MongoDB cookbook: Indexes

MongoDB cookbook: Indexes

Indexes are crucial for the efficient execution of queries and aggregations in MongoDB. Without indexes, MongoDB must perform a collection scan, i.e., scan every document in a collection.

If a write operation modifies an indexed field, MongoDB updates all indexes that have the modified field as a key. So, be careful while choosing indexes.

Types Of Indexes

ref:
https://docs.mongodb.com/manual/indexes/
https://docs.mongodb.com/manual/applications/indexes/

Single Field Index

For a single field index and sort operations, the sort order (i.e. ascending or descending) of the index key doesn't matter. With index intersetion, single field indexs could be powerful.

ref:
https://docs.mongodb.com/manual/core/index-single/

Compound Index

The order of the fields listed in a compound index is very important.

ref:
https://docs.mongodb.com/manual/core/index-compound/
https://docs.mongodb.com/manual/tutorial/create-indexes-to-support-queries/

TTL Index

When the TTL thread is active, a background thread in mongod reads the values in the index and removes expired documents from the collection. You will see delete operations in the output of db.currentOp().

TTL indexes are a single-field indexes. Compound indexes do not support TTL and ignore the expireAfterSeconds option.

import datetime

class JournalEntry(db.Document):
    users = db.ListField(db.ReferenceField('User'))
    event = db.StringField()
    context = db.DynamicField()
    timestamp = db.DateTimeField(default=datetime.datetime.utcnow)

    meta = {
        'index_background': True,
        'indexes': [
            {
                'fields': ['timestamp'],
                'cls': False,
                'expireAfterSeconds': int(datetime.timedelta(days=90).total_seconds()),
            },
        ],
    }

ref:
https://docs.mongodb.com/manual/core/index-ttl/

Index Intersection

MongoDB can use multiple single field indexes to fulfill queries.

db.orders.createIndex({tags: 1});
db.orders.createIndex({key: { created_at: -1 }, background: true});

db.orders.find({item: 'abc123', qty: {$gt: 15}});

ref:
https://docs.mongodb.com/manual/core/index-intersection/

Covered Queries

ref:
https://docs.mongodb.com/manual/core/query-optimization/#read-operations-covered-query

Index Limits

The size of an index entry for an indexed field must be less than 1024 bytes. For instance, an arbitrary URL field can easily exceed 1024 bytes.

MongoDB will not insert into an indexed collection any document with an indexed field whose corresponding index entry would exceed the index key limit, and instead, will return an error; Updates to the indexed field will error if the updated value causes the index entry to exceed the indexkey limit.

ref:
https://docs.mongodb.com/manual/reference/limits/#indexes

List Indexes

db.message.getIndexes()

// show collection statistics
db.message.stats()
db.message.stats().indexSizes

// scale defaults to 1 to return size data in bytes
// 1024 * 1024 means MB
db.getCollection('message').stats({'scale': 1024 * 1024}).indexSizes

ref:
https://docs.mongodb.com/manual/tutorial/manage-indexes/

Add Indexes

TODO:
It seems like creating indexes on empty collection, even with background will cause DB latency.

An index which contains array fields might consume a lot of disk space.

db.message.createIndex({
    '_cls': 1,
    'sender': 1,
    'posted_at': 1
}, {'background': true, 'sparse': true})

db.message.createIndex({
    '_cls': 1,
    'includes': 1,
    'posted_at': 1
}, {'background': true, 'sparse': true})

db.getCollection('message').find({
    '$or': [
        // sent by cp
        {
            '_cls': 'Message.ChatMessage',
            'sender': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        },
        // sent by payer
        {
            '_cls': 'Message.GiftMessage',
            'includes': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        }
    ]
})
import pymongo
from your_app.models import YourModel

YourModel._get_collection().create_index(
    [
        ('users', pymongo.ASCENDING),
        ('timestamp', pymongo.DESCENDING),
    ], 
    background=True,
    partialFilterExpression={'timestamp': {'$exists': True}},
)

ref:
http://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.create_index

You can't index two arrays together, in this example: includes and unlocks.

// it doesn't work
db.message.createIndex({
    '_cls': 1,
    'sender': 1,
    'includes': 1,
    'unlocks': 1
}, {'background': true, 'sparse': true})

The Order Of Fields of Compound Indexes

The order of fields in an index matters, you must consider Index Cardinality and Selectivity. Instead, the order of fields in a find() query or $match in an aggregation doesn't affect whether it can use an index or not.

The order of fields in a compound index should be:

  • First, fields on which you will query for exact values.
  • Second, fields on which you will sort.
  • Finally, fields on which you will query for a range of values.

ref:
https://docs.mongodb.com/manual/core/index-compound/#prefixes
https://emptysqua.re/blog/optimizing-mongodb-compound-indexes/
https://blog.mlab.com/2012/06/cardinal-ins/
https://stackoverflow.com/questions/33545339/how-does-the-order-of-compound-indexes-matter-in-mongodb-performance-wise
https://stackoverflow.com/questions/5245737/mongodb-indexes-order-and-query-order-must-match

Partial Indexes v.s. Sparse Indexes

Partial indexes should be preferred over sparse indexes. However, partial indexes only support a very small set of filter operators:

  • $exists
  • $eq or field: value
  • $gt, $gte, $lt, $lte
  • $type
  • $and

If you use 'partialFilterExpression': {'includes': {'$exists': true}}, MongoDB also indexes documents whose includes field has null value.

db.collection('message').createIndex(
    {'_cls': 1, 'includes': 1, 'posted_at': 1},
    {'background': true, 'partialFilterExpression': {'includes': {'$exists': true}}}
)

db.collection('message').createIndex(
  {'created_at': -1},
  {'background': true, 'partialFilterExpression': {'created_at': {'$gte': new Date("2018-01-01T16:00:00Z")}}}
)

ref:
https://docs.mongodb.com/manual/core/index-partial/
https://docs.mongodb.com/manual/core/index-sparse/

Create An Index On An Array Field

Querying will certainly be a lot easier in an array field index than a object field.

ref:
https://stackoverflow.com/questions/9589856/mongo-indexing-on-object-arrays-vs-objects

Create An Unique Index On An Array Field

Create an unique index on an array field.

The unique constraint applies to separate documents in the collection. That is, the unique index prevents separate documents from having the same value for the indexed key. It prevents different documents have the same transaction ID but allows one document has multiple identical transaction IDs.

db.getCollection('test1').createIndex({purchases.transaction_id: 1}, {unique: true})

db.getCollection('test1').insert({ _id: 1, purchases: [
    {transaction_id: 'A'}
]})

db.getCollection('test1').insert({ _id: 5, purchases: [
    {transaction_id: 'A'}
]})

db.getCollection('test1').update({ _id: 1}, {$push: {purchases: {transaction_id: 'A'}}})

To prevent one document has multiple identical transaction IDs, We would have atomic updates on single documents.

user = User(id=bson.ObjectId(user_id))
purchase = DirectPurchase(
    user=user,
    timestamp=timestamp,
    transaction_id=transaction_id,
)
MessagePackProduct.objects \
    .filter(id=message_pack_id, __raw__={
        'purchases': {'$not': {'$elemMatch': {
            '_cls': purchase._cls,
            'user': purchase.user.id,
        }}},
    }) \
    .update_one(push__purchases=purchase)

ref:
https://docs.mongodb.com/manual/core/index-unique/#unique-constraint-across-separate-documents

Sort With Indexes

ref:
https://docs.mongodb.com/manual/tutorial/sort-results-with-indexes/

Drop Indexes

db.message.dropIndex({
    'includes': 1
})

db.message.dropIndex({
    '_cls': 1,
    'posted_at': 1,
    'includes': 1
})

Remove Unused Indexes

You can use db.getCollection('COLLECTION_NAME').aggregate({$indexStats: {}}) to find unused indexes, there is a accesses.ops field which indicates the number of operations that have used the index. Also, you might want to remove indexes which have the same prefix.

db.getCollection('message').aggregate(
    {
        '$indexStats': {}
    },
    {
        '$match': {
            'accesses.ops': {'$gt': 0}
        }
    }
);

Result:

{
    "name" : "_cls_1_sender_1_posted_at_1",
    "key" : {
        "_cls" : 1,
        "sender" : 1,
        "posted_at" : 1
    },
    "host" : "a6ea11893605:27017",
    "accesses" : {
        "ops" : 3,
        "since" : "2018-01-26T07:04:51.137Z"
    }
}

ref:
https://blog.mlab.com/2017/01/using-mongodb-indexstats-to-identify-and-remove-unused-indexes/
https://scalegrid.io/blog/how-to-find-unused-indexes-in-mongodb/

Profiling

// enable
db.setProfilingLevel(2)

// disable
db.setProfilingLevel(0)

// see profiling data after you issues some queries
db.system.profile.find().limit(10).sort( { ts : -1 } ).pretty()

// delete profiling data
db.system.profile.drop()

Query Explain

There are both collection.find().explain() and collection.explain().find(). It's recommended to use collection.find().explain('executionStats') for getting more information, like total documents examined.

db.getCollection('message').find({
    '$or': [
        // sent by cp
        {
            '_cls': 'Message.ChatMessage',
            'sender': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        },
        {
            '_cls': 'Message',
            'sender': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        },
        // sent by payer
        {
            '_cls': 'Message.ChatMessage',
            'includes': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        },
        {
            '_cls': 'Message.ReplyMessage',
            'includes': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        },
        {
            '_cls': 'Message.GiftMessage',
            'includes': ObjectId('582ee32a5b9c861c87dc297e'),
            'posted_at': {
                '$gte': ISODate('2018-01-08T00:00:00.000Z'),
                '$lt': ISODate('2018-01-14T00:00:00.000Z')
            }
        }
    ]
})
// .explain()
// .explain('allPlansExecution')
.explain('executionStats')

ref:
https://docs.mongodb.com/manual/reference/method/cursor.explain/
https://docs.mongodb.com/manual/reference/method/db.collection.explain/#db.collection.explain

You could also explain a .update() query. However, .updateMany() and .updateOne() don't support .explain().

db.getCollection('user').explain().update(
    {'follows.user': ObjectId("57985b784af4124063f090d3")},
    {'$set': {'created_at': ISODate('2018-01-01 00:00:00.000Z')}},
    {'multi': true}
)

Some important fields to look at in the result of explain():

  • executionStats.totalKeysExamined
  • executionStats.totalDocsExamined
  • queryPlanner.winningPlan.stage
  • queryPlanner.winningPlan.inputStage.stage
  • queryPlanner.winningPlan.inputStage.indexName
  • queryPlanner.winningPlan.inputStage.direction

Possible values of stage:

  • COLLSCAN: scanning the entire collection
  • IXSCAN: scanning index keys
  • FETCH: retrieving documents
  • SHARD_MERGE: merging results from shards

ref:
https://docs.mongodb.com/manual/reference/explain-results/

Aggregation Explain

db.getCollection('message').explain().aggregate()

ref:
https://stackoverflow.com/questions/12702080/mongodb-explain-for-aggregation-framework
https://docs.mongodb.com/manual/reference/method/db.collection.explain/

If $project, $unwind, or $group occur prior to the $sort operation, $sort cannot use any indexes. Additionally, $sort can only use fields defined in previous $project stage.

Basically, you could just consider the $match part when you want to create new indexes.

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/sort/#sort-operator-and-performance

MongoEngine

_cls creation on indexes is automatically included if allow_inheritance is on. If you want to disable, set kwarg cls: False.

ref:
http://docs.mongoengine.org/guide/defining-documents.html#indexes

MongoDB cookbook: Queries and Aggregations

MongoDB cookbook: Queries and Aggregations

Frequently accessed items are cached in memory, so that MongoDB can provide optimal response time.

MongoDB Shell in JavaScript

Administration

db.currentOp();

// slow queries
db.currentOp({
    "active": true,
    "secs_running": {"$gt" : 3},
    "ns": /^test\./
});

// queries not using any index
db.adminCommand({ 
    "currentOp": true,
    "op": "query", 
    "planSummary": "COLLSCAN"
});

// operations with high numYields
db.adminCommand({ 
    "currentOp": true, 
    "ns": /^test\./, 
    "numYields": {"$gte": 100} 
}) 

db.serverStatus().connections
{
    "current" : 269,
    "available" : 838591,
    "totalCreated" : 417342
}

ref:
https://docs.mongodb.com/manual/reference/method/db.currentOp/
https://hackernoon.com/mongodb-currentop-18fe2f9dbd68
http://www.mongoing.com/archives/6246

BSON Types

ref:
https://docs.mongodb.com/manual/reference/bson-types/

Check If A Document Exists

It is significantly faster to use find() + limit() because findOne() will always read + return the document if it exists. find() just returns a cursor (or not) and only reads the data if you iterate through the cursor.

db.getCollection('message').find({_id: ObjectId("585836504b287b5022a3ae26", delivered: false)}, {_id: 1}).limit(1)

ref:
https://stackoverflow.com/questions/8389811/how-to-query-mongodb-to-test-if-an-item-exists
https://blog.serverdensity.com/checking-if-a-document-exists-mongodb-slow-findone-vs-find/

Find Documents

db.getCollection('user').find({username: 'nanababy520'})

db.getCollection('message').find({_id: ObjectId("5a6383b8d93d7a3fadf75af3")})

db.getCollection('message').find({_cls: 'Message'}).sort({posted_at: -1})

db.getCollection('message').find({sender: ObjectId("57aace67ac08e72acc3b265f"), pricing: {$ne: 0}})

db.getCollection('message').find({
    sender: ObjectId("5ac0f56038cfff013a123d85"),
    created_at: {
        $gte: ISODate('2018-04-21 12:00:00Z'),
        $lte: ISODate('2018-04-21 13:00:00Z')
    }
})
.sort({created_at: -1})

Find Documents With Regular Expression

db.getCollection('user').find({'username': /vicky/})

ref:
https://docs.mongodb.com/manual/reference/operator/query/regex/

Find Documents With An Array Field

  • $in: [...] means "intersection" or "any element in"
  • $all: [...] means "subset" or "contain"
  • $elemMatch: {...} means "any element match"
  • $not: {$elemMatch: {$nin: [...]}} means "subset" or "in"

The last one roughly means not any([False, False, False, False]) where each False is indicating if the item is not in in [...].

ref:
https://stackoverflow.com/questions/12223465/mongodb-query-subset-of-an-array

db.getCollection('message').find({includes: ObjectId("5a4bb448af9c462c610d0cc7")})

db.getCollection('user').find({gender: 'F', tags: 'promoted'})
db.getCollection('user').find({gender: 'F', 'tags.1': {$exists: true}})

ref:
https://docs.mongodb.com/manual/reference/operator/query/exists/#exists-true

Find Documents With An Array Field Of Embedded Documents

Usually, you could use $elemMatch.

{'the_array_field': {'$elemMatch': {
    'a_field_of_each_element': {'$lte': now},
    'another_field_of_each_element': 123
}}}
db.getCollection('message').find({
    unlocks: {
        $elemMatch: {
            _cls: 'PointsUnlock',
            user: ObjectId("57f662e727a79d07993faec5")
        }
    }
})

db.getCollection('feature.shop.product').find({
    purchases: {
        $elemMatch: {
            _cls: 'Purchase'
        }
    }
})

db.getCollection('feature.shop.product').find({
    '_id': 'prod_CWlSTXBEU4mhEu',
    'purchases': {'$not': {'$elemMatch': {
        '_cls': 'DirectPurchase',
        'user': ObjectId("58b61d9094ab56f912ba10a5")
    }}},
})

ref:
https://docs.mongodb.com/manual/reference/operator/query/elemMatch/

Find Documents With Existence Of Fields Or Values

  • .find({'field': {'$exists': true}}): the field exists
  • .find({'field': {'$exists': false}}): the field does not exist
  • .find({'field': {'$type': 10}}): the field exists with a null value
  • .find({'field': null}): the field exists with a null value or the field does not exist
  • .find({'field': {'$ne': null}}): the field exists and the value is not null
  • .find({'array_field': {'$in': [null, []]}})
db.test.insert({'num': 1, 'check': 'value'})
db.test.insert({'num': 2, 'check': null})
db.test.insert({'num': 3})

db.test.find({});

db.test.find({'check': {'$exists': true}})
// return 1 and 2

db.test.find({'check': {'$exists': false}})
// return 3

db.test.find({'check': {'$type': 10}});
// return 2

db.test.find({'check': null})
// return 2 and 3

db.test.find({'check': {'$ne': null}});
// return 1

ref:
https://stackoverflow.com/questions/4057196/how-do-you-query-this-in-mongo-is-not-null
https://docs.mongodb.com/manual/tutorial/query-for-null-fields/

Find Documents Where An Array Field Does Not Contain A Certain Value

db.getCollection('user').update({_id: ObjectId("579994ac61ff217f96a585d9"), tags: {$ne: 'tag_to_add'}}, {$push: {tags: 'tag_to_add'}})

db.getCollection('user').update({_id: ObjectId("579994ac61ff217f96a585d9"), tags: {$nin: ['tag_to_add']}}, {$push: {tags: 'tag_to_add'}})

ref:
https://stackoverflow.com/questions/16221599/find-documents-with-arrays-not-containing-a-document-with-a-particular-field-val

Find Documents Where An Array Field Is Not Empty

db.getCollection('message').find({unlocks: {$exists: true}})

ref:
https://stackoverflow.com/questions/14789684/find-mongodb-records-where-array-field-is-not-empty

Find Documents Where An Array Field's Size Is Greater Than 1

db.getCollection('user.inbox').find({
    'messages.0': {'$exists': true}
})

db.getCollection('message').find({
    '_cls': 'Message',
    'unlocks.10': {'$exists': true}
}).sort({'posted_at': -1})

db.getCollection('message').find({
    '_cls': 'Message.ChatMessage',
    'sender': ObjectId("582ee32a5b9c861c87dc297e"),
    'unlocks': {'$exists': true, '$not': {'$size': 0}}
})

ref:
https://stackoverflow.com/questions/7811163/query-for-documents-where-array-size-is-greater-than-1/15224544

Find Documents With Computed Values Using $expr

For instance, compare 2 fields from a single document in a find() query.

db.getCollection('user').find({
    $expr: {
        $eq: [{$size: '$follows'}, {$size: '$blocks'}]
    }
})

ref:
https://thecodebarbarian.com/a-nodejs-perspective-on-mongodb-36-lookup-expr
https://dzone.com/articles/expressive-query-language-in-mongodb-36-2

Project A Subset Of An Array Field With $filter

A sample document:

{
    "_id" : "message_unlock_pricing",
    "seed" : 42,
    "distributions" : {
        "a" : 0.5,
        "b" : 0.5
    },
    "whitelist" : [ 
        {
            "_id" : ObjectId("57dd071dd20fc40c0cbed6b7"),
            "variation" : "a"
        }, 
        {
            "_id" : ObjectId("5b1173a1487fbe2b2e9bba04"),
            "variation" : "b"
        }, 
        {
            "_id" : ObjectId("5a66d5c2af9c462c617ce552"),
            "variation" : "b"
        }
    ]
}
var now = new Date();

db.getCollection('feature.ab.experiment').aggregate([
    {'$project': {
        '_id': 1,
        'seed': 1,
        'distributions': 1,
        'whitelist': {
            '$filter': {
               'input': {'$ifNull': ["$whitelist", []]},
               'as': "user",
               'cond': {'$eq': ['$$user._id', ObjectId("5a66d5c2af9c462c617ce552")]}
            }
         }
    }},
    {'$unwind': {
        'path': '$whitelist',
        'preserveNullAndEmptyArrays': true
    }}
])

ref:
https://stackoverflow.com/questions/42607221/mongodb-aggregation-project-check-if-array-contains

Insert Documents

db.getCollection('feature.launch').insert({
    'url': '//example.com/launchs/5a06b88aaf9c462c6146ce12.jpg',
    'user': {
        'id': ObjectId("5a06b88aaf9c462c6146ce12"),
        'username': 'luke0804',
        'tags': ["gender:male"]
    }
})

db.getCollection('feature.launch').insert({
    'url': '//example.com/launchs/57c16f5bb811055b66d8ef46.jpg',
    'user': {
        'id': ObjectId("57c16f5bb811055b66d8ef46"),
        'username': 'riva',
        'tags': ["gender:female"]
    }
})

Update Within A For Loop

var oldTags = ['famous', 'newstar', 'featured', 'western', 'recommended', 'popular'];
oldTags.forEach(function(tag) {
    db.getCollection('user').updateMany({tags: tag}, {$addToSet: {tags: 'badge:' + tag}});
});

Update With Conditions Of Field Values

You could update the value of the field to a specified value if the specified value is less than or greater than the current value of the field. The $min and $max operators can compare values of different types.

Only set posted_at to current timestamp if its current value is None or absent.

Post.objects.update_one(
    {
        '_id': bson.ObjectId(post_id),
        'media.0': {'$exists': True},
        'title': {'$ne': None},
        'location': {'$ne': None},
        'gender': {'$ne': None},
        'pricing': {'$ne': None},
    },
    {
        '$min': {'posted_at': utils.utcnow()},
    },
)

ref:
https://docs.mongodb.com/manual/reference/operator/update/min/
https://docs.mongodb.com/manual/reference/operator/update/max/

Update An Array Field

Array update operators:

  • $: Acts as a placeholder to update the first element in an array for documents that matches the query condition.
  • $[]: Acts as a placeholder to update all elements in an array for documents that match the query condition.
  • $[<identifier>]: Acts as a placeholder to update elements in an array that match the arrayFilters condition.
  • $addToSet: Adds elements to an array only if they do not already exist in the set.
  • $push: Adds an item to an array.
  • $pop: Removes the first or last item of an array.
  • $pull: Removes all array elements that match a specified query.
  • $pullAll: Removes all matching values from an array.

ref:
https://docs.mongodb.com/manual/reference/operator/update-array/
http://docs.mongoengine.org/guide/querying.html#atomic-updates
http://thecodebarbarian.com/a-nodejs-perspective-on-mongodb-36-array-filters.html

Add an element in an array field.

user_id = '582ee32a5b9c861c87dc297e'
tag = 'my_tag'

updated = User.objects \
    .filter(id=user_id, tags__ne=tag) \
    .update_one(push__tags=tag)

updated = User.objects \
    .filter(id=user_id) \
    .update_one(add_to_set__schedules={
        'tag': tag,
         'nbf': datetime.datetime(2018, 6, 4, 0, 0),
        'exp': datetime.datetime(2019, 5, 1, 0, 0),
    })

Insert an element into an array at a certain position.

slot = 2
Post.objects.filter(id=post_id, media__id__ne=media_id).update_one(__raw__={
    '$push': {
        'media': {
            '$each': [{'id': bson.ObjectId(media_id)}],
            '$position': slot,
        }
    }
})

ref:
https://docs.mongodb.com/manual/reference/operator/update/position/
http://docs.mongoengine.org/guide/querying.html#querying-lists

Remove elements in an array field. It is also worth noting that update(pull__abc=xyz) always returns 1.

user_id = '582ee32a5b9c861c87dc297e'
tag = 'my_tag'

updated = User.objects \
    .filter(id=user_id) \
    .update_one(pull__tags=tag)

updated = User.objects \
    .filter(id=user_id) \
    .update_one(pull__schedules={'tag': tag})

Remove multiple embedded documents in an array field.

import bson

user_id = '5a66d5c2af9c462c617ce552'
tags = ['valid_tag_1', 'future_tag']

updated_result = User._get_collection().update_one(
    {'_id': bson.ObjectId(user_id)},
    {'$pull': {'schedules': {'tag': {'$in': tags}}}},
)
print(updated_result.raw_result)
# {'n': 1, 'nModified': 1, 'ok': 1.0, 'updatedExisting': True}

ref:
https://stackoverflow.com/questions/28102691/pullall-while-removing-embedded-objects

db.getCollection('feature.feeds').updateMany(
    {
        'aliases': {'$exists': true},
        'exp': {'$gte': ISODate('2019-03-21T00:00:00.000+08:00')},
        'items': {'$elemMatch': {'username': 'engine'}},
    },
    {
        '$pull': {
            'items': {'username': 'engine'},
        }
    }
);

ref:
https://docs.mongodb.com/manual/reference/operator/update/pull/

You could also use add_to_set to add an item to an array only if it is not in the list, which always returns 1 if filter() matches any document. However, you are able to set full_result=True to get detail updated result.

update_result = User.objects.filter(id=user_id).update_one(
    add_to_set__tags=tag,
    full_result=True,
)
# {'n': 1, 'nModified': 1, 'ok': 1.0, 'updatedExisting': True}

ref:
http://docs.mongoengine.org/guide/querying.html#atomic-updates

Update a multi-level nested array field. Yes, arrayFilters supports it.

ref:
https://docs.mongodb.com/manual/reference/operator/update/positional-filtered/
https://stackoverflow.com/questions/23577123/updating-a-nested-array-with-mongodb

Update an embedding document in an array field.

MessagePackProduct.objects \
    .filter(id='prod_CR1u34BIpDbHeo', skus__id='sku_CR23rZOTLhYprP') \
    .update(__raw__={
        '$set': {'skus.$': {'id': 'sku_CR23rZOTLhYprP', 'test': 'test'}}
    })

ref:
https://stackoverflow.com/questions/9200399/replacing-embedded-document-in-array-in-mongodb
https://docs.mongodb.com/manual/reference/method/db.collection.update/#db.collection.update

Update specific embedded documents with arrayFilters in an array field.

User data:

{
    "_id" : ObjectId("5a66d5c2af9c462c617ce552"),
    "username" : "gibuloto",
    "tags" : [
        "beta",
        "future_tag",
        "expired_tag"
    ],
    "schedules" : [
        {
            "tag" : "valid_tag",
            "nbf" : ISODate("2018-05-01T16:00:00.000Z"),
            "exp" : ISODate("2020-06-04T16:00:00.000Z")
        },
        {
            "tag" : "future_tag",
            "nbf" : ISODate("2020-01-28T16:00:00.000Z"),
            "exp" : ISODate("2020-12-14T16:00:00.000Z")
        },
        {
            "tag" : "expired_tag",
            "nbf" : ISODate("2016-02-12T16:00:00.000Z"),
            "exp" : ISODate("2016-04-21T16:00:00.000Z")
        }
    ],
}

It is worth noting that <identifier> in $arrayFilters can only contain lowercase alphanumeric characters.

import bson

user_id = '5a66d5c2af9c462c617ce552'
tags = ['from_past_to_future']

updated_result = User._get_collection().update_one(
    {'_id': bson.ObjectId(user_id)},
    {
        '$addToSet': {'tags': {'$each': tags}},
        '$unset': {'schedules.$[schedule].nbf': True},
    },
    array_filters=[{'schedule.tag': {'$in': [tag for tag in tags]}}],
)
print(updated_result.raw_result)
# {'n': 1, 'nModified': 1, 'ok': 1.0, 'updatedExisting': True}

ref:
https://docs.mongodb.com/master/reference/operator/update/positional-filtered/

Update An Array Field With arrayFilters

You should use arrayFilters as much as possible.

The syntax of arrayFilters would be:

db.collection.update(
   {<query selector>},
   {<update operator>: {'array.$[<identifier>].field': value}},
   {arrayFilters: [{<identifier>: <condition>}}]}
)
Inbox._get_collection().update_many(
    {'messages.id': message_id},
    {'$set': {'messages.$[message].tags': tags}},
    array_filters=[
        {'message.id': message_id},
    ],
)

ref:
https://docs.mongodb.com/manual/reference/operator/update/positional-filtered/

Insert an element into an array field at a certain position.

db.getCollection('feature.forums.post').update(
   { _id: ObjectId("5b3c6a9c8433b15569cae54e") },
   {
     $push: {
        media: {
           $each: [{
                "mimetype" : "image/jpeg",
                "url" : "https://example.com/posts/5adb795b47d057338abe8910.jpg",
                "presets" : {}
            }],
           $position: 1
        }
     }
   }
)

Or use explicit array index $set.

media_id = 'xxx'
media_slot = 0

Post.objects \
    .filter(id=post_id, **{f'media__{media_slot}__id__ne': media_id}) \
    .update_one(__raw__={'$set': {f'media.{media_slot}': {'id': media_id}}})

ref:
https://docs.mongodb.com/manual/reference/operator/update/position/

Set an array field to empty.

db.getCollection('message').update(
    {'tags': 'pack:joycelai-1'},
    {'$set': {'unlocks': []}},
    {'multi': true}
)

db.getCollection('feature.shop.product').update(
    {},
    {'$set': {'purchases': []}},
    {'multi': true}
)

ref:
https://docs.mongodb.com/manual/reference/method/db.collection.update/
https://docs.mongodb.com/manual/reference/operator/update/set/

Remove elements from an array field.

var userId = ObjectId("57985b784af4124063f090d3");

db.getCollection('user').update(
    {'follows.user': userId},
    {'$pull': {'follows': {'user': userId}}},
    {
        'multi': true,
    }
);

db.getCollection('message').update(
    {'_id': {'$in': [
        ObjectId('5aca1ffc4271ab1624787ec4'),
        ObjectId('5aca31ab93ef2936291c3dd4'),
        ObjectId('5aca33d9b5eaef04943c0d0b'),
        ObjectId('5aca34e7a48c543b07fb0a0f'),
        ObjectId('5aca272d93ef296edc1c3dee'),
        ObjectId('5aca342aa48c54306dfb0a21'),
        ObjectId('5aca20756bd01023a8cb02e9')
    ]}},
    {'$pull': {'tags': 'pack:prod_D75YlDMzcCiAw3'}},
    {'multi': true}
);

ref:
https://docs.mongodb.com/manual/reference/operator/update/pull/

Update A Dictionary Field

Set a key/value in a dictionary field.

tutorial.data = {
    "price_per_message": 1200,
    "inbox": []
}

new_inbox = [
    {
        "id": "5af118c598eacb528e8fb8f9",
        "sender": "5a13239eaf9c462c611510fc"
    },
    {
        "id": "5af1117298eacb212a8fb8e9",
        "sender": "5a99554be9a21d5ff38b8ca5"
    }
]
tutorial.update(set__data__inbox=new_inbox)

ref:
https://stackoverflow.com/questions/21158028/updating-a-dictfield-in-mongoengine

Upsert: Update Or Create

You must use upsert=true with uniquely indexed fields. If you don't need the modified document, you should just use update_one(field1=123, field2=456, upsert=True).

Additionally, remember that modify() always reloads the whole object even the original one only loads specific fields with only(). Try to avoid using document.DB_QUERY_METHOD(), and using User.objects.filter().only().modify() or User.objects.filter().update() when it is possible.

tag_schedule = TagSchedule.objects \
    .filter(user=user_id, tag='vip') \
    .modify(
        started_at=started_at,
        ended_at=ended_at,
        upsert=True
    )

user = User.objects \
    .filter(id=user.id, tutorials__buy_diamonds__version=None) \
    .modify(set__tutorials__buy_diamonds__version='v1')

updated = User.objects \
    .filter(user=user_id, tag=tag) \
    .update_one(
        push__followers=new_follower,
    )

ref:
https://docs.mongodb.com/manual/reference/method/db.collection.update/#update-with-unique-indexes
http://docs.mongoengine.org/apireference.html#mongoengine.queryset.QuerySet.modify
http://docs.mongoengine.org/apireference.html#mongoengine.queryset.QuerySet.update_one

Rename A Field

Simply rename a field with $rename.

db.getCollection('user').updateMany(
    {
        'phone_no': {'$exists': true},
        'social.phone-number.uid': {'$exists': false},
    },
    {'$rename': {
        'phone_no': 'social.phone-number.uid',
    }}
);

ref:
https://docs.mongodb.com/manual/reference/operator/update/rename/

Do some extra data converting and rename the field manually.

db.getCollection('user').aggregate([
    {'$match': {
        'twitter.id': {'$exists': true},
        'social.twitter.uid': {'$exists': false},
    }},
    {'$project': {
        'twitter_id': '$twitter.id',
        'twitter_id_str': {'$toString': '$twitter.id'},
    }},
]).forEach(function (document) {
    printjson({
        'id': document._id,
    });

    db.getCollection('user').updateOne(
      {
          'twitter.id': document.twitter_id,
          'social.twitter.uid': {'$exists': false},
      },
      {
          '$unset': {'twitter.id': true},
          '$set': {'social.twitter.uid': document.twitter_id_str}
      }
    )
})

Insert/Replace Large Amount Of Documents

const operations = contracts.map((contract) => {
    // TODO: should create a new contract if there is any change of the contract?
    // use MongoDB transaction to change the new one and old one
    return {
        'replaceOne': {
            'filter': {'settlement_datetime': currentSettlementMonth.toDate(), 'user': contract.user},
            'replacement': contract,
            'upsert': true,
        },
    };
});

db.collection('user.contract').bulkWrite(
    operations,
    {ordered: true},
    (bulkError, result) => {
        if (bulkError) {
            return next(bulkError, null);
        }

        logger.info('Finished importing all contracts');
        return next(null, result);
    },
);

Update Large Numbers Of Documents

Use Bulk.find.arrayFilters() and Bulk.find.update() together.

In Python:

import datetime

expiration_time = datetime.datetime.utcnow() - datetime.timedelta(hours=48)

bulk = Outbox._get_collection().initialize_unordered_bulk_op()

for outbox in Outbox.objects.only('id').filter(messages__posted_at__lt=expiration_time):
    bulk.find({'_id': outbox.id}).update_one({
        '$pull': {'messages': {
            'posted_at': {'$lt': expiration_time},
        }},
    })

try:
    results = bulk.execute()
except pymongo.errors.InvalidOperation as err:
    if str(err) != 'No operations to execute':
        raise err

In JavaScript:

const operations = docs.map((doc) => {
    logger.debug(doc, 'Revenue');

    const operation = {
        'updateOne': {
            'filter': {
                '_id': doc._id,
            },
            'update': {
                '$set': {
                    'tags': doc.contract.tags,
                },
            },
        },
    };
    return operation;
});

db.collection('user.revenue').bulkWrite(
    operations,
    {ordered: false},
    (bulkError, bulkResult) => {
        if (bulkError) {
            return next(bulkError, null);
        }

        logger.info(bulkResult, 'Saved tags');
        return next(null, true);
    },
);
});

ref:
https://docs.mongodb.com/manual/reference/method/Bulk/
https://docs.mongodb.com/manual/reference/method/Bulk.find.arrayFilters/

Of course, you could also update the same document with multiple operations. However, it does not make sense.

from pymongo import UpdateOne
import bson

def _operations():
    if title = payload.get('title'):
        yield UpdateOne({'_id': bson.ObjectId(post_id)}, {'$set': {'title': title}})

    if location = payload.get('location'):
        yield UpdateOne({'_id': bson.ObjectId(post_id)}, {'$set': {'location': location}})      

    if pricing = payload.get('pricing'):
        yield UpdateOne({'_id': bson.ObjectId(post_id)}, {'$set': {'pricing': pricing}})

    if description = payload.get('description'):
        yield UpdateOne({'_id': bson.ObjectId(post_id)}, {'$set': {'description': description}})

    UpdateOne(
        {
            '_id': bson.ObjectId(post_id),
            'media.0': {'$exists': True},
            'title': {'$ne': None},
            'location': {'$ne': None},
            'pricing': {'$ne': None},
            'posted_at': {'$eq': None},
        },
        {'$set': {'posted_at': utils.utcnow()}},
    )

operations = list(_operations())
result = Post._get_collection().bulk_write(operations, ordered=True)
print(result.bulk_api_result)

ref:
https://api.mongodb.com/python/current/examples/bulk.html

Remove items from an array field of documents.

var userId = ObjectId("57a42a779f22bb6bcc434520");

db.getCollection('user').update(
    {'follows.user': userId},
    {'$pull': {'follows': {'user': userId}}},
    {'multi': true}
)

ref:
https://stackoverflow.com/questions/33594397/how-to-update-a-large-number-of-documents-in-mongodb-most-effeciently

Remove Large Numbers Of Documents

in mongo shell:

var bulk = db.getCollection('feature.journal.v2').initializeUnorderedBulkOp()
bulk.find({}).remove()
bulk.execute()

// or

var bulk = db.getCollection('feature.journal.v2').initializeUnorderedBulkOp()
bulk.find({event: 'quest.rewarded'}).remove()
bulk.find({event: 'message.sent'}).remove()
bulk.execute()

ref:
https://docs.mongodb.com/manual/reference/method/Bulk.find.remove/#bulk-find-remove

MongoEngine In Python

ref:
http://docs.mongoengine.org/guide/index.html
http://docs.mongoengine.org/apireference.html

Define Collections

It seems every collection in MongoEngine must have a id field.

ref:
http://docs.mongoengine.org/guide/defining-documents.html

Define A Field With Default EmbeddedDocument

The behavior of setting an EmbeddedDocument class as default works differently with and without only().

class User(ab.models.ABTestingMixin, db.Document):
    class UserSettings(db.EmbeddedDocument):
        reply_price = db.IntField(min_value=0, default=500, required=True)
        preferences = db.ListField(db.StringField())

    email = db.EmailField(max_length=255)
    created_at = db.DateTimeField(default=utils.now)
    last_active = db.DateTimeField(default=utils.now)
    settings = db.EmbeddedDocumentField(UserSettings, default=UserSettings)

If the user does not have settings field in DB, here is the difference.

user = User.objects.get(username='gibuloto')
isinstance(user.settings, User.UserSettings) == True

user = User.objects.only('settings').get(username='gibuloto')
(user.settings is None) == True

user = User.objects.exclude('settings').get(username='gibuloto')
isinstance(user.settings, User.UserSettings) == True

Filter With Raw Queries

post = Post.objects \
    .no_dereference().only('posted_at') \
    .filter(__raw__={
        '_id': bson.ObjectId(post_id),
        'media.0': {'$exists': True},
        'title': {'$ne': None},
        'location': {'$ne': None},
        'gender': {'$ne': None},
        'pricing': {'$ne': None},
    }) \
    .modify(__raw__={'$min': {'posted_at': utils.utcnow()}}, new=True)

print(post.posted_at)

ref:
http://docs.mongoengine.org/guide/querying.html#raw-queries

Check If A Document Exists

Use .exists().

import datetime

now = datetime.datetime.now(datetime.timezone.utc)
if TagSchedule.objects.filter(user=user_id, tag=tag, started_at__gt=now).exists():
    return 'exists'

You have to use __raw__ if the field you want to query is a db.ListField(GenericEmbeddedDocumentField(XXX) field.

if MessagePackProduct.objects.filter(id=message_pack_id, __raw__={'purchases.user': g.user.id}).exists():
    return 'exists'

Upsert: Get Or Create

buy_diamonds = BuyDiamonds.objects.filter(user_id=user.id).upsert_one()

ref:
http://docs.mongoengine.org/apireference.html#mongoengine.queryset.QuerySet.upsert_one

Store Files On GridFS

# models.py
class User(db.Document):
    username = db.StringField()
    image = db.ImageField(collection_name='user.images')
# tasks.py
import bson
import gridfs
import mongoengine

@celery.shared_task(bind=True, ignore_result=True)
def gridfs_save(task, user_id, format='JPEG', raw_data: bytes=None, **kwargs):
    image_id = None

    if raw_data is None:
        user = User.objects.only('image').get(id=user_id)
        if user.image.grid_id:
            image_id, raw_data = user.image.grid_id, user.image.read()

    if not raw_data:
        return

    gf = gridfs.GridFS(mongoengine.connection.get_db(), User.image.collection_name)

    with io.BytesIO(raw_data) as raw_image:
        with Image.open(raw_image) as image:
            image = image.convert('RGB')
            with io.BytesIO() as buffer:
                image.save(buffer, format=format, quality=80, **kwargs)
                buffer.seek(0)
                grid_id = gf.put(buffer, format=format, width=image.width, height=image.height, thumbnail_id=None)

    # NOTE: If function was passed with raw_data, only override if ID is the same as the read
    query = mongoengine.Q(id=user_id)
    if image_id:
        query = query & mongoengine.Q(image=image_id)

    user = User.objects.only('image').filter(query).modify(
        __raw__={'$set': {'image': grid_id}},
        new=False,
    )

    def cleanup():
        # Delete the old image
        if user and user.image:
            yield user.image.grid_id

        # The user image was already changed before the scheduled optimization took place
        # Drop the optimized image
        if user is None and image_id:
            yield image_id

    gridfs_delete.apply_async(kwargs=dict(
        collection=User.image.collection_name,
        grid_ids=list(cleanup()),
    ))

@celery.shared_task(bind=True, ignore_result=True)
def gridfs_delete(task, collection, grid_ids):
    gf = gridfs.GridFS(mongoengine.connection.get_db(), collection)
    for grid_id in grid_ids:
        gf.delete(bson.ObjectId(grid_id))

ref:
http://docs.mongoengine.org/guide/gridfs.html

Store Datetime

MongoDB stores datetimes in UTC.

ref:
https://docs.mongodb.com/manual/reference/method/Date/

2-phase Commit

The easiest way to think about 2-phase commit is idempotency, i.e., if you run a update many times, the results would "be the same": initial -> pending -> applied -> done.

ref:
https://docs.mongodb.com/manual/tutorial/perform-two-phase-commits/

Aggregation Pipeline

  • $match: Filters documents.
  • $project: Modifies document fields.
  • $addFields: Adds or overrides document fields.
  • $group: Groups documents by fields.
  • $lookup: Joins another collection.
  • $replaceRoot: Promotes an embedded document field to the top level and replace all other fields.
  • $unwind: Expanses an array field into multiple documents along with original documents.
  • $facet: Processes multiple pipelines within one stage and output to different fields.

There are special system variables, for instance, $$ROOT, $$REMOVE, $$PRUNE, which you could use in some stages of the aggregation pipeline.

ref:
https://docs.mongodb.com/manual/reference/aggregation-variables/#system-variables

Return Date As Unix Timestamp

import datetime

def stages():
    yield {'$project': {
        'createdAt': {'$floor': {'$divide': [{'$subtract': ['$$created', datetime.datetime.utcfromtimestamp(0)]}, 1000]}},
    }}

try:
    docs = MessagePackProduct.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://stackoverflow.com/questions/39274311/convert-iso-date-to-timestamp-in-mongo-query

Match Multiple Conditions Which Store In An Array Fields

db.getCollection('feature.promotions').insert({
    "name": "TEST",
    "nbf": ISODate("2018-05-31 16:00:00.000Z"),
    "exp": ISODate("2018-06-30 15:59:00.001Z"),
    "positions": {
        "discover": {
            "urls": [
                "https://example.com/events/2018/jun/event1/banner.html"
            ]
        }
    },
    "requirements" : [
        {
            // users who like women and their app version is greater than v2.21
            "preferences" : [
                "gender:female"
            ],
            "version_major_min": 2.0,
            "version_minor_min": 21.0
        },
        {
            // female CPs
            "tags" : [
                "stats",
                "gender:female"
            ]
        }
    ]
});
import werkzeug

user_agent = werkzeug.UserAgent('hello-world/2.25.1 (iPhone; iOS 11.4.1; Scale/2.00; com.example.app; zh-tw)')
user_preferences = ['gender:female', 'gender:male']
user_tags = ['beta', 'vip']
user_platforms = [user_agent.platform]

def stages():
    now = utils.utcnow()

    yield {'$match': {
        '$and': [
            {'nbf': {'$lte': now}},
            {'exp': {'$gt': now}},
            {'requirements': {'$elemMatch': {
                'preferences': {'$not': {'$elemMatch': {'$nin': user_preferences}}},
                'tags': {'$not': {'$elemMatch': {'$nin': user_tags}}},
                'platforms': {'$not': {'$elemMatch': {'$nin': user_platforms}}},
                '$or': [
                    {'$and': [
                        {'version_major_min': {'$lte': user_agent.version.major}},
                        {'version_minor_min': {'$lte': user_agent.version.minor}},
                    ]},
                    {'$and': [
                        {'version_minor_min': {'$exists': False}},
                        {'version_minor_min': {'$exists': False}},
                    ]},
                ],
            }}},
        ],
    }}
    yield {'$project': {
        'name': True,
        'nbf': True,
        'exp': True,
        'positions': {'$objectToArray': '$positions'},
    }}
    yield {'$unwind': '$positions'}
    yield {'$sort': {
        'exp': 1,
    }}
    yield {'$project': {
        '_id': False,
        'name': True,
        'position': '$positions.k',
        'url': {'$arrayElemAt': ['$positions.v.urls', 0]},
        'startedAt': {'$floor': {'$divide': [{'$subtract': ['$nbf', constants.UNIX_EPOCH]}, 1000]}},
        'endedAt': {'$floor': {'$divide': [{'$subtract': ['$exp', constants.UNIX_EPOCH]}, 1000]}},
    }}
    yield {'$group': {
        '_id': '$position',
        'items': {'$push': '$$ROOT'},
    }}

try:
    docs = Promotion.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    docs = list(docs)

ref:
https://docs.mongodb.com/manual/reference/operator/query/in/
https://docs.mongodb.com/manual/reference/operator/query/nin/
https://docs.mongodb.com/manual/reference/operator/aggregation/setIsSubset/

Do Distinct With $group

def stages():
    yield {'$match': {
        'tags': 'some_tag',
    }}
    yield {'$unwind': '$unlocks'}
    yield {'$replaceRoot': {'newRoot': '$unlocks'}}
    yield {'$match': {
        '_cls': 'MessagePackUnlock',
    }}
    yield {'$group': {
        '_id': '$user',
        'timestamp': {'$first': '$timestamp'},
    }}

for unlock in MessagePackMessage.objects.aggregate(*stages()):
    tasks.offline_purchase_pack.apply(kwargs=dict(
        user_id=unlock['_id'],
        message_pack_id=message_pack.id,
        timestamp=unlock['timestamp'],
    ))

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/group/

Slice Items In Each $group

import random

def stages():
    yield {'$match': {'tags': {'$regex': '^badge:'}}}
    yield {'$unwind': {'path': '$tags', 'includeArrayIndex': 'index'}}
    yield {'$match': {'tags': {'$regex': '^badge:'}}}
    yield {'$project': {'_id': True, 'tag': '$tags', 'index': {'$mod': ['$index', random.random()]}}}
    yield {'$sort': {'index': 1}}
    yield {'$group': {'_id': '$tag', 'users': {'$addToSet': '$_id'}}}
    yield {'$project': {'_id': True, 'users': {'$slice': ['$users', 1000]}}}

docs = User.objects.aggregate(*stages())
for doc in docs:
    badge, user_ids = doc['_id'], doc['users']

Collect Items With $group And $addToSet

User data:

{
    "_id" : ObjectId("5a66d5c2af9c462c617ce552"),
    "username" : "gibuloto",
    "tags" : [ 
        "beta"
    ],
    "schedules" : [ 
        {
            "tag" : "stats",
            "nbf" : ISODate("2018-02-01T16:00:00.000Z"),
            "exp" : ISODate("2018-08-12T16:00:00.000Z")
        }, 
        {
            "tag" : "vip",
            "nbf" : ISODate("2018-05-13T16:00:00.000Z"),
            "exp" : ISODate("2018-05-20T16:00:00.000Z")
        }
    ]
}
def stages():
    now = utils.utcnow()

    yield {'$match': {
        'schedules': {'$elemMatch': {
            'nbf': {'$lte': now},
            'exp': {'$gte': now}
        }}
    }}
    yield {'$unwind': '$schedules'}
    yield {'$match': {
        'schedules.nbf': {'$lte': now},
        'schedules.exp': {'$gte': now}
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'username': True,
        'tag': '$schedules.tag',
        'nbf': '$schedules.nbf',
        'exp': '$schedules.exp'
    }}
    yield {'$group': {
        '_id': '$id',
        'tags': {'$addToSet': '$tag'},
    }}

for user_tag_schedule in User.objects.aggregate(*stages()):
    print(user_tag_schedule)

# output:
# {'_id': ObjectId('579b9387b7af8e1fd1635da9'), 'tags': ['stats']}
# {'_id': ObjectId('5a66d5c2af9c462c617ce552'), 'tags': ['chat', 'vip']}

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/group/

Project A New Field Based On Whether Elements Exist In Another Array Field

Use $addFields with $cond.

def stages():
    user_preferences = g.user.settings.preferences or ['gender:female']
    yield {'$match': {
        'gender': {'$in': [prefix_gender.replace('gender:', '') for prefix_gender in user_preferences]}
    }}

    yield {'$addFields': {
        'isPinned': {'$cond': {
            'if': {'$in': [constants.tags.HIDDEN, '$badges']},
            'then': True,
            'else': False,
        }},
    }}
    yield {'$sort': {
        'isPinned': -1,
        'posted_at': -1,
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'author': '$author',
        'title': '$title',
        'location': '$location',
        'postedAt': {'$floor': {'$divide': [{'$subtract': ['$posted_at', constants.UNIX_EPOCH]}, 1000]}},
        'viewCount': '$view_count',
        'commentCount': {'$size': {'$ifNull': ['$comments', []]}},
        'badges': '$badges',
        'isPinned': '$isPinned',
    }}

try:
    results = Post.objects.aggregate(*stages()).next()
except StopIteration:
    return Response(status=http.HTTPStatus.NOT_FOUND)

ref:
https://stackoverflow.com/questions/16512329/project-new-boolean-field-based-on-element-exists-in-an-array-of-a-subdocument
https://docs.mongodb.com/manual/reference/operator/aggregation/project/
https://docs.mongodb.com/manual/reference/operator/aggregation/addFields/
https://docs.mongodb.com/manual/reference/operator/aggregation/cond/

Project And Filter Out Elements Of An Array With $filter

Elements in details might have no value field.

def stages():
    yield {'$match': {
        '_id': bson.ObjectId(post_id),
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'author': '$author',
        'title': '$title',
        'location': '$location',
        'postedAt': {'$floor': {'$divide': [{'$subtract': ['$posted_at', constants.UNIX_EPOCH]}, 1000]}},
        'viewCount': '$view_count',
        'commentCount': {'$size': '$comments'},
        'details': [
            {'key': 'gender', 'value': '$gender'},
            {'key': 'pricing', 'value': '$pricing'},
            {'key': 'lineId', 'value': {'$ifNull': ['$lineId', None]}},
            {'key': 'description', 'value': {'$ifNull': ['$description', None]}},
        ],
    }}
    yield {'$addFields': {
        'details': {
            '$filter': {
                'input': '$details',
                'as': 'detail',
                'cond': {'$ne': ['$$detail.value', None]},
            }
        }
    }}

try:
    post = next(Post.objects.aggregate(*stages()))
except StopIteration:
    return Response(status=http.HTTPStatus.NOT_FOUND)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/filter/#exp._S_filter
https://docs.mongodb.com/manual/reference/operator/aggregation/addFields/

Project Specific Fields Of Elements Of An Array With $map

def stages():
    yield {'$match': {
        '_id': bson.ObjectId(post_id),
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'author': '$author',
        'title': '$title',
        'location': '$location',
        'postedAt': {'$floor': {'$divide': [{'$subtract': ['$posted_at', constants.UNIX_EPOCH]}, 1000]}},
        'viewCount': '$view_count',
        'commentCount': {'$size': '$comments'},
        'details': [
            {'key': 'gender', 'value': '$gender'},
            {'key': 'pricing', 'value': '$pricing'},
            {'key': 'lineId', 'value': {'$ifNull': ['$lineId', None]}},
            {'key': 'description', 'value': {'$ifNull': ['$description', None]}},
        ],
        'media': {
            '$map': {
                'input': '$media',
                'as': 'transcoded_media',
                'in': {
                    'mimetype': '$$transcoded_media.mimetype',
                    'dash': '$$transcoded_media.presets.dash',
                    'hls': '$$transcoded_media.presets.hls',
                    'thumbnail': '$$transcoded_media.thumbnail',
                }
            }
        },
    }}
    yield {'$addFields': {
        'details': {
            '$filter': {
                'input': '$details',
                'as': 'detail',
                'cond': {'$ne': ['$$detail.value', None]},
            }
        }
    }}

try:
    post = next(Post.objects.aggregate(*stages()))
except StopIteration:
    return Response(status=http.HTTPStatus.NOT_FOUND)

ref:
https://stackoverflow.com/questions/33831665/how-to-project-specific-fields-from-a-document-inside-an-array

Do Advanced $project With $let

If you find youself want to do $project twice to tackle some fields, you should use $let.

def stages():
    yield {'$match': {
        'purchases.user': g.user.id,
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'name': True,
        'image': {
            '$ifNull': [{'$arrayElemAt': ['$images', 0]}, None],
        },
        'purchasedAt': {
            '$let': {
                'vars': {
                    'purchase': {
                        '$arrayElemAt': [
                            {
                                '$filter': {
                                    'input': '$purchases',
                                    'as': 'purchase',
                                    'cond': {
                                        '$and': [
                                            {'$eq': ['$$purchase.user', g.user.id]},
                                        ],
                                    },
                                },
                            },
                            0,
                        ],
                    },
                },
                'in': '$$purchase.timestamp',
            },
        },
    }}

try:
    docs = MessagePackProduct.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/let/

Deconstruct An Array Field With $unwind And Query Them With $match

def stages():
    category_tag = 'category:user'
    currency = 'usd'
    platform = 'ios'

    yield {'$match': {
        'active': True,
        'tags': category_tag,
        'total': {'$gt': 0},
        'preview_message': {'$exists': True},
    }}
    yield {'$unwind': '$skus'}
    yield {'$match': {
        'skus.attributes.platform': platform,
        'skus.attributes.currency': currency,
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'name': True,
        'caption': True,
        'description': True,
        'image': {
            '$ifNull': [{'$arrayElemAt': ['$images', 0]}, None],
        },
        'sku': '$skus',
        'created_at': True,
        'is_purchased': {'$in': [g.user.id, {'$ifNull': ['$purchases.user', []]}]},
    }}
    yield {'$sort': {'is_purchased': 1, 'created_at': -1}}

try:
    docs = MessagePackProduct.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/match/
https://docs.mongodb.com/manual/reference/operator/aggregation/unwind/
https://docs.mongodb.com/manual/reference/operator/aggregation/project/

Query The First Element In An Array Field With $arrayElemAt And $filter

def stages():
    category_tag = 'category:user'
    currency = 'usd'
    platform = 'ios'

    yield {'$match': {
        'active': True,
        'tags': category_tag,
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'name': True,
        'caption': True,
        'description': True,
        'image': {
            '$ifNull': [{'$arrayElemAt': ['$images', 0]}, None],
        },
        'preview_message': True,
        'metadata': True,
        'created_at': True,
        'updated_at': True,
        'active': True,
        'sku': {
            '$ifNull': [
                {
                    '$arrayElemAt': [
                        {
                            '$filter': {
                                'input': '$skus',
                                'as': 'sku',
                                'cond': {
                                    '$and': [
                                        {'$eq': ['$$sku.currency', currency]},
                                        {'$eq': ['$$sku.attributes.platform', platform]},
                                    ]
                                }
                            },
                        },
                        0
                    ]
                },
                None
            ],
        },
        'tags': True,
        'total': True,
        'is_bought': {'$in': [g.user.id, {'$ifNull': ['$purchases.user', []]}]},
    }}
    yield {'$sort': {'is_bought': 1, 'created_at': -1}}

try:
    docs = MessagePackProduct.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://docs.mongodb.com/master/reference/operator/aggregation/filter/
https://stackoverflow.com/questions/3985214/retrieve-only-the-queried-element-in-an-object-array-in-mongodb-collection

Join Another Collection Using $lookup

def stages():
    yield {'$match': {
        'tags': 'pack:prod_CR1u34BIpDbHeo',
    }}
    yield {'$lookup': {
        'from': 'user',
        'localField': 'sender',
        'foreignField': '_id',
        'as': 'sender_data',
    }}
    yield {'$unwind': '$sender_data'}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'sender': {
            'id': '$sender_data._id',
            'username': '$sender_data.username',
        },
        'caption': True,
        'posted_at': True,
    }}
    yield {'$sort': {'posted_at': -1}}

try:
    docs = Message.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/lookup/
https://thecodebarbarian.com/a-nodejs-perspective-on-mongodb-36-lookup-expr

Join Another Collection With Multiple Conditions Using pipeline in $lookup

To access the let variables in the $lookup pipeline, you could only use the $expr operator.

var start = ISODate('2018-09-22T00:00:00.000+08:00');

db.getCollection('feature.shop.order').aggregate([
    {'$match': {
        'payment.timestamp': {'$gte': start},
        'status': {'$in': ['paid']},
    }},
    {'$lookup': {
        'from': 'user',
        'localField': 'customer',
        'foreignField': '_id',
        'as': 'customer_data',
    }},
    {'$unwind': '$customer_data'},
    {'$project': {
        'variation': '$customer_data.experiments.message_unlock_price.variation',
        'amount_normalized': {'$divide': ['$amount', 100.0]},
    }},
    {'$addFields': {
        'amount_usd': {'$multiply': ['$amount_normalized', 0.033]},
    }},
    {'$group': {
       '_id': '$variation',
       'purchase_amount': {'$sum': '$amount_usd'},
       'paid_user_count': {'$sum': 1},
    }},
    {'$lookup': {
        'from': 'user',
        'let': {
            'variation': '$_id',
        },
        'pipeline': [
            {'$match': {
                'last_active': {'$gte': start},
                'experiments': {'$exists': true},
            }},
            {'$match': {
                '$expr': {
                    '$and': [
                         {'$eq': ['$experiments.message_unlock_price.variation', '$$variation']},
                    ],
                },
            }},
            {'$group': {
               '_id': '$experiments.message_unlock_price.variation',
               'count': {'$sum': 1},
            }},
        ],
        'as': 'variation_data',
    }},
    {'$unwind': '$variation_data'},
    {'$project': {
        '_id': 1,
        'purchase_amount': 1,
        'paid_user_count': 1,
        'total_user_count': '$variation_data.count',
    }},
    {'$addFields': {
        'since': start,
        'arpu': {'$divide': ['$purchase_amount', '$total_user_count']},
        'arppu': {'$divide': ['$purchase_amount', '$paid_user_count']},
    }},
    {'$sort': {'_id': 1}},
]);

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/lookup/#join-conditions-and-uncorrelated-sub-queries

or

def stages():
    yield {'$match': {'_id': bson.ObjectId(message_id)}}
    yield {'$limit': 1}
    yield {'$project': {
        '_cls': 1,
        'sender': 1,
        'unlocks': 1,
    }}
    yield {'$unwind': '$unlocks'}
    yield {'$match': {
        'unlocks.user': bson.ObjectId(user_id),
        'unlocks.amount': {'$gt': 0},
    }}
    yield {'$lookup': {
        'from': 'user',
        'let': {
            'sender': '$sender',
            'unlocker': '$unlocks.user',
        },
        'pipeline': [
            {'$match': {
                '$expr': {
                    '$or': [
                        {'$eq': ['$_id', '$$sender']},
                        {'$eq': ['$_id', '$$unlocker']}
                    ]
                }
            }}
        ],
        'as': 'users',
    }}
    yield {'$addFields': {
        'sender': {'$arrayElemAt': ['$users', 0]},
        'unlocker': {'$arrayElemAt': ['$users', 1]},
    }},
    yield {'$project': {
        '_id': 0,
        '_cls': 1,
        'id': '$_id',
        'sender': {
            'id': '$sender._id',
            'username': '$sender.username',
        },
        'unlocker': {
            'id': '$unlocker._id',
            'username': '$unlocker.username',
        },
        'amount': '$unlocks.amount',
    }}

try:
    context = Message.objects.aggregate(*stages()).next()
except StopIteration:
    pass

ref:
https://stackoverflow.com/questions/37086387/multiple-join-conditions-using-the-lookup-operator
https://docs.mongodb.com/manual/reference/operator/aggregation/lookup/#specify-multiple-join-conditions-with-lookup

Count Documents In Another Collection With $lookup (JOIN)

def stages():
    category_tag = f'category:{category}'
    yield {'$match': {
        'active': True,
        'tags': category_tag,
    }}
    yield {'$addFields': {
        'message_pack_id_tag': {'$concat': ['pack:', '$_id']},
    }}
    yield {'$lookup': {
        'from': 'message',
        'localField': 'message_pack_id_tag',
        'foreignField': 'tags',
        'as': 'total',
    }}
    yield {'$addFields': {
        'total': {'$size': '$total'}
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'name': True,
        'total': True,
    }}

try:
    docs = MessagePackProduct.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/lookup/#equality-match

Use $lookup as findOne() Which Returns An Object

Use $lookup and $unwind.

import bson

def stages():
    yield {'$match': {'_id': bson.ObjectId(gift_id)}}
    yield {'$limit': 1}
    yield {'$lookup': {
        'from': 'user',
        'localField': 'sender',
        'foreignField': '_id',
        'as': 'sender',
    }}
    yield {'$unwind': '$sender'}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'sender': {
            'id': '$sender._id',
            'username': '$sender.username',
        },
        'product_id': '$product._id',
        'sent_at': '$sent_at',
        'amount': '$cost.amount',
    }}

try:
    _context = Gift.objects.aggregate(*stages()).next()
except StopIteration:
    pass

ref:
https://stackoverflow.com/questions/37691727/how-to-use-mongodbs-aggregate-lookup-as-findone

Collapse Documents In An Array

def stages():
    yield {'$match': {
        'tags': f'tutorial:buy-diamonds:v1',
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'caption.text': True,
        'sender': True,
        'media.type': '$media.mimetype',
    }}
    yield {'$facet': {
        'inbox': [
            {'$sort': {'created_at': -1}},
            {'$limit': 10}
        ],
    }}
    yield {'$project': {
        'inbox': True,
        'required_unlock_count': {'$literal': 5},
        'price_per_message': {'$literal': 1200},
    }}

try:
    result = Message.objects.aggregate(*stages()).next()
except StopIteration:
    result = {}

JSON output:

{
    "inbox": [
        {
            "caption": {
                "text": "fuck yeah"
            },
            "id": "5aaba1e9593950337a90dcb3",
            "media": {
                "type": "video/mp4"
            },
            "sender": "5a66d5c2af9c462c617ce552"
        },
        {
            "caption": {
                "text": "test"
            },
            "id": "5ad549276b2c362a4efe5e21",
            "media": {
                "type": "image/jpeg"
            },
            "sender": "5a66d5c2af9c462c617ce552"
        }
    ],
    "price_per_message": 1200,
    "required_unlock_count": 5
}

Do Pagination With $facet And $project

def stages():
    # normal query
    yield {'$match': {
        'purchases.user': g.user.id,
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'name': True,
        'created_at': True,
        'meta': {
            'revision': '$revision',
            'tags': '$tags',
        },
    }}
    yield {'$sort': {'created_at': -1}}

    # pagination
    page = 0
    limit = 10
    yield {'$facet': {
        'meta': [
            {'$count': 'total'},
        ],
        'objects': [
            {'$skip': page * limit},
            {'$limit': limit},
        ]
    }}
    # JSON output:
    # {
    #    "meta": [
    #       {"total": 2}
    #    ],
    #    "objects": [
    #       {
    #          "id": "prod_CR1u34BIpDbHeo",
    #          "name": "Product Name 2"
    #       },
    #       {
    #          "id": "prod_Fkhf9JFK3Rdgk9",
    #          "name": "Product Name 1"
    #       }
    #    ]
    # }
    yield {'$project': {
        'total': {'$let': {
            'vars': {
                'meta': {'$arrayElemAt': ['$meta', 0]},
            },
            'in': '$$meta.total',
        }},
        'objects': True,
    }}
    # JSON output:
    # {
    #    "total": 2,
    #    "objects": [
    #       {
    #          "id": "prod_CR1u34BIpDbHeo",
    #          "name": "Product Name 2"
    #       },
    #       {
    #          "id": "prod_Fkhf9JFK3Rdgk9",
    #          "name": "Product Name 1"
    #       }
    #    ]
    # }

try:
    output = MessagePackProduct.objects.aggregate(*stages()).next()
except StopIteration:
    output = {}
else:
    print(output)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/facet/
https://docs.mongodb.com/manual/reference/operator/aggregation/project/

Perform $facet + $project => Unwrap with $unwind => Do $facet + $project Again

def stages():
    yield {'$match': {
        'purchases.user': g.user.id,
    }}
    yield {'$project': {
        '_id': False,
        'id': '$_id',
        'name': True,
        'image': {
            '$ifNull': [{'$arrayElemAt': ['$images', 0]}, None],
        },
        'created_at': True,
    }}
    yield {'$sort': {'created_at': -1}}

    # pagination
    page = 0
    limit = 10
    yield {'$facet': {
        'meta': [
            {'$count': 'total'},
        ],
        'objects': [
            {'$skip': page * limit},
            {'$limit': limit},
        ]
    }}
    yield {'$project': {
        'total': {'$let': {
            'vars': {
                'meta': {'$arrayElemAt': ['$meta', 0]},
            },
            'in': '$$meta.total',
        }},
        'objects': True,
    }}

    # do $lookup after the pagination
    yield {'$unwind': '$objects'}
    yield {'$addFields': {
        'objects.message_pack_id_tag': {'$concat': ['pack:', '$objects.id']},
    }}
    yield {'$lookup': {
        'from': 'message',
        'localField': 'objects.message_pack_id_tag',
        'foreignField': 'tags',
        'as': 'objects.total',
    }}
    yield {'$addFields': {
        'objects.total': {'$size': '$objects.total'}
    }}

    # re-wrap into the pagination structure
    yield {'$facet': {
        'total_list': [
            {'$project': {
                'total': True,
            }},
        ],
        'objects': [
            {'$replaceRoot': {'newRoot': '$objects'}},
        ]
    }}
    yield {'$project': {
        'total': {'$let': {
            'vars': {
                'meta': {'$arrayElemAt': ['$total_list', 0]},
            },
            'in': '$$meta.total',
        }},
        'objects': True,
    }}

try:
    output = MessagePackProduct.objects.aggregate(*stages()).next()
except StopIteration:
    output = {}
else:
    print(output)

Do $group First To Reduce Numbers Of $lookup Calls

def stages():
    yield {'$match': {
        'tags': f'pack:{message_pack_id}',
    }}
    yield {'$group': {
        '_id': '$sender',
        'messages': {'$push': '$$ROOT'},
    }}
    yield {'$lookup': {
        'from': 'user',
        'localField': '_id',
        'foreignField': '_id',
        'as': 'sender_data',
    }}
    yield {'$unwind': '$messages'}
    yield {'$project': {
        '_id': False,
        'id': '$messages._id',
        'caption': {
            'text': '$messages.caption.text',
            'y': '$messages.caption.y',
        },
        'sender': {
            'id': {'$arrayElemAt': ['$sender_data._id', 0]},
            'username': {'$arrayElemAt': ['$sender_data.username', 0]},
        },
    }}

try:
    docs = Message.objects.aggregate(*stages())
except StopIteration:
    docs = []
else:
    for doc in docs:
        print(doc)

ref:
https://docs.mongodb.com/manual/reference/operator/aggregation/group/

Copy Collections To Another Database

var bulk = db.getSiblingDB('target_db')['target_collection'].initializeOrderedBulkOp();
db.getCollection('source_collection').find().forEach(function(d) {
    bulk.insert(d);
});
bulk.execute();

var bulk = db.getSiblingDB('test')['company.revenue'].initializeOrderedBulkOp();
db.getCollection('company.revenue').find().forEach(function(d) {
    bulk.insert(d);
});
bulk.execute();

var bulk = db.getSiblingDB('test')['user.contract'].initializeOrderedBulkOp();
db.getCollection('user.contract').find().forEach(function(d) {
    bulk.insert(d);
});
bulk.execute();

var bulk = db.getSiblingDB('test')['user.revenue'].initializeOrderedBulkOp();
db.getCollection('user.revenue').find().forEach(function(d) {
    bulk.insert(d);
});
bulk.execute();

ref:
https://stackoverflow.com/questions/11554762/how-to-copy-a-collection-from-one-database-to-another-in-mongodb

Sadly, cloneCollection() cannot clone collections from one local database to another local database.

ref:
https://docs.mongodb.com/manual/reference/command/cloneCollection/

Useful Tools

Backup

$ mongodump -h  127.0.0.1:27017 --oplog -j=8 --gzip --archive=/data/mongodump.tar.gz

ref:
https://docs.mongodb.com/manual/reference/program/mongodump/

Restore

$ mongorestore --drop --gzip --archive=2018-08-12T03.tar.gz

This kind of error typically indicates some sort of issue with data corruption, which is often caused by problems with the underlying storage device, file system or network connection.

restoring indexes for collection test.message from metadata
Failed: test.message: error creating indexes for test.message: createIndex error: BSONElement: bad type -47

ref:
https://docs.mongodb.com/manual/reference/program/mongorestore/

Profiling

You could also set the profiling level to 2 to record every query.

db.setProfilingLevel(2);

db.getCollection('system.profile').find({
    'ns': { 
        '$nin' : ['test.system.profile', 'test.system.indexes', 'test.system.js', 'test.system.users']
    }
}).limit(5).sort({'ts': -1}).pretty();

ref:
https://docs.mongodb.com/manual/tutorial/manage-the-database-profiler/
https://stackoverflow.com/questions/15204341/mongodb-logging-all-queries

$ pip install mongotail

# set the profiling level
$ mongotail 127.0.0.1:27017/test -l 2

# tail logs
$ mongotail 127.0.0.1:27017/test -f -m -f

ref:
https://github.com/mrsarm/mongotail

Monitoring

$ mongotop
$ mongostat

ref:
https://docs.mongodb.com/manual/reference/program/mongotop/
https://docs.mongodb.com/manual/reference/program/mongostat/

$ pip install mtools

$ mloginfo mongod.log

ref:
https://github.com/rueckstiess/mtools

Run a Celery task at a specific time

Run a Celery task at a specific time

Schedule Tasks

You are able to run any Celery task at a specific time through eta (means "Estimated Time of Arrival") parameter.

import datetime

import celery

@celery.shared_task(bind=True)
def add_tag(task, user_id, tag):
    User.objects.filter(id=user_id, tags__ne=tag).update(push__tags=tag)
    return True

user_id = '582ee32a5b9c861c87dc297e'
tag = 'new_tag'
started_at = datetime.datetime(2018, 3, 12, tzinfo=datetime.timezone.utc)
add_tag.apply_async((user_id, tag), eta=started_at)

ref:
http://docs.celeryproject.org/en/master/userguide/calling.html#eta-and-countdown

Revoke Tasks

Revoked tasks will be discarded until their eta.

from celery.result import AsyncResult

AsyncResult(task_id).revoke()

ref:
http://docs.celeryproject.org/en/latest/reference/celery.result.html#celery.result.AsyncResult.revoke

Revoking tasks works by sending a broadcast message to all the workers, the workers then keep a list of revoked tasks in memory. When a worker starts up it will synchronize revoked tasks with other workers in the cluster.

The list of revoked tasks is in-memory so if all workers restart the list of revoked ids will also vanish. If you want to preserve this list between restarts you need to specify a file for these to be stored in by using the –statedb argument to celery worker.

ref:
http://docs.celeryproject.org/en/latest/userguide/workers.html#worker-persistent-revokes

Timezone in Python: Offset-naive and Offset-aware datetimes

Timezone in Python: Offset-naive and Offset-aware datetimes

TL;DR: You should always store datetimes in UTC and convert to proper timezone on display.

A timezone offset refers to how many hours the timezone is from Coordinated Universal Time (UTC). The offset of UTC is +00:00, and the offset of Asia/Taipei timezone is UTC+08:00 (you could also present it as GMT+08:00). Basically, there is no perceptible difference between Greenwich Mean Time (GMT) and UTC.

The local time subtracts the offset of its timezone is UTC time. For instance, 18:00+08:00 of Asia/Taipei minuses timezone offset +08:00 is 10:00+00:00, 10 o'clock of UTC. On the other hand, UTC time pluses local timezone offset is local time.

ref:
https://opensource.com/article/17/5/understanding-datetime-python-primer
https://julien.danjou.info/blog/2015/python-and-timezones

到底是 GMT+8 還是 UTC+8?
http://pansci.asia/archives/84978

Installation

$ pip install -U python-dateutil pytz tzlocal

Show System Timezone

import tzlocal

tzlocal.get_localzone()
# <DstTzInfo 'Asia/Taipei' LMT+8:06:00 STD>

tzlocal.get_localzone().zone
# 'Asia/Taipei'

from time import gmtime, strftime
print(strftime("%z", gmtime()))
# +0800

ref:
https://github.com/regebro/tzlocal
https://stackoverflow.com/questions/13218506/how-to-get-system-timezone-setting-and-pass-it-to-pytz-timezone/

Find Timezones Of A Certain Country

import pytz

pytz.country_timezones('tw')
# ['Asia/Taipei']

pytz.country_timezones('cn')
# ['Asia/Shanghai', 'Asia/Urumqi']

ref:
https://pythonhosted.org/pytz/#country-information

Offset-naive Datetime

Any naive datetime would be present as local timezone but without tzinfo, so it is buggy.

A naive datetime object contains no timezone information. The datetime_obj.tzinfo will be set to None if the object is naive. Actually, datetime objects without timezone should be considered as a "bug" in your application. It is up for the programmer to keep track of which timezone users are working in.

import datetime

import dateutil.parser

datetime.datetime.now()
# return the current date and time in local timezone, in this example: Asia/Taipei (UTC+08:00)
# datetime.datetime(2018, 2, 2, 9, 15, 6, 211358)), naive

datetime.datetime.utcnow()
# return the current date and time in UTC
# datetime.datetime(2018, 2, 2, 1, 15, 6, 211358), naive

dateutil.parser.parse('2018-02-04T16:30:00')
# datetime.datetime(2018, 2, 4, 16, 30), naive

ref:
https://docs.python.org/3/library/datetime.html
https://dateutil.readthedocs.io/en/stable/

Offset-aware Datetime

A aware datetime object embeds a timezone information. Rules of thumb for timezone in Python:

  • Always work with "offset-aware" datetime objects.
  • Always store datetime in UTC and do timezone conversion only when interacting with users.
  • Always use ISO 8601 as input and output string format.

There are two useful methods: pytz.utc.localize(naive_dt) for converting naive datetime to timezone be offset-aware, and aware_dt.astimezone(pytz.timezone('Asia/Taipei')) for adjusting timezones of offset-aware objects.

You should avoid naive_dt.astimezone(some_tzinfo) which would be converted to aware datetime as system timezone then convert to some_tzinfo timezone.

import datetime

import pytz

now_utc = pytz.utc.localize(datetime.datetime.utcnow())
# equals to datetime.datetime.now(pytz.utc)
# equals to datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
# datetime.datetime(2018, 2, 4, 10, 17, 40, 679562, tzinfo=<UTC>), aware

now_taipei = now_utc.astimezone(pytz.timezone('Asia/Taipei'))
# convert to another timezone
# datetime.datetime(2018, 2, 4, 18, 17, 40, 679562, tzinfo=<DstTzInfo 'Asia/Taipei' CST+8:00:00 STD>), aware

now_utc.isoformat()
# '2018-02-04T10:17:40.679562+00:00'

now_taipei.isoformat()
# '2018-02-04T18:17:40.679562+08:00'

now_utc == now_taipei
# True

For working with pytz, it is recommended to call tz.localize(naive_dt) instead of naive_dt.replace(tzinfo=tz). dt.replace(tzinfo=tz) does not handle daylight savings time correctly.

dt1 = datetime.datetime.now(pytz.timezone('Asia/Taipei'))
# datetime.datetime(2018, 2, 4, 18, 22, 28, 409332, tzinfo=<DstTzInfo 'Asia/Taipei' CST+8:00:00 STD>), aware

dt2 = datetime.datetime(2018, 2, 4, 18, 22, 28, 409332, tzinfo=pytz.timezone('Asia/Taipei'))
# datetime.datetime(2018, 2, 4, 18, 22, 28, 409332, tzinfo=<DstTzInfo 'Asia/Taipei' LMT+8:06:00 STD>), aware

dt1 == dt2
# False

ref:
https://pythonhosted.org/pytz/

Naive and aware datetime objects are not comparable.

naive = datetime.datetime.utcnow()
aware = pytz.utc.localize(naive)

naive == aware
# False

naive >= aware
# TypeError: can't compare offset-naive and offset-aware datetimes

Parse String to Datetime

python-dateutil usually comes in handy.

import dateutil.parser
import dateutil.tz

dt1 = dateutil.parser.parse('2018-02-04T19:30:00+08:00')
# datetime.datetime(2018, 2, 4, 19, 30, tzinfo=tzoffset(None, 28800)), aware

dt2 = dateutil.parser.parse('2018-02-04T11:30:00+00:00')
# datetime.datetime(2018, 2, 4, 11, 30, tzinfo=tzutc()), aware

dt3 = dateutil.parser.parse('2018-02-04T11:30:00Z')
# datetime.datetime(2018, 2, 4, 11, 30, tzinfo=tzutc()), aware

dt1 == dt2 == dt3
# True

ref:
https://dateutil.readthedocs.io/en/stable/

Convert Datetime To Unix Timestamp

import datetime

naive_dt = datetime.datetime(2018, 9, 10, 0, 0, 0)
naive_timestamp = aware_dt.timestamp()
# naive_dt would be in local timezone, in this example: Asia/Taipei (UTC+08:00)

aware_dt = datetime.datetime(2018, 9, 10, 0, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=8)))
aware_timestamp = aware_dt.timestamp()

naive_timestamp == aware_timestamp
# True

# MongoDB stores all datetimes in UTC timezone
dt_fetched_from_mongodb.replace(tzinfo=datetime.timezone.utc).timestamp()

Parse Unix Timestamp To Datetime

import datetime
import time

import pytz

ts = time.time()
# seconds since the Epoch (1970-01-01T00:00:00 in UTC)
# 1517748706.063205

dt1 = datetime.datetime.fromtimestamp(ts)
# return the date and time of the timestamp in local timezone, in this example: Asia/Taipei (UTC+08:00)
# datetime.datetime(2018, 2, 4, 20, 51, 46, 63205), naive

dt2 = datetime.datetime.utcfromtimestamp(ts)
# return the date and time of the timestamp in UTC timezone
# datetime.datetime(2018, 2, 4, 12, 51, 46, 63205), naive

pytz.timezone('Asia/Taipei').localize(dt1) == pytz.utc.localize(dt2)
# True

ref:
https://stackoverflow.com/questions/13890935/does-pythons-time-time-return-the-local-or-utc-timestamp

We might receive an Unix timestamp from a JavaScript client.

var moment = require('moment')
var ts = moment('2018-02-02').unix()
// 1517500800

ref:
https://momentjs.com/docs/#/parsing/unix-timestamp/

Store Datetime In Databases

  • MySQL lets developers decide what timezone should be used, and you should convert datetime to UTC before saving into database.
  • MongoDB assumes that all the timestamp are in UTC, and you have to normalize datetime to UTC.

ref:
https://tommikaikkonen.github.io/timezones/
https://blog.elsdoerfer.name/2008/03/03/fun-with-timezones-in-django-mysql/

Tools

ref:
https://www.epochconverter.com/
https://www.timeanddate.com/worldclock/converter.html

Build a recommender system with Spark: Logistic Regression

Build a recommender system with Spark: Logistic Regression

在這個系列的文章裡,我們將使用 Apache Spark、XGBoost、Elasticsearch 和 MySQL 等工具來搭建一個推薦系統的 Machine Learning Pipeline。推薦系統的組成可以粗略地分成 Candidate Generation 和 Ranking 兩個部分,前者是針對用戶產生候選物品集,常用的方法有 Collaborative Filtering、Content-based、標籤配對、熱門排行或人工精選等;後者則是對這些候選物品排序,以 Top N 的方式呈現最終的推薦結果,常用的方法有 Logistic Regression。

在本篇文章中,我們將以 Ranking 階段常用的方法之一:Logistic Regression 邏輯迴歸為例,利用 Apache Spark 的 Logistic Regression 模型建立一個 GitHub repositories 的推薦系統,以用戶對 repo 的打星紀錄和用戶與 repo 的各項屬性做為特徵,預測出用戶會不會打星某個 repo(分類問題)。最後訓練出來的模型就可以做為我們的推薦系統的 Ranking 模組。不過因為 LR 是線性模型,所以通常需要大量的 Feature Engineering 來習得非線性關係。所以這篇文章的重點是 Spark ML 的 Pipeline 機制和特徵工程,不會在演算法的部分著墨太多。

完整的程式碼可以在 https://github.com/vinta/albedo 找到。

系列文章:

Submit the Application

$ spark-submit \
--master spark://192.168.10.100:7077 \
--packages "com.github.fommil.netlib:all:1.1.2,com.hankcs:hanlp:portable-1.3.4,mysql:mysql-connector-java:5.1.41" \
--class ws.vinta.albedo.UserProfileBuilder \
target/albedo-1.0.0-SNAPSHOT.jar

$ spark-submit \
--master spark://192.168.10.100:7077 \
--packages "com.github.fommil.netlib:all:1.1.2,com.hankcs:hanlp:portable-1.3.4,mysql:mysql-connector-java:5.1.41" \
--class ws.vinta.albedo.RepoProfileBuilder \
target/albedo-1.0.0-SNAPSHOT.jar

$ spark-submit \
--master spark://192.168.10.100:7077 \
--packages "com.github.fommil.netlib:all:1.1.2,com.hankcs:hanlp:portable-1.3.4,mysql:mysql-connector-java:5.1.41" \
--class ws.vinta.albedo.LogisticRegressionRanker \
target/albedo-1.0.0-SNAPSHOT.jar

ref:
https://vinta.ws/code/setup-spark-scala-and-maven-with-intellij-idea.html
https://spark.apache.org/docs/latest/submitting-applications.html
https://spoddutur.github.io/spark-notes/distribution_of_executors_cores_and_memory_for_spark_application

Load Data

我們之前已經利用 GitHub API 和 BigQuery 上的 GitHub Archive 收集了 150 萬筆的打星紀錄和所屬的用戶、repo 數據。目前有以下幾個數據集,大致上是照著 GitHub API 建立的,欄位分別如下:

rawUserInfoDS.printSchema()
// root
 // |-- user_id: integer (nullable = true)
 // |-- user_login: string (nullable = true)
 // |-- user_account_type: string (nullable = true)
 // |-- user_name: string (nullable = true)
 // |-- user_company: string (nullable = true)
 // |-- user_blog: string (nullable = true)
 // |-- user_location: string (nullable = true)
 // |-- user_email: string (nullable = true)
 // |-- user_bio: string (nullable = true)
 // |-- user_public_repos_count: integer (nullable = true)
 // |-- user_public_gists_count: integer (nullable = true)
 // |-- user_followers_count: integer (nullable = true)
 // |-- user_following_count: integer (nullable = true)
 // |-- user_created_at: timestamp (nullable = true)
 // |-- user_updated_at: timestamp (nullable = true)

rawRepoInfoDS.printSchema()
// |-- repo_id: integer (nullable = true)
 // |-- repo_owner_id: integer (nullable = true)
 // |-- repo_owner_username: string (nullable = true)
 // |-- repo_owner_type: string (nullable = true)
 // |-- repo_name: string (nullable = true)
 // |-- repo_full_name: string (nullable = true)
 // |-- repo_description: string (nullable = true)
 // |-- repo_language: string (nullable = true)
 // |-- repo_created_at: timestamp (nullable = true)
 // |-- repo_updated_at: timestamp (nullable = true)
 // |-- repo_pushed_at: timestamp (nullable = true)
 // |-- repo_homepage: string (nullable = true)
 // |-- repo_size: integer (nullable = true)
 // |-- repo_stargazers_count: integer (nullable = true)
 // |-- repo_forks_count: integer (nullable = true)
 // |-- repo_subscribers_count: integer (nullable = true)
 // |-- repo_is_fork: boolean (nullable = true)
 // |-- repo_has_issues: boolean (nullable = true)
 // |-- repo_has_projects: boolean (nullable = true)
 // |-- repo_has_downloads: boolean (nullable = true)
 // |-- repo_has_wiki: boolean (nullable = true)
 // |-- repo_has_pages: boolean (nullable = true)
 // |-- repo_open_issues_count: integer (nullable = true)
 // |-- repo_topics: string (nullable = true)

rawStarringDS.printSchema()
// root
 // |-- user_id: integer (nullable = true)
 // |-- repo_id: integer (nullable = true)
 // |-- starred_at: timestamp (nullable = true)
 // |-- starring: double (nullable = true)

ref:
https://www.githubarchive.org/
http://ghtorrent.org/

載入資料之後,要做的第一件事應該就是 Exploratory Data Analysis (EDA) 了,把玩一下手上的數據。建議大家可以試試 Apache Zeppelin 或是 Databricks 的 Notebook,除了內建 Spark 支援的所有語言之外,也整合了 NoSQL 和 JDBC 支援的資料庫,要畫圖表也很方便,簡直比 Jupyter Notebook 還好用了。

ref:
https://zeppelin.apache.org/
https://databricks.com/

Build User Profile / Item Profile

在這個專案中,最主要的數據主體就是 user 和 repo,所以我們針對兩者各自建立 User Profile 和 Item Profile,作為之後在模型訓練階段會使用的特徵。我們把這個步驟跟模型訓練的流程分開,這樣對整個架構的搭建會更有彈性。實務上,我們可以用 user id 或 item id 當 key,直接把製作好的特徵存進 Redis 或其他 schemaless 的 NoSQL 資料庫,方便之後給多個模型取用;在做 real-time 推薦時,也可以很快地拿到特徵,只需要重新計算部份欄位即可。

不過因為這裡主要用的是來自 GitHub API 的資料,某種程度上人家已經幫我們做了很多資料清理和正規化的動作了,但是在現實中,你的系統要處理的數據通常不會這麼乾淨,可能來自各種 data source、有著各種格式,還會隨著時間而改變,通常得花上不少力氣做 Extract, Transform, Load (ETL),所以最好在寫 log(埋點)的時候就溝通好。而且在 production 環境中,數據是會一直變動的,要確保數據的時效性和容錯性,很重要的一個部分就是 monitoring。

ref:
http://www.algorithmdog.com/ad-rec-deploy
https://tech.meituan.com/online-feature-system02.html

礙於篇幅有限,以下的文章中我們只會挑幾個重要的部分說明。簡單說,在這個步驟的最後,我們會得到 userProfileDFrepoProfileDF 這兩個 DataFrame,分別存放製作好的特徵。詳細的程式碼如下:

ref:
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/UserProfileBuilder.scala
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/RepoProfileBuilder.scala

Feature Engineering

以推薦系統為例,特徵可以分成以下四種:

  • 用戶特徵:用戶本身的各種屬性,例如 user id、性別、職業或所在的城市等
  • 物品特徵:物品本身的各種屬性,例如 item id、作者、標題、分類、評分或所屬的標籤等
  • 交互特徵:用戶對物品做出的某項行為,該行為的 aggregation 或交叉特徵,例如是否看過同類型的電影、最近聽的歌曲的曲風分佈或上週買了多少高單價的商品
  • 上下文特徵:用戶對物品做出的某項行為,該行為的 metadata,例如發生的時間、使用的裝置或當前的 GPS 位置等

有些特徵是在資料採集階段就能拿到,有些特徵則會需要額外的步驟(例如透過外部的 API 或是其他模型)才能取得,也有些特徵必須即時更新。順道一提,因為我們要預測的是「某個用戶會不會打星某個 repo」,所以下述特徵裡的 user 可以是 repo stargazer 也可以是 repo owner。

原始特徵:

  • 用戶特徵
    • user_id
    • user_login
    • user_name
    • user_email
    • user_blog
    • user_bio
    • user_company
    • user_location
    • user_followers_coung
    • user_following_count
    • user_public_repos_count
    • user_public_gists_count
    • user_created_at
    • user_updated_at
  • 物品特徵
    • repo_id
    • repo_name
    • repo_owner
    • repo_owner_type
    • repo_language
    • repo_description
    • repo_homepage
    • repo_subscribers_count
    • repo_stargazers_count
    • repo_forks_count
    • repo_size
    • repo_created_at
    • repo_updated_at
    • repo_pushed_at
    • repo_has_issues
    • repo_has_projects
    • repo_has_downloads
    • repo_has_wiki
    • repo_has_pages
    • repo_open_issues_count
    • repo_topics
  • 交互特徵
    • user_stars_repo
    • user_follows_user
  • 上下文特徵
    • user_repo_starred_at

發想特徵:

  • 用戶特徵
    • user_days_between_created_at_today: 該用戶的註冊日期距離今天過了幾年
    • user_days_between_updated_at_today: 該用戶的更新日期距離今天過了幾天
    • user_repos_avg_stargazers_count: 該用戶名下的所有 repo(不含 fork 的)的平均星星數
    • user_organizations: 該用戶屬於哪些組織
    • user_has_null: 該用戶至少有一個欄位是 null
    • user_has_blog: 該用戶有沒有網站
    • user_is_freelancer: 該用戶的 bio 中是否包含 Freelancer 等字眼
    • user_is_junior: 該用戶的 bio 中是否包含 Beginner 或 Junior 等字眼
    • user_is_lead: 該用戶的 bio 中是否包含 Team Lead、Architect、Creator、CTO 或 VP of Engineering 等字眼
    • user_is_scholar: 該用戶的 bio 中是否包含 Researcher、Scientist、PhD 或 Professor 等字眼
    • user_is_pm: 該用戶的 bio 中是否包含 Product Manager 等字眼
    • user_knows_backend: 該用戶的 bio 中是否包含 Backend 或 Back end 等字眼
    • user_knows_data: 該用戶的 bio 中是否包含 Machine Learning、Deep Learning 或 Data Science 等字眼
    • user_knows_devops: 該用戶的 bio 中是否包含 DevOps、SRE、SysAdmin 或 Infrastructure 等字眼
    • user_knows_frontend: 該用戶的 bio 中是否包含 Frontend 或 Front end 等字眼
    • user_knows_mobile: 該用戶的 bio 中是否包含 Mobile、iOS 或 Android 等字眼
    • user_knows_recsys: 該用戶的 bio 中是否包含 Recommender System、Data Mining 或 Information Retrieval 等字眼
    • user_knows_web: 該用戶的 bio 中是否包含 Web Development 或 Fullstack 等字眼
  • 物品特徵
    • repo_created_at_days_since_today: 該 repo 的建立日期距離今天過了幾天
    • repo_updated_at_days_since_today: 該 repo 的更新日期距離今天過了幾天
    • repo_pushed_at_days_since_today: 該 repo 的提交日期距離今天過了幾天
    • repo_stargazers_count_in_30days: 該 repo 在 30 天內收到的星星數
    • repo_subscribers_stargazers_ratio: 該 repo 的 watch 數和 star 數的比例
    • repo_forks_stargazers_ratio: 該 repo 的 fork 數和 star 數的比例
    • repo_open_issues_stargazers_ratio: 該 repo 的 數和 star 數的比例
    • repo_releases_count: 該 repo 的 release 或 tag 數
    • repo_lisence: 該 repo 的授權條款
    • repo_readme: 該 repo 的 README 內容
    • repo_has_null: 該 repo 有至少一個欄位是 null
    • repo_has_readme: 該 repo 是否有 README 檔案
    • repo_has_changelog: 該 repo 是否有 CHANGELOG 檔案
    • repo_has_contributing: 該 repo 是否有 CONTRIBUTING 檔案
    • repo_has_tests: 該 repo 是否有測試
    • repo_has_ci: 該 repo 是否有 CI
    • repo_has_dockerfile: 該 repo 是否有 Dockerfile
    • repo_is_unmaintained: 該 repo 是否不再維護了
    • repo_is_awesome: 該 repo 是否被收錄進任何的 awesome-xxx 列表裡
    • repo_is_vinta_starred: 該 repo 是否被 @vinta aka 本文的作者打星了
  • 交互特徵
    • user_starred_repos_count: 該用戶總共打星了多少 repo
    • user_avg_daily_starred_repos_count: 該用戶平均每天打星多少 repo
    • user_forked_repos_count: 該用戶總共 fork 了多少 repo
    • user_follower_following_count_ratio: 該用戶的 follower 數和 following 數的比例
    • user_recent_searched_keywords: 該用戶最近搜尋的 50 個關鍵字
    • user_recent_commented_repos: 該用戶最近留言的 50 個 repo
    • user_recent_watched_repos: 該用戶最近訂閱的 50 個 repo
    • user_recent_starred_repos_descriptions: 該用戶最近打星的 50 個 repo 的描述
    • user_recent_starred_repos_languages: 該用戶最近打星的 50 個 repo 的語言
    • user_recent_starred_repos_topics: 該用戶最近打星的 50 個 repo 的標籤
    • user_follows_repo_owner: 該用戶是否追蹤該 repo 的作者
    • repo_language_index_in_user_recent_repo_languages: 該 repo 的語言出現在該用戶最近打星的語言列表的順序
    • repo_language_count_in_user_recent_repo_languages: 該 repo 的語言出現在該用戶最近打星的語言列表的次數
    • repo_topics_user_recent_topics_similarity: 該 repo 的標籤與該用戶最近打星的標籤列表的相似度
  • 上下文特徵
    • als_model_prediction: 來自 ALS 模型的預測值,該用戶對該 repo 的偏好程度
    • gbdt_model_index: 來自 GBDT 模型的 tree index,該 observation 的自動特徵

Feature Engineering 特徵工程中常見的方法
https://vinta.ws/code/feature-engineering.html

Detect Outliers

除了缺失值之外,離群值(異常值)也是需要注意的地方。如果是 continuous 特徵,用 box plot 可以很快地發現離群值;如果是 categorical 特徵,可以 SELECT COUNT(*) ... GROUP BY 一下,然後用 bar chart 查看每個 category 的數量。取決於你所要解決的問題,異常值可能可以直接忽略,也可能需要特別對待,例如搞清楚異常值出現的原因,是資料採集時的差錯或是某種隱含的深層的因素之類的。

ref:
https://www.analyticsvidhya.com/blog/2016/01/guide-data-exploration/
https://www.slideshare.net/tw_dsconf/123-70852901

Impute Missing Values

可以利用 df.describe().show() 查看各個欄位的統計數據:countmeanstddevminmax。除了使用 df.where("some_column IS NULL") 之外,比較不同欄位的 count 差異也可以很快地發現哪些欄位有缺失值。順便觀察一下有缺失值的欄位和 target variable 有沒有什麼關聯。

這裡直接對 nullNaN 數據填充缺失值,因為以下幾個欄位都是字串類型,所以直接改成空字串,方便後續的處理。然後順便做一個 has_null 的特徵。

針對 user:

import org.apache.spark.sql.functions._

val nullableColumnNames = Array("user_name", "user_company", "user_blog", "user_location", "user_bio")

val imputedUserInfoDF = rawUserInfoDS
  .withColumn("user_has_null", when(nullableColumnNames.map(rawUserInfoDS(_).isNull).reduce(_ || _), true).otherwise(false))
  .na.fill("", nullableColumnNames)

針對 repo:

import org.apache.spark.sql.functions._

val nullableColumnNames = Array("repo_description", "repo_homepage")

val imputedRepoInfoDF = rawRepoInfoDS
  .withColumn("repo_has_null", when(nullableColumnNames.map(rawRepoInfoDS(_).isNull).reduce(_ || _), true).otherwise(false))
  .na.fill("", nullableColumnNames)

ref:
https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.sql.DataFrameNaFunctions

如果是數值類型的欄位可以考慮使用 Imputer

ref:
https://spark.apache.org/docs/latest/ml-features.html#imputer

Clean Data

針對 user,用 User-defined Function 對幾個文字欄位做一些正規化的處理:

import org.apache.spark.sql.functions._

val cleanUserInfoDF = imputedUserInfoDF
  .withColumn("user_clean_company", cleanCompanyUDF($"user_company"))
  .withColumn("user_clean_location", cleanLocationUDF($"user_location"))
  .withColumn("user_clean_bio", lower($"user_bio"))

針對 repo,過濾掉一些 repo_stargazers_count 太多和太少、description 欄位含有 "unmaintained" 或 "assignment" 等字眼的項目:

val reducedRepoInfo = imputedRepoInfoDF
  .where($"repo_is_fork" === false)
  .where($"repo_forks_count" <= 90000)
  .where($"repo_stargazers_count".between(30, 100000))

val unmaintainedWords = Array("%unmaintained%", "%no longer maintained%", "%deprecated%", "%moved to%")
val assignmentWords = Array("%assignment%", "%作業%", "%作业%")
val demoWords = Array("test", "%demo project%")
val blogWords = Array("my blog")

val cleanRepoInfoDF = reducedRepoInfo
  .withColumn("repo_clean_description", lower($"repo_description"))
  .withColumn("repo_is_unmaintained", when(unmaintainedWords.map($"repo_clean_description".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("repo_is_assignment", when(assignmentWords.map($"repo_clean_description".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("repo_is_demo", when(demoWords.map($"repo_clean_description".like(_)).reduce(_ or _) and $"repo_stargazers_count" <= 40, true).otherwise(false))
  .withColumn("repo_is_blog", when(blogWords.map($"repo_clean_description".like(_)).reduce(_ or _) and $"repo_stargazers_count" <= 40, true).otherwise(false))
  .where($"repo_is_unmaintained" === false)
  .where($"repo_is_assignment" === false)
  .where($"repo_is_demo" === false)
  .where($"repo_is_blog" === false)
  .withColumn("repo_clean_language", lower($"repo_language"))
  .withColumn("repo_clean_topics", lower($"repo_topics"))

Construct Features

針對 user,根據上述的「發想特徵」,製作出新的特徵:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val webThings = Array("web", "fullstack", "full stack")
val backendThings = Array("backend", "back end", "back-end")
val frontendThings = Array("frontend", "front end", "front-end")
val mobileThings = Array("mobile", "ios", "android")
val devopsThings = Array("devops", "sre", "admin", "infrastructure")
val dataThings = Array("machine learning", "deep learning", "data scien", "data analy")
val recsysThings = Array("data mining", "recommend", "information retrieval")

val leadTitles = Array("team lead", "architect", "creator", "director", "cto", "vp of engineering")
val scholarTitles = Array("researcher", "scientist", "phd", "professor")
val freelancerTitles = Array("freelance")
val juniorTitles = Array("junior", "beginner", "newbie")
val pmTitles = Array("product manager")

val userStarredReposCountDF = rawStarringDS
  .groupBy($"user_id")
  .agg(count("*").alias("user_starred_repos_count"))

val starringRepoInfoDF = rawStarringDS
  .select($"user_id", $"repo_id", $"starred_at")
  .join(rawRepoInfoDS, Seq("repo_id"))

val userTopLanguagesDF = starringRepoInfoDF
  .withColumn("rank", rank.over(Window.partitionBy($"user_id").orderBy($"starred_at".desc)))
  .where($"rank" <= 50)
  .groupBy($"user_id")
  .agg(collect_list(lower($"repo_language")).alias("user_recent_repo_languages"))
  .select($"user_id", $"user_recent_repo_languages")

val userTopTopicsDF = starringRepoInfoDF
  .where($"repo_topics" =!= "")
  .withColumn("rank", rank.over(Window.partitionBy($"user_id").orderBy($"starred_at".desc)))
  .where($"rank" <= 50)
  .groupBy($"user_id")
  .agg(concat_ws(",", collect_list(lower($"repo_topics"))).alias("temp_user_recent_repo_topics"))
  .select($"user_id", split($"temp_user_recent_repo_topics", ",").alias("user_recent_repo_topics"))

val userTopDescriptionDF = starringRepoInfoDF
  .where($"repo_description" =!= "")
  .withColumn("rank", rank.over(Window.partitionBy($"user_id").orderBy($"starred_at".desc)))
  .where($"rank" <= 50)
  .groupBy($"user_id")
  .agg(concat_ws(" ", collect_list(lower($"repo_description"))).alias("user_recent_repo_descriptions"))
  .select($"user_id", $"user_recent_repo_descriptions")

val constructedUserInfoDF = cleanUserInfoDF
  .withColumn("user_knows_web", when(webThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_knows_backend", when(backendThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_knows_frontend", when(frontendThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_knows_mobile", when(mobileThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_knows_devops", when(devopsThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_knows_data", when(dataThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_knows_recsys", when(recsysThings.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_is_lead", when(leadTitles.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_is_scholar", when(scholarTitles.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_is_freelancer", when(freelancerTitles.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_is_junior", when(juniorTitles.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_is_pm", when(pmTitles.map($"user_clean_bio".like(_)).reduce(_ or _), true).otherwise(false))
  .withColumn("user_followers_following_ratio", round($"user_followers_count" / ($"user_following_count" + lit(1.0)), 3))
  .withColumn("user_days_between_created_at_today", datediff(current_date(), $"user_created_at"))
  .withColumn("user_days_between_updated_at_today", datediff(current_date(), $"user_updated_at"))
  .join(userStarredReposCountDF, Seq("user_id"))
  .withColumn("user_avg_daily_starred_repos_count", round($"user_starred_repos_count" / ($"user_days_between_created_at_today" + lit(1.0)), 3))
  .join(userTopDescriptionDF, Seq("user_id"))
  .join(userTopTopicsDF, Seq("user_id"))
  .join(userTopLanguagesDF, Seq("user_id"))

針對 repo,根據上述的「發想特徵」,製作出新的特徵,意思到了就好:

import org.apache.spark.sql.functions._

val vintaStarredRepos = rawStarringDS
  .where($"user_id" === 652070)
  .select($"repo_id".as[Int])
  .collect()
  .to[List]

val constructedRepoInfoDF = cleanRepoInfoDF
  .withColumn("repo_has_activities_in_60days", datediff(current_date(), $"repo_pushed_at") <= 60)
  .withColumn("repo_has_homepage", when($"repo_homepage" === "", false).otherwise(true))
  .withColumn("repo_is_vinta_starred", when($"repo_id".isin(vintaStarredRepos: _*), true).otherwise(false))
  .withColumn("repo_days_between_created_at_today", datediff(current_date(), $"repo_created_at"))
  .withColumn("repo_days_between_updated_at_today", datediff(current_date(), $"repo_updated_at"))
  .withColumn("repo_days_between_pushed_at_today", datediff(current_date(), $"repo_pushed_at"))
  .withColumn("repo_subscribers_stargazers_ratio", round($"repo_subscribers_count" / ($"repo_stargazers_count" + lit(1.0)), 3))
  .withColumn("repo_forks_stargazers_ratio", round($"repo_forks_count" / ($"repo_stargazers_count" + lit(1.0)), 3))
  .withColumn("repo_open_issues_stargazers_ratio", round($"repo_open_issues_count" / ($"repo_stargazers_count" + lit(1.0)), 3))
  .withColumn("repo_text", lower(concat_ws(" ", $"repo_owner_username", $"repo_name", $"repo_language", $"repo_description")))

ref:
https://databricks.com/blog/2015/09/16/apache-spark-1-5-dataframe-api-highlights.html

Convert Features

針對 user,這裡主要是對一些 categorical 特徵作 binning:

import org.apache.spark.sql.functions._

val companyCountDF = cleanUserInfoDF
  .groupBy($"user_clean_company")
  .agg(count("*").alias("count_per_user_company"))

val locationCountDF = cleanUserInfoDF
  .groupBy($"user_clean_location")
  .agg(count("*").alias("count_per_user_location"))

val transformedUserInfoDF = constructedUserInfoDF
  .join(companyCountDF, Seq("user_clean_company"))
  .join(locationCountDF, Seq("user_clean_location"))
  .withColumn("user_has_blog", when($"user_blog" === "", 0.0).otherwise(1.0))
  .withColumn("user_binned_company", when($"count_per_user_company" <= 5, "__other").otherwise($"user_clean_company"))
  .withColumn("user_binned_location", when($"count_per_user_location" <= 50, "__other").otherwise($"user_clean_location"))

針對 repo:

import org.apache.spark.sql.functions._

val languagesDF = cleanRepoInfoDF
  .groupBy($"repo_clean_language")
  .agg(count("*").alias("count_per_repo_language"))
  .select($"repo_clean_language", $"count_per_repo_language")
  .cache()

val transformedRepoInfoDF = constructedRepoInfoDF
  .join(languagesDF, Seq("repo_clean_language"))
  .withColumn("repo_binned_language", when($"count_per_repo_language" <= 30, "__other").otherwise($"repo_clean_language"))
  .withColumn("repo_clean_topics", split($"repo_topics", ","))

ref:
https://docs.databricks.com/spark/latest/mllib/binary-classification-mllib-pipelines.html

Prepare the Feature Pipeline

我們過濾掉那些打星了超多 repo 的用戶。從收集到的數據發現,有些用戶甚至打星了一兩萬個 repo,這些用戶可能是個爬蟲專用帳號或是他看到什麼就打星什麼,推薦系統對這樣的用戶來說可能沒什麼意義,還不如從數據集中拿掉。

import org.apache.spark.sql.functions._

val maxStarredReposCount = 2000

val userStarredReposCountDF = rawStarringDS
  .groupBy($"user_id")
  .agg(count("*").alias("user_starred_repos_count"))

val reducedStarringDF = rawStarringDS
  .join(userStarredReposCountDF, Seq("user_id"))
  .where($"user_starred_repos_count" <= maxStarredReposCount)
  .select($"user_id", $"repo_id", $"starred_at", $"starring")

val profileStarringDF = reducedStarringDF
  .join(userProfileDF, Seq("user_id"))
  .join(repoProfileDF, Seq("repo_id"))

Build the Feature Pipeline

把處理特徵的一連串流程寫成 Spark ML Pipeline,方便抽換或是加入新的 Transformer,例如 Standardization、One-hot Encoding 和 Word2Vec,也把 ALS 模型的預測值做為其中一項特徵。

import org.apache.spark.ml.feature._
import org.apache.spark.ml.recommendation.ALSModel
import ws.vinta.albedo.transformers.UserRepoTransformer

val profileStarringDF = reducedStarringDF
  .join(userProfileDF, Seq("user_id"))
  .join(repoProfileDF, Seq("repo_id"))
  .cache()

categoricalColumnNames += "user_id"
categoricalColumnNames += "repo_id"

val userRepoTransformer = new UserRepoTransformer()
  .setInputCols(Array("repo_language", "user_recent_repo_languages"))

continuousColumnNames += "repo_language_index_in_user_recent_repo_languages"
continuousColumnNames += "repo_language_count_in_user_recent_repo_languages"

val alsModelPath = s"${settings.dataDir}/${settings.today}/alsModel.parquet"
val alsModel = ALSModel.load(alsModelPath)
  .setUserCol("user_id")
  .setItemCol("repo_id")
  .setPredictionCol("als_score")
  .setColdStartStrategy("drop")

continuousColumnNames += "als_score"

val categoricalTransformers = categoricalColumnNames.flatMap((columnName: String) => {
  val stringIndexer = new StringIndexer()
    .setInputCol(columnName)
    .setOutputCol(s"${columnName}__idx")
    .setHandleInvalid("keep")

  val oneHotEncoder = new OneHotEncoder()
    .setInputCol(s"${columnName}__idx")
    .setOutputCol(s"${columnName}__ohe")
    .setDropLast(false)

  Array(stringIndexer, oneHotEncoder)
})

val listTransformers = listColumnNames.flatMap((columnName: String) => {
  val countVectorizerModel = new CountVectorizer()
    .setInputCol(columnName)
    .setOutputCol(s"${columnName}__cv")
    .setMinDF(10)
    .setMinTF(1)

  Array(countVectorizerModel)
})

val textTransformers = textColumnNames.flatMap((columnName: String) => {
  val hanLPTokenizer = new HanLPTokenizer()
    .setInputCol(columnName)
    .setOutputCol(s"${columnName}__words")
    .setShouldRemoveStopWords(true)

  val stopWordsRemover = new StopWordsRemover()
    .setInputCol(s"${columnName}__words")
    .setOutputCol(s"${columnName}__filtered_words")
    .setStopWords(StopWordsRemover.loadDefaultStopWords("english"))
  val word2VecModelPath = s"${settings.dataDir}/${settings.today}/word2VecModel.parquet"
  val word2VecModel = Word2VecModel.load(word2VecModelPath)
    .setInputCol(s"${columnName}__filtered_words")
    .setOutputCol(s"${columnName}__w2v")

  Array(hanLPTokenizer, stopWordsRemover, word2VecModel)
})

val finalBooleanColumnNames = booleanColumnNames.toArray
val finalContinuousColumnNames = continuousColumnNames.toArray
val finalCategoricalColumnNames = categoricalColumnNames.map(columnName => s"${columnName}__ohe").toArray
val finalListColumnNames = listColumnNames.map(columnName => s"${columnName}__cv").toArray
val finalTextColumnNames = textColumnNames.map(columnName => s"${columnName}__w2v").toArray
val vectorAssembler = new SimpleVectorAssembler()
  .setInputCols(finalBooleanColumnNames ++ finalContinuousColumnNames ++ finalCategoricalColumnNames ++ finalListColumnNames ++ finalTextColumnNames)
  .setOutputCol("features")

val featureStages = mutable.ArrayBuffer.empty[PipelineStage]
featureStages += userRepoTransformer
featureStages += alsModel
featureStages ++= categoricalTransformers
featureStages ++= listTransformers
featureStages ++= textTransformers
featureStages += vectorAssembler

val featurePipeline = new Pipeline().setStages(featureStages.toArray)
val featurePipelineModel = featurePipeline.fit(profileStarringDF)

ref:
https://spark.apache.org/docs/latest/ml-pipeline.html
https://spark.apache.org/docs/latest/ml-features.html

Handle Imbalanced Data

因為我們要訓練一個 Binary Classification 二元分類模型,會同時需要 positive(正樣本)和 negative(負樣本)。但是我們的原始數據 rawStarringDS 都是正樣本,也就是說我們只有「用戶有對哪些 repo 打星的資料」(正樣本),卻沒有「用戶沒有對哪些 repo 打星的資料」(負樣本)。我們當然是可以用「所有用戶沒有打星的 repo 做為負樣本」,但是考慮到這種做法產生的負樣本的數量實在太大,而且也不太合理,因為那些用戶沒有打星的 repo 不見得是因為他不喜歡,可能只是因為他不知道有那個 repo 存在。

我們後來採用的做法是「用熱門但是用戶沒有打星的 repo 做為負樣本」,我們寫了一個 Spark Transformer 來做這件事:

import ws.vinta.albedo.transformers.NegativeBalancer

import scala.collection.mutable

val sc = spark.sparkContext

val popularReposDS = loadPopularRepoDF()
val popularRepos = popularReposDS
  .select($"repo_id".as[Int])
  .collect()
  .to[mutable.LinkedHashSet]
val bcPopularRepos = sc.broadcast(popularRepos)

val negativeBalancer = new NegativeBalancer(bcPopularRepos)
  .setUserCol("user_id")
  .setItemCol("repo_id")
  .setTimeCol("starred_at")
  .setLabelCol("starring")
  .setNegativeValue(0.0)
  .setNegativePositiveRatio(2.0)
val balancedStarringDF = negativeBalancer.transform(reducedStarringDF)

ref:
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/evaluators/RankingEvaluator.scala
http://www.kdnuggets.com/2017/06/7-techniques-handle-imbalanced-data.html

Split Data

直接使用 holdout 的方式,隨機分配不同的 row 到 training set 和 test set。其他的做法可能是根據時間來拆分,用以前的數據來預測之後的行為。

val profileBalancedStarringDF = balancedStarringDF
  .join(userProfileDF, Seq("user_id"))
  .join(repoProfileDF, Seq("repo_id"))

val tmpDF = featurePipelineModel.transform(profileBalancedStarringDF)
val keepColumnName = tmpDF.columns.filter((columnName: String) => {
  !columnName.endsWith("__idx") &&
  !columnName.endsWith("__ohe") &&
  !columnName.endsWith("__cv") &&
  !columnName.endsWith("__words") &&
  !columnName.endsWith("__filtered_words") &&
  !columnName.endsWith("__w2v")
})
val featuredBalancedStarringDF = tmpDF.select(keepColumnName.map(col): _*)

val Array(trainingFeaturedDF, testFeaturedDF) = featuredBalancedStarringDF.randomSplit(Array(0.9, 0.1))

Build the Model Pipeline

為了方便之後的擴充性,這裡也使用 Spark ML Pipeline 的寫法。Spark ML 的 LogisticRegression 可以額外設置一個 weightCol 來調整不同 row 的權重。

import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.{Pipeline, PipelineStage}

import scala.collection.mutable

val weightSQL = """
SELECT *,
       1.0 AS default_weight,
       IF (starring = 1.0, 0.9, 0.1) AS positive_weight,
       IF (starring = 1.0 AND datediff(current_date(), starred_at) <= 365, 0.9, 0.1) AS recent_starred_weight
FROM __THIS__
""".stripMargin
val weightTransformer = new SQLTransformer()
  .setStatement(weightSQL)

val lr = new LogisticRegression()
  .setMaxIter(200)
  .setRegParam(0.7)
  .setElasticNetParam(0.0)
  .setStandardization(true)
  .setLabelCol("starring")
  .setFeaturesCol("standard_features")
  .setWeightCol("recent_starred_weight")

val modelStages = mutable.ArrayBuffer.empty[PipelineStage]
modelStages += weightTransformer
modelStages += lr

val modelPipeline = new Pipeline().setStages(modelStages.toArray)
val modelPipelineModel = modelPipeline.fit(trainingFeaturedDF)

ref:
https://spark.apache.org/docs/latest/ml-classification-regression.html

Evaluate the Model: Classification

因為 Logistic Regression 是二元分類模型,所以我們可以用 Spark ML 的 BinaryClassificationEvaluator 來評估結果。不過因為我們做的是推薦系統,真正在乎的是 Top N 的排序問題,所以這裡的 AUC 的數值參考一下就好。

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator

val testRankedDF = modelPipelineModel.transform(testFeaturedDF)

val binaryClassificationEvaluator = new BinaryClassificationEvaluator()
  .setMetricName("areaUnderROC")
  .setRawPredictionCol("rawPrediction")
  .setLabelCol("starring")

val classificationMetric = binaryClassificationEvaluator.evaluate(testRankedDF)
println(s"${binaryClassificationEvaluator.getMetricName} = $classificationMetric")
// areaUnderROC = 0.9450631491281277

ref:
https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
https://docs.databricks.com/spark/latest/mllib/binary-classification-mllib-pipelines.html

Generate Candidates

推薦系統的另外一個重要部分就是產生候選物品集,這裡我們使用以下幾種方式:

  • ALS: 協同過濾的推薦
  • Content-based: 基於內容的推薦
  • Popularity: 基於熱門的推薦

不過因為這篇文章的主題是排序和特徵工程的 Machine Learning Pipeline,所以產生候選物品集的部分就不多說了,有興趣的人可以直接看底下連結的 source code 或是這個系列的其他文章。

import ws.vinta.albedo.recommenders.ALSRecommender
import ws.vinta.albedo.recommenders.ContentRecommender
import ws.vinta.albedo.recommenders.PopularityRecommender

val topK = 30

val alsRecommender = new ALSRecommender()
  .setUserCol("user_id")
  .setItemCol("repo_id")
  .setTopK(topK)

val contentRecommender = new ContentRecommender()
  .setUserCol("user_id")
  .setItemCol("repo_id")
  .setTopK(topK)
  .setEnableEvaluationMode(true)

val popularityRecommender = new PopularityRecommender()
  .setUserCol("user_id")
  .setItemCol("repo_id")
  .setTopK(topK)

val recommenders = mutable.ArrayBuffer.empty[Recommender]
recommenders += alsRecommender
recommenders += contentRecommender
recommenders += popularityRecommender

val candidateDF = recommenders
  .map((recommender: Recommender) => recommender.recommendForUsers(testUserDF))
  .reduce(_ union _)
  .select($"user_id", $"repo_id")
  .distinct()

// 每個 Recommender 的結果類似這樣:
// +-------+-------+----------+------+
// |user_id|repo_id|score     |source|
// +-------+-------+----------+------+
// |652070 |1239728|0.6731846 |als   |
// |652070 |854078 |0.7187486 |als   |
// |652070 |1502338|0.70165294|als   |
// |652070 |1184678|0.7434903 |als   |
// |652070 |547708 |0.7956538 |als   |
// +-------+-------+----------+------+

ref:
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/recommenders/ALSRecommender.scala
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/recommenders/ContentRecommender.scala
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/recommenders/PopularityRecommender.scala

Predict the Ranking

把這些候選物品集丟給我們訓練好的 Logistic Regression 模型來排序。結果中的 probability 欄位的第 0 項表示結果為 0 的機率(negative)、第 1 項表示結果為 1 的機率(positive)。

val profileCandidateDF = candidateDF
  .join(userProfileDF, Seq("user_id"))
  .join(repoProfileDF, Seq("repo_id"))

val featuredCandidateDF = featurePipelineModel
  .transform(profileCandidateDF)

val rankedCandidateDF = modelPipelineModel
  .transform(featuredCandidateDF)

// rankedCandidateDF 的結果類似這樣:
// +-------+--------+----------+----------------------------------------+
// |user_id|repo_id |prediction|probability                             |
// +-------+--------+----------+----------------------------------------+
// |652070 |83467664|1.0       |[0.12711894229094317,0.8728810577090568]|
// |652070 |55099616|1.0       |[0.1422859437320775,0.8577140562679224] |
// |652070 |42266235|1.0       |[0.1462014853157966,0.8537985146842034] |
// |652070 |78012800|1.0       |[0.15576081067098502,0.844239189329015] |
// |652070 |5928761 |1.0       |[0.16149848941925066,0.8385015105807493]|
// +-------+--------+----------+----------------------------------------+

ref:
https://stackoverflow.com/questions/37903288/what-do-colum-rawprediction-and-probability-of-dataframe-mean-in-spark-mllib

Evaluate the Model: Ranking

最後我們使用 Information Retrieval 領域中用來評價排序能力的指標 NDCG (Normalized Discounted Cumulative Gain) 來評估排序的結果。Spark MLlib 有現成的 RankingMetrics 可以用,但是它只適用於 RDD-based 的 API,所以我們改寫成適合 DataFrame-based 的 Evaluator

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import ws.vinta.albedo.evaluators.RankingEvaluator

val userActualItemsDF = reducedStarringDF
  .withColumn("rank", rank().over(Window.partitionBy($"user_id").orderBy($"starred_at".desc)))
  .where($"rank" <= topK)
  .groupBy($"user_id")
  .agg(collect_list($"repo_id").alias("items"))

val userPredictedItemsDF = rankedCandidateDF
  .withColumn("rank", rank().over(Window.partitionBy($"user_id").orderBy(toArrayUDF($"probability").getItem(1).desc)))
  .where($"rank" <= topK)
  .groupBy($"user_id")
  .agg(collect_list($"repo_id").alias("items"))

val rankingEvaluator = new RankingEvaluator(userActualItemsDF)
  .setMetricName("NDCG@k")
  .setK(topK)
  .setUserCol("user_id")
  .setItemsCol("items")
val rankingMetric = rankingEvaluator.evaluate(userPredictedItemsDF)
println(s"${rankingEvaluator.getFormattedMetricName} = $rankingMetric")
// NDCG@30 = 0.021114356461615493

ref:
https://github.com/vinta/albedo/blob/master/src/main/scala/ws/vinta/albedo/evaluators/RankingEvaluator.scala
https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#ranking-systems
https://weekly.codetengu.com/issues/83#kOxuVxW