{"id":394,"date":"2017-05-20T06:20:33","date_gmt":"2017-05-19T22:20:33","guid":{"rendered":"http:\/\/vinta.ws\/code\/?p=394"},"modified":"2026-02-18T01:20:35","modified_gmt":"2026-02-17T17:20:35","slug":"build-a-recommender-system-with-pyspark-implicit-als","status":"publish","type":"post","link":"https:\/\/vinta.ws\/code\/build-a-recommender-system-with-pyspark-implicit-als.html","title":{"rendered":"Build a recommender system with Spark: Implicit ALS"},"content":{"rendered":"<p>\u5728\u9019\u500b\u7cfb\u5217\u7684\u6587\u7ae0\u88e1\uff0c\u6211\u5011\u5c07\u4f7f\u7528 Apache Spark\u3001XGBoost\u3001Elasticsearch \u548c MySQL \u7b49\u5de5\u5177\u4f86\u642d\u5efa\u4e00\u500b\u63a8\u85a6\u7cfb\u7d71\u7684 Machine Learning Pipeline\u3002\u63a8\u85a6\u7cfb\u7d71\u7684\u7d44\u6210\u53ef\u4ee5\u7c97\u7565\u5730\u5206\u6210 Candidate Generation \u548c Ranking \u5169\u500b\u90e8\u5206\uff0c\u524d\u8005\u662f\u91dd\u5c0d\u7528\u6236\u7522\u751f\u5019\u9078\u7269\u54c1\u96c6\uff0c\u5e38\u7528\u7684\u65b9\u6cd5\u6709 Collaborative Filtering\u3001Content-based\u3001\u6a19\u7c64\u914d\u5c0d\u3001\u71b1\u9580\u6392\u884c\u6216\u4eba\u5de5\u7cbe\u9078\u7b49\uff1b\u5f8c\u8005\u5247\u662f\u5c0d\u9019\u4e9b\u5019\u9078\u7269\u54c1\u6392\u5e8f\uff0c\u4ee5 Top N \u7684\u65b9\u5f0f\u5448\u73fe\u6700\u7d42\u7684\u63a8\u85a6\u7d50\u679c\uff0c\u5e38\u7528\u7684\u65b9\u6cd5\u6709 Logistic Regression\u3002<\/p>\n<p>\u5728\u672c\u7bc7\u6587\u7ae0\u4e2d\uff0c\u6211\u5011\u5c07\u4ee5 Candidate Generation \u968e\u6bb5\u5e38\u7528\u7684\u65b9\u6cd5\u4e4b\u4e00\uff1aCollaborative Filtering \u5354\u540c\u904e\u6ffe\u6f14\u7b97\u6cd5\u70ba\u4f8b\uff0c\u5229\u7528 Apache Spark \u7684 ALS (Alternating Least Squares) \u6a21\u578b\u5efa\u7acb\u4e00\u500b GitHub repositories \u7684\u63a8\u85a6\u7cfb\u7d71\uff0c\u4ee5\u7528\u6236\u5c0d repo \u7684\u6253\u661f\u7d00\u9304\u4f5c\u70ba\u8a13\u7df4\u6578\u64da\uff0c\u63a8\u85a6\u51fa\u7528\u6236\u53ef\u80fd\u6703\u611f\u8208\u8da3\u7684\u5176\u4ed6 repo \u4f5c\u70ba\u5019\u9078\u7269\u54c1\u96c6\u3002<\/p>\n<p>\u5b8c\u6574\u7684\u7a0b\u5f0f\u78bc\u53ef\u4ee5\u5728 <a href=\"https:\/\/github.com\/vinta\/albedo\">https:\/\/github.com\/vinta\/albedo<\/a> \u627e\u5230\u3002<\/p>\n<p>\u7cfb\u5217\u6587\u7ae0\uff1a<\/p>\n<ul>\n<li><a href=\"https:\/\/vinta.ws\/code\/build-a-recommender-system-with-pyspark-implicit-als.html\">Build a recommender system with Spark: Implicit ALS<\/a><\/li>\n<li><a href=\"https:\/\/vinta.ws\/code\/build-a-recommender-system-with-spark-and-elasticsearch-content-based.html\">Build a recommender system with Spark: Content-based and Elasticsearch<\/a><\/li>\n<li><a href=\"https:\/\/vinta.ws\/code\/build-a-recommender-system-with-spark-logistic-regression.html\">Build a recommender system with Spark: Logistic Regression<\/a><\/li>\n<li><a href=\"https:\/\/vinta.ws\/code\/feature-engineering.html\">Feature Engineering \u7279\u5fb5\u5de5\u7a0b\u4e2d\u5e38\u898b\u7684\u65b9\u6cd5<\/a><\/li>\n<li><a href=\"https:\/\/vinta.ws\/code\/spark-ml-cookbook-scala.html\">Spark ML cookbook (Scala)<\/a><\/li>\n<li><a href=\"https:\/\/vinta.ws\/code\/spark-sql-cookbook-scala.html\">Spark SQL cookbook (Scala)<\/a><\/li>\n<li>\u4e0d\u5b9a\u671f\u66f4\u65b0\u4e2d<\/li>\n<\/ul>\n<h2>Submit the Application<\/h2>\n<p>\u56e0\u70ba\u9700\u8981\u4f7f\u7528 JDBC \u8b80\u53d6 MySQL \u8cc7\u6599\u5eab\uff0c\u5fc5\u9808\u5b89\u88dd MySQL driver\uff0c\u53ef\u4ee5\u900f\u904e <code>--packages &quot;mysql:mysql-connector-java:5.1.41&quot;<\/code> \u53c3\u6578\u5728 cluster \u7684\u6bcf\u4e00\u53f0\u6a5f\u5668\u4e0a\u5b89\u88dd\u9700\u8981\u7684 Java packages\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-bash\">$ spark-submit \n--packages \"com.github.fommil.netlib:all:1.1.2,mysql:mysql-connector-java:5.1.41\" \n--master spark:\/\/YOUR_SPARK_MASTER:7077 \n--py-files deps.zip \ntrain_als.py -u vinta<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/submitting-applications.html\">https:\/\/spark.apache.org\/docs\/latest\/submitting-applications.html<\/a><\/p>\n<h2>Load Data<\/h2>\n<p>\u8b80\u53d6\u4f86\u81ea MySQL \u8cc7\u6599\u5eab\u7684\u6578\u64da\u3002\u4f60\u53ef\u4ee5\u4f7f\u7528 <code>predicates<\/code> \u53c3\u6578\u4f86\u6307\u5b9a <code>WHERE<\/code> \u689d\u4ef6\uff0c\u96d6\u7136\u56b4\u683c\u4f86\u8aaa\u9019\u500b\u53c3\u6578\u662f\u7528\u4f86\u63a7\u5236 partition \u6578\u91cf\u7684\uff0c\u4e00\u500b\u689d\u4ef6\u5c31\u662f\u4e00\u500b partition\u3002<\/p>\n<p>\u5047\u8a2d <code>app_repostarring<\/code> \u7684\u6b04\u4f4d\u5982\u4e0b\uff1a<\/p>\n<pre class=\"line-numbers\"><code class=\"language-sql\">CREATE TABLE <code>app_repostarring<\/code> (\n  <code>id<\/code> int(11) NOT NULL AUTO_INCREMENT,\n  <code>from_user_id<\/code> int(11) NOT NULL,\n  <code>from_username<\/code> varchar(39) NOT NULL,\n  <code>repo_owner_id<\/code> int(11) NOT NULL,\n  <code>repo_owner_username<\/code> varchar(39) NOT NULL,\n  <code>repo_owner_type<\/code> varchar(16) NOT NULL,\n  <code>repo_id<\/code> int(11) NOT NULL,\n  <code>repo_name<\/code> varchar(100) NOT NULL,\n  <code>repo_full_name<\/code> varchar(140) NOT NULL,\n  <code>repo_description<\/code> varchar(191) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci DEFAULT NULL,\n  <code>repo_language<\/code> varchar(32) NOT NULL,\n  <code>repo_created_at<\/code> datetime(6) NOT NULL,\n  <code>repo_updated_at<\/code> datetime(6) NOT NULL,\n  <code>starred_at<\/code> datetime(6) NOT NULL,\n  <code>stargazers_count<\/code> int(11) NOT NULL,\n  <code>forks_count<\/code> int(11) NOT NULL,\n  PRIMARY KEY (<code>id<\/code>),\n  UNIQUE KEY <code>from_user_id_repo_id<\/code> (<code>full_name<\/code>, <code>repo_id<\/code>)\n);<\/code><\/pre>\n<pre class=\"line-numbers\"><code class=\"language-py\">def loadRawData():\n    url = 'jdbc:mysql:\/\/127.0.0.1:3306\/albedo?user=root&amp;password=123&amp;verifyServerCertificate=false&amp;useSSL=false'\n    properties = {'driver': 'com.mysql.jdbc.Driver'}\n    rawDF = spark.read.jdbc(url, table='app_repostarring', properties=properties)\n    return rawDF\n\nrawDF = loadRawData()<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/api\/python\/pyspark.sql.html?highlight=jdbc#pyspark.sql.DataFrameReader.jdbc\">https:\/\/spark.apache.org\/docs\/latest\/api\/python\/pyspark.sql.html?highlight=jdbc#pyspark.sql.DataFrameReader.jdbc<\/a><br \/>\n<a href=\"http:\/\/www.gatorsmile.io\/numpartitionsinjdbc\/\">http:\/\/www.gatorsmile.io\/numpartitionsinjdbc\/<\/a><\/p>\n<h2>Preprocess Data<\/h2>\n<h3>Format Data<\/h3>\n<p>\u628a raw data \u6574\u7406\u6210 <code>user,item,rating,starred_at<\/code> \u9019\u6a23\u7684\u683c\u5f0f\u3002<code>starred_at<\/code> \u53ea\u6709\u8a55\u50f9 model \u6642\u6709\u7528\u4f86\u6392\u5e8f\uff0c\u8a13\u7df4 model \u6642\u4e26\u6c92\u6709\u7528\u5230\uff0c\u56e0\u70ba Spark \u7684 ALS \u6c92\u8fa6\u6cd5\u8f15\u6613\u5730\u6574\u5408 side information\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.ml import Transformer\n\nclass RatingBuilder(Transformer):\n\n    def _transform(self, rawDF):\n        ratingDF = rawDF \n            .selectExpr('from_user_id AS user', 'repo_id AS item', '1 AS rating', 'starred_at') \n            .orderBy('user', F.col('starred_at').desc())\n        return ratingDF\n\nratingBuilder = RatingBuilder()\nratingDF = ratingBuilder.transform(rawDF)\nratingDF.cache()<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"http:\/\/blog.ethanrosenthal.com\/2016\/11\/07\/implicit-mf-part-2\/\">http:\/\/blog.ethanrosenthal.com\/2016\/11\/07\/implicit-mf-part-2\/<\/a><\/p>\n<h3>Inspect Data<\/h3>\n<pre class=\"line-numbers\"><code class=\"language-py\">import pyspark.sql.functions as F\n\nratingDF.rdd.getNumPartitions()\n# 200\n\nratingDF.agg(F.count('rating'), F.countDistinct('user'), F.countDistinct('item')).show()\n# +-------------+--------------------+--------------------+\n# |count(rating)|count(DISTINCT user)|count(DISTINCT item)|\n# +-------------+--------------------+--------------------+\n# |      3121629|               10483|              551216|\n# +-------------+--------------------+--------------------+\n\nstargazersCountDF = ratingDF \n    .groupBy('item') \n    .agg(F.count('user').alias('stargazers_count')) \n    .orderBy('stargazers_count', ascending=False)\nstargazersCountDF.show(10)\n# +--------+----------------+\n# |    item|stargazers_count|\n# +--------+----------------+\n# | 2126244|            2211|\n# |10270250|            1683|\n# |  943149|            1605|\n# |  291137|            1567|\n# |13491895|            1526|\n# | 9384267|            1480|\n# | 3544424|            1468|\n# | 7691631|            1441|\n# |29028775|            1427|\n# | 1334369|            1399|\n# +--------+----------------+\n\nstarredCountDF = ratingDF \n    .groupBy('user') \n    .agg(F.count('item').alias('starred_count')) \n    .orderBy('starred_count', ascending=False)\nstarredCountDF.show(10)\n# +-------+-------------+\n# |   user|starred_count|\n# +-------+-------------+\n# |3947125|         8947|\n# |5527642|         7978|\n# | 446613|         7860|\n# | 627410|         7800|\n# |  13998|         6334|\n# |2467194|         6327|\n# |  63402|         6034|\n# |2005841|         6024|\n# |5073946|         5980|\n# |   2296|         5862|\n# +-------+-------------+<\/code><\/pre>\n<h3>Clean Data<\/h3>\n<p>\u4f60\u53ef\u4ee5\u904e\u6ffe\u6389\u90a3\u4e9b\u592a\u5c11 user \u6253\u661f\u7684 item \u548c\u6253\u661f\u4e86\u592a\u5c11 item \u7684 user\uff0c\u63d0\u6607\u77e9\u9663\u7684\u7a20\u5bc6\u5ea6\u3002\u9019\u500b\u73fe\u8c61\u4e5f\u6b63\u597d\u662f Cold Start \u7684\u554f\u984c\uff0c\u4f60\u5c31\u662f\u6c92\u6709\u8db3\u5920\u591a\u7684\u95dc\u65bc\u9019\u4e9b item \u548c user \u7684\u6578\u64da\uff08\u53ef\u4ee5\u8003\u616e\u4f7f\u7528 content-based \u7684\u63a8\u85a6\u65b9\u5f0f\uff09\u3002\u9664\u6b64\u4e4b\u5916\uff0c\u5982\u679c\u4f60\u7684\u63a8\u85a6\u7cfb\u7d71\u6240\u63a8\u85a6\u7684 item \u53ea\u6709\u975e\u5e38\u5c11\u4eba\u6253\u661f\uff0c\u5373\u4fbf\u4f60\u5b8c\u7f8e\u5730\u6316\u6398\u4e86\u9577\u5c3e\u6548\u61c9\uff0c\u9019\u6a23\u7684\u63a8\u85a6\u7d50\u679c\u7d66\u7528\u6236\u7684\u300c\u7b2c\u4e00\u5370\u8c61\u300d\u53ef\u80fd\u4e5f\u4e0d\u6703\u592a\u597d\uff08\u9019\u53ef\u80fd\u6c7a\u5b9a\u4e86\u4ed6\u8981\u4e0d\u8981\u7e7c\u7e8c\u4f7f\u7528\u9019\u500b\u7cfb\u7d71\u6216\u662f\u4ed6\u8981\u4e0d\u8981\u771f\u7684\u53bb\u5617\u8a66\u90a3\u500b\u4f60\u63a8\u85a6\u7d66\u4ed6\u7684\u6771\u897f\uff09\u3002<\/p>\n<p>\u4f60\u4e5f\u53ef\u4ee5\u9078\u64c7\u8981\u4e0d\u8981\u904e\u6ffe\u6389\u90a3\u4e9b\u8d85\u591a\u4eba\u6253\u661f\u7684 item \u548c\u6253\u661f\u4e86\u8d85\u591a item \u7684 user\u3002\u5982\u679c\u67d0\u4e9b item \u6709\u8d85\u904e\u516b\u3001\u4e5d\u6210\u7684 user \u90fd\u6253\u661f\u4e86\uff0c\u5c0d\u65bc\u9019\u9ebc\u71b1\u9580\u7684 item\uff0c\u53ef\u80fd\u4e5f\u6c92\u6709\u63a8\u85a6\u7684\u5fc5\u8981\u4e86\uff0c\u56e0\u70ba\u5176\u4ed6 user \u65e9\u665a\u4e5f\u6703\u81ea\u5df1\u767c\u73fe\u7684\uff1b\u5982\u679c\u6709\u5c11\u6578\u7684 user \u5e7e\u4e4e\u6253\u661f\u4e86\u4e00\u534a\u4ee5\u4e0a\u7684 item\uff0c\u9019\u4e9b user \u53ef\u80fd\u662f\u5c6c\u65bc\u67d0\u7a2e web crawler \u7684\u7528\u9014\u6216\u662f\u9019\u4e9b user \u5c31\u662f\u90a3\u7a2e\u770b\u5230\u4ec0\u9ebc\u5c31\u6253\u661f\u4ec0\u9ebc\u7684\u4eba\uff0c\u7121\u8ad6\u662f\u54ea\u4e00\u7a2e\uff0c\u4ed6\u5011\u53ef\u80fd\u90fd\u4e0d\u662f\u4f60\u60f3\u8981 modeling \u7684\u5c0d\u8c61\uff0c\u53ef\u4ee5\u8003\u616e\u5f9e dataset \u4e2d\u62ff\u6389\u3002<\/p>\n<p>\u5be6\u52d9\u4e0a\uff0c\u5982\u679c\u4f60\u6709\u95dc\u65bc user \u6216 item \u7684\u9ed1\u540d\u55ae\uff0c\u4f8b\u5982\u4e00\u4e9b SPAM \u5e33\u865f\u6216 NSFW \u7684\u5167\u5bb9\u7b49\uff0c\u4e5f\u53ef\u4ee5\u5728\u9019\u500b\u6b65\u9a5f\u628a\u5b83\u5011\u904e\u6ffe\u6389\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark import keyword_only\nfrom pyspark.ml import Transformer\nfrom pyspark.ml.param.shared import Param\nimport pyspark.sql.functions as F\n\nclass DataCleaner(Transformer):\n\n    @keyword_only\n    def __init__(self, minItemStargazersCount=None, maxItemStargazersCount=None, minUserStarredCount=None, maxUserStarredCount=None):\n        super(DataCleaner, self).__init__()\n        self.minItemStargazersCount = Param(self, 'minItemStargazersCount', '\u79fb\u9664 stargazer \u6578\u4f4e\u65bc\u9019\u500b\u6578\u5b57\u7684 item')\n        self.maxItemStargazersCount = Param(self, 'maxItemStargazersCount', '\u79fb\u9664 stargazer \u6578\u8d85\u904e\u9019\u500b\u6578\u5b57\u7684 item')\n        self.minUserStarredCount = Param(self, 'minUserStarredCount', '\u79fb\u9664 starred repo \u6578\u4f4e\u65bc\u9019\u500b\u6578\u5b57\u7684 user')\n        self.maxUserStarredCount = Param(self, 'maxUserStarredCount', '\u79fb\u9664 starred repo \u6578\u8d85\u904e\u9019\u500b\u6578\u5b57\u7684 user')\n        self._setDefault(minItemStargazersCount=1, maxItemStargazersCount=50000, minUserStarredCount=1, maxUserStarredCount=50000)\n        kwargs = self.__init__._input_kwargs\n        self.setParams(**kwargs)\n\n    @keyword_only\n    def setParams(self, minItemStargazersCount=None, maxItemStargazersCount=None, minUserStarredCount=None, maxUserStarredCount=None):\n        kwargs = self.setParams._input_kwargs\n        return self._set(**kwargs)\n\n    def setMinItemStargazersCount(self, value):\n        self._paramMap[self.minItemStargazersCount] = value\n        return self\n\n    def getMinItemStargazersCount(self):\n        return self.getOrDefault(self.minItemStargazersCount)\n\n    def setMaxItemStargazersCount(self, value):\n        self._paramMap[self.maxItemStargazersCount] = value\n        return self\n\n    def getMaxItemStargazersCount(self):\n        return self.getOrDefault(self.maxItemStargazersCount)\n\n    def setMinUserStarredCount(self, value):\n        self._paramMap[self.minUserStarredCount] = value\n        return self\n\n    def getMinUserStarredCount(self):\n        return self.getOrDefault(self.minUserStarredCount)\n\n    def setMaxUserStarredCount(self, value):\n        self._paramMap[self.maxUserStarredCount] = value\n        return self\n\n    def getMaxUserStarredCount(self):\n        return self.getOrDefault(self.maxUserStarredCount)\n\n    def _transform(self, ratingDF):\n        minItemStargazersCount = self.getMinItemStargazersCount()\n        maxItemStargazersCount = self.getMaxItemStargazersCount()\n        minUserStarredCount = self.getMinUserStarredCount()\n        maxUserStarredCount = self.getMaxUserStarredCount()\n\n        toKeepItemsDF = ratingDF \n            .groupBy('item') \n            .agg(F.count('user').alias('stargazers_count')) \n            .where('stargazers_count &gt;= {0} AND stargazers_count &lt;= {1}'.format(minItemStargazersCount, maxItemStargazersCount)) \n            .orderBy('stargazers_count', ascending=False) \n            .select('item', 'stargazers_count')\n        temp1DF = ratingDF.join(toKeepItemsDF, 'item', 'inner')\n\n        toKeepUsersDF = temp1DF \n            .groupBy('user') \n            .agg(F.count('item').alias('starred_count')) \n            .where('starred_count &gt;= {0} AND starred_count &lt;= {1}'.format(minUserStarredCount, maxUserStarredCount)) \n            .orderBy('starred_count', ascending=False) \n            .select('user', 'starred_count')\n        temp2DF = temp1DF.join(toKeepUsersDF, 'user', 'inner')\n\n        cleanDF = temp2DF.select('user', 'item', 'rating', 'starred_at')\n        return cleanDF\n\ndataCleaner = DataCleaner(\n    minItemStargazersCount=2,\n    maxItemStargazersCount=4000,\n    minUserStarredCount=2,\n    maxUserStarredCount=5000\n)\ncleanDF = dataCleaner.transform(ratingDF)\n\ncleanDF.agg(F.count('rating'), F.countDistinct('user'), F.countDistinct('item')).show()\n# +-------------+--------------------+--------------------+\n# |count(rating)|count(DISTINCT user)|count(DISTINCT item)|\n# +-------------+--------------------+--------------------+\n# |      2761118|               10472|              245626|\n# +-------------+--------------------+--------------------+<\/code><\/pre>\n<h3>Generate Negative Samples<\/h3>\n<p>\u5c0d implicit feedback \u7684 ALS \u4f86\u8aaa\uff0c\u624b\u52d5\u52a0\u5165\u8ca0\u6a23\u672c\uff08Rui = 0 \u7684\u6a23\u672c\uff09\u662f\u6c92\u6709\u610f\u7fa9\u7684\uff0c\u56e0\u70ba missing value \/ non-observed value \u5c0d\u8a72\u6f14\u7b97\u6cd5\u4f86\u8aaa\u672c\u4f86\u5c31\u662f 0\uff0c\u8868\u793a\u7528\u6236\u78ba\u5be6\u6c92\u6709\u5c0d\u8a72\u7269\u54c1\u505a\u51fa\u884c\u70ba\uff0c\u4e5f\u5c31\u662f Pui = 0 \u6c92\u6709\u504f\u597d\uff0c\u6240\u4ee5 Cui = 1 + alpha x 0 \u7f6e\u4fe1\u5ea6\u4e5f\u6703\u6bd4\u5176\u4ed6\u6b63\u6a23\u672c\u4f4e\u3002\u4e0d\u904e\u56e0\u70ba Spark ML \u7684 ALS \u53ea\u6703\u8a08\u7b97 Rui &gt; 0 \u7684\u9805\u76ee\uff0c\u6240\u4ee5\u5373\u4fbf\u4f60\u624b\u52d5\u52a0\u5165\u4e86 Rui = 0 \u6216 Rui = -1 \u7684\u8ca0\u6a23\u672c\uff0c\u5c0d\u6574\u500b\u6a21\u578b\u5176\u5be6\u6c92\u6709\u5f71\u97ff\u3002<\/p>\n<p>\u96d6\u7136\u6c92\u6709\u8ca0\u6a23\u672c\u4f60\u5c31\u4e0d\u80fd\u7b97 area under ROC curve \u6216\u662f area under Precision-Recall curve \u7b49 binary classifier \u7528\u7684\u6307\u6a19\uff0c\u4e0d\u904e\u4f60\u53ef\u4ee5\u6539\u7528 Learning to rank \u7684\u8a55\u4f30\u65b9\u5f0f\uff0c\u4f8b\u5982 NDCG \u6216 Mean Average Precision \u7b49\u3002\u4f46\u662f ALS \u7684 loss function \u4e5f\u6c92\u8fa6\u6cd5\u76f4\u63a5\u512a\u5316 NDCG \u9019\u6a23\u7684\u6307\u6a19\u5c31\u662f\u4e86\u3002<\/p>\n<p>ref:<br \/>\n<a href=\"https:\/\/vinta.ws\/code\/generate-negative-samples-for-recommender-system.html\">https:\/\/vinta.ws\/code\/generate-negative-samples-for-recommender-system.html<\/a><\/p>\n<h2>Split Data<\/h2>\n<p>\u56e0\u70ba Matrix Factorization \u9700\u8981\u8003\u616e\u6bcf\u500b user-item pair\uff0c\u5982\u679c\u4f60\u9935\u7d66 model \u5b83\u6c92\u898b\u904e\u7684\u8cc7\u6599\uff0c\u5b83\u5c31\u6c92\u8fa6\u6cd5\u9032\u884c\u63a8\u85a6\uff08\u51b7\u555f\u52d5\u554f\u984c\uff09\u3002\u53ea\u8981 user \u6216 item \u5176\u4e2d\u4e4b\u4e00\u4e0d\u5b58\u5728\u65bc dataset \u88e1\uff0cALS model \u6240\u8f38\u51fa\u7684 prediction \u503c\u5c31\u6703\u662f <code>NaN<\/code>\u3002\u6240\u4ee5\u61c9\u8a72\u76e1\u91cf\u4fdd\u6301\u6bcf\u500b user \u548c item \u90fd\u51fa\u73fe\u5728 training set \u548c testing set \u88e1\uff0c\u4f8b\u5982\u96a8\u6a5f\u6311\u51fa\u6bcf\u500b user \u7684\u4efb\u610f n \u500b\u6216 n \u6bd4\u4f8b\u7684\u8a55\u5206\u4f5c\u70ba test set\uff0c\u5269\u4e0b\u7684\u8a55\u5206\u7576\u4f5c training set\uff08\u4fd7\u7a31 leave-n-out\uff09\u3002\u5982\u679c\u4f7f\u7528 Machine Learning \u4e2d\u5e38\u898b\u7684 holdout \u65b9\u5f0f\uff0c\u96a8\u6a5f\u5730\u628a\u6240\u6709 data point \u5206\u6563\u5230 training set \u548c test set\uff08\u4f8b\u5982 <code>df.randomSplit([0.7, 0.3])<\/code>\uff09\uff0c\u6703\u6709\u5f88\u9ad8\u7684\u6a5f\u7387\u9020\u6210\u90e8\u5206 user \u6216 item \u53ea\u51fa\u73fe\u5728\u5176\u4e2d\u4e00\u7d44 dataset \u88e1\u3002<\/p>\n<p>ref:<br \/>\n<a href=\"https:\/\/jessesw.com\/Rec-System\/\">https:\/\/jessesw.com\/Rec-System\/<\/a><br \/>\n<a href=\"http:\/\/blog.ethanrosenthal.com\/2016\/10\/19\/implicit-mf-part-1\/\">http:\/\/blog.ethanrosenthal.com\/2016\/10\/19\/implicit-mf-part-1\/<\/a><\/p>\n<p>\u5f9e LibRec \u7684\u6587\u4ef6\u4e0a\u4e5f\u53ef\u4ee5\u767c\u73fe\u9084\u6709\u8a31\u591a\u62c6\u5206\u6578\u64da\u7684\u65b9\u5f0f\uff0c\u4f8b\u5982\uff1a<\/p>\n<ul>\n<li>\u57fa\u4e8e Ratio \u7684\u5206\u7c7b\u65b9\u6cd5\u4e3a\u901a\u8fc7\u7ed9\u5b9a\u7684\u6bd4\u4f8b\u6765\u5c06\u6570\u636e\u5206\u4e3a\u4e24\u90e8\u5206\u3002\u8fd9\u4e2a\u5206\u7c7b\u8fc7\u7a0b\u53ef\u4ee5\u5728\u6240\u6709\u6570\u636e\u4e2d\u8fdb\u884c\u968f\u673a\u5206\u7c7b\uff0c\u4e5f\u53ef\u4ee5\u5728\u7528\u6237\u6216\u8005\u7269\u54c1\u7ef4\u5ea6\u4e0a\u8fdb\u884c\u5206\u7c7b\u3002\u5f53\u6709\u65f6\u95f4\u7684\u7279\u5f81\u65f6\uff0c\u53ef\u4ee5\u6839\u636e\u65f6\u95f4\u987a\u5e8f\u7559\u51fa\u6700\u540e\u4e00\u5b9a\u6bd4\u4f8b\u7684\u6570\u636e\u6765\u8fdb\u884c\u6d4b\u8bd5\u3002<\/li>\n<li>LooCV \u7684\u5206\u5272\u65b9\u6cd5\u4e3a leave-one-user\/item\/rating-out\uff0c\u4e5f\u5c31\u662f\u968f\u673a\u9009\u53d6\u6bcf\u4e2a user \u7684\u4efb\u610f\u4e00\u4e2a item \u6216\u8005\u6bcf\u4e2a item \u7684\u4efb\u610f\u4e00\u4e2a user \u4f5c\u4e3a\u6d4b\u8bd5\u6570\u636e\uff0c\u4f59\u4e0b\u7684\u6570\u636e\u6765\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e\u3002\u5728\u5b9e\u73b0\u4e2d\u5b9e\u73b0\u4e86\u57fa\u4e8e User \u548c\u57fa\u4e8e Item \u7684\u591a\u79cd\u5206\u7c7b\u65b9\u5f0f\u3002<\/li>\n<li>GivenN \u5206\u5272\u65b9\u6cd5\u662f\u6307\u4e3a\u6bcf\u4e2a\u7528\u6237\u7559\u51fa\u6307\u5b9a\u6570\u76ee N \u7684\u6570\u636e\u6765\u4f5c\u4e3a\u6d4b\u8bd5\u7528\u4f8b\uff0c\u4f59\u4e0b\u7684\u6837\u672c\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e\u3002<\/li>\n<li>KCV \u5373 K \u6298\u4ea4\u53c9\u9a8c\u8bc1\u3002\u5c06\u6570\u636e\u5206\u5272\u4e3a K \u4efd\uff0c\u5728\u6bcf\u6b21\u6267\u884c\u65f6\u9009\u62e9\u5176\u4e2d\u4e00\u4efd\u4f5c\u4e3a\u6d4b\u8bd5\u6570\u636e\uff0c\u4f59\u4e0b\u7684\u6570\u636e\u4f5c\u4e3a\u8bad\u7ec3\u6570\u636e\uff0c\u5171\u6267\u884c K \u6b21\u3002\u7efc\u5408 K \u6b21\u7684\u8bad\u7ec3\u7ed3\u679c\u6765\u5bf9\u63a8\u8350\u7b97\u6cd5\u7684\u6027\u80fd\u8fdb\u884c\u8bc4\u4f30\u3002<\/li>\n<\/ul>\n<p>ref:<br \/>\n<a href=\"https:\/\/www.librec.net\/dokuwiki\/doku.php?id=DataModel_zh#splitter\">https:\/\/www.librec.net\/dokuwiki\/doku.php?id=DataModel_zh#splitter<\/a><\/p>\n<p>\u9019\u88e1\u6211\u5011\u7528 <code>sampleBy()<\/code> \u7c21\u55ae\u5730\u5beb\u4e86\u4e00\u500b\u6839\u64da user \u4f86\u96a8\u6a5f\u5283\u5206 item \u5230 training set \u548c test set \u7684\u65b9\u6cd5\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-py\">def randomSplitByUser(df, weights, seed=None):\n    trainingRation = weights[0]\n    fractions = {row['user']: trainingRation for row in df.select('user').distinct().collect()}\n    training = df.sampleBy('user', fractions, seed)\n    testRDD = df.rdd.subtract(training.rdd)\n    test = spark.createDataFrame(testRDD, df.schema)\n    return training, test\n\ntraining, test = randomSplitByUser(ratingDF, weights=[0.7, 0.3])<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"http:\/\/spark.apache.org\/docs\/latest\/api\/python\/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy\">http:\/\/spark.apache.org\/docs\/latest\/api\/python\/pyspark.sql.html#pyspark.sql.DataFrame.sampleBy<\/a><\/p>\n<h2>Train the Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.ml.recommendation import ALS\n\nals = ALS(implicitPrefs=True, seed=42) \n    .setRank(50) \n    .setMaxIter(22) \n    .setRegParam(0.5) \n    .setAlpha(40)\n\nalsModel = als.fit(training)\n\n# \u9019\u4e9b\u5c31\u662f\u8a13\u7df4\u51fa\u4f86\u7684 user \u548c item \u7684 Latent Factors\nalsModel.userFactors.show()\nalsModel.itemFactors.show()<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-collaborative-filtering.html\">https:\/\/spark.apache.org\/docs\/latest\/ml-collaborative-filtering.html<\/a><br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/api\/python\/pyspark.ml.html#pyspark.ml.recommendation.ALS\">https:\/\/spark.apache.org\/docs\/latest\/api\/python\/pyspark.ml.html#pyspark.ml.recommendation.ALS<\/a><\/p>\n<h2>Predict Preferences<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.ml import Transformer\n\npredictedDF = alsModel.transform(testing)\n\nclass PredictionProcessor(Transformer):\n\n    def _transform(self, predictedDF):\n        nonNullDF = predictedDF.dropna(subset=['prediction', ])\n        predictionDF = nonNullDF.withColumn('prediction', nonNullDF['prediction'].cast('double'))\n        return predictionDF\n\n# \u522a\u6389\u90a3\u4e9b NaN \u7684\u6578\u64da\npredictionProcessor = PredictionProcessor()\npredictionDF = predictionProcessor.transform(predictedDF)<\/code><\/pre>\n<h2>Evaluate the Model<\/h2>\n<p>\u56e0\u70ba Spark ML \u6c92\u6709\u63d0\u4f9b\u7d66 DataFrame \u7528\u7684 ranking evaluator\uff0c\u6211\u5011\u53ea\u597d\u81ea\u5df1\u5beb\u4e00\u500b\uff0c\u4f46\u662f\u5167\u90e8\u9084\u662f\u4f7f\u7528 Spark MLlib \u7684 <code>RankingMetrics<\/code>\u3002\u4e0d\u904e\u9019\u500b\u53ea\u662f offline \u7684\u8a55\u4f30\u65b9\u5f0f\u800c\u5df2\uff0c\u7b49\u5230\u8981\u5be6\u969b\u4e0a\u7dda\u7684\u6642\u5019\u53ef\u80fd\u9084\u9700\u8981\u505a A\/B testing\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark import keyword_only\nfrom pyspark.ml.evaluation import Evaluator\nfrom pyspark.ml.param.shared import Param\nfrom pyspark.mllib.evaluation import RankingMetrics\nfrom pyspark.sql import Window\nfrom pyspark.sql.functions import col\nfrom pyspark.sql.functions import expr\nimport pyspark.sql.functions as F\n\nclass RankingEvaluator(Evaluator):\n\n    @keyword_only\n    def __init__(self, k=None):\n        super(RankingEvaluator, self).__init__()\n        self.k = Param(self, 'k', 'Top K')\n        self._setDefault(k=30)\n        kwargs = self.__init__._input_kwargs\n        self.setParams(**kwargs)\n\n    @keyword_only\n    def setParams(self, k=None):\n        kwargs = self.setParams._input_kwargs\n        return self._set(**kwargs)\n\n    def isLargerBetter(self):\n        return True\n\n    def setK(self, value):\n        self._paramMap[self.k] = value\n        return self\n\n    def getK(self):\n        return self.getOrDefault(self.k)\n\n    def _evaluate(self, predictedDF):\n        k = self.getK()\n\n        predictedDF.show()\n\n        windowSpec = Window.partitionBy('user').orderBy(col('prediction').desc())\n        perUserPredictedItemsDF = predictedDF \n            .select('user', 'item', 'prediction', F.rank().over(windowSpec).alias('rank')) \n            .where('rank &lt;= {0}'.format(k)) \n            .groupBy('user') \n            .agg(expr('collect_list(item) as items'))\n\n        windowSpec = Window.partitionBy('user').orderBy(col('starred_at').desc())\n        perUserActualItemsDF = predictedDF \n            .select('user', 'item', 'starred_at', F.rank().over(windowSpec).alias('rank')) \n            .where('rank &lt;= {0}'.format(k)) \n            .groupBy('user') \n            .agg(expr('collect_list(item) as items'))\n\n        perUserItemsRDD = perUserPredictedItemsDF.join(F.broadcast(perUserActualItemsDF), 'user', 'inner') \n            .rdd \n            .map(lambda row: (row[1], row[2]))\n\n        if perUserItemsRDD.isEmpty():\n            return 0.0\n\n        rankingMetrics = RankingMetrics(perUserItemsRDD)\n        metric = rankingMetrics.ndcgAt(k)\n        return metric\n\nk = 30\nrankingEvaluator = RankingEvaluator(k=k)\nndcg = rankingEvaluator.evaluate(predictionDF)\nprint('NDCG', ndcg)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/mllib-evaluation-metrics.html\">https:\/\/spark.apache.org\/docs\/latest\/mllib-evaluation-metrics.html<\/a><\/p>\n<h2>Recommend Items<\/h2>\n<p>\u5be6\u969b\u611f\u53d7\u4e00\u4e0b\u63a8\u85a6\u7cfb\u7d71\u7684\u6548\u679c\u5982\u4f55\u3002\u9019\u88e1\u662f\u76f4\u63a5\u628a\u7d50\u679c print \u51fa\u4f86\uff0c\u800c\u6c92\u6709\u628a\u63a8\u85a6\u7d50\u679c\u5132\u5b58\u5230\u8cc7\u6599\u5eab\u3002\u4e0d\u904e\u901a\u5e38\u4e0d\u6703\u76f4\u63a5\u5c31\u628a\u63a8\u85a6\u7cfb\u7d71\u8f38\u51fa\u7684\u6771\u897f\u5c55\u793a\u7d66\u7528\u6236\uff0c\u6703\u5148\u7d93\u904e\u4e00\u4e9b\u904e\u6ffe\u3001\u6392\u5e8f\u548c\u7522\u751f\u63a8\u85a6\u7406\u7531\u7b49\u7b49\u7684\u6b65\u9a5f\uff0c\u6216\u662f\u52a0\u5165\u4e00\u4e9b\u4eba\u70ba\u7684\u898f\u5247\uff0c\u6bd4\u5982\u8aaa\u5f37\u5236\u63d2\u5165\u5ee3\u544a\u3001\u6700\u8fd1\u4e3b\u6253\u7684\u5546\u54c1\u6216\u662f\u904e\u6ffe\u6389\u90a3\u4e9b\u5f88\u591a\u4eba\u9ede\u64ca\u4f46\u662f\u5176\u5be6\u8cea\u91cf\u4e26\u4e0d\u600e\u9ebc\u6a23\u7684\u6771\u897f\u3002\u7576\u7136\u4e5f\u6709\u53ef\u80fd\u6703\u628a\u9019\u500b\u63a8\u85a6\u7cfb\u7d71\u7684\u8f38\u51fa\u4f5c\u70ba\u5176\u4ed6\u6a5f\u5668\u5b78\u7fd2 model \u7684\u8f38\u5165\u3002<\/p>\n<p>ref:<br \/>\n<a href=\"https:\/\/www.zhihu.com\/question\/28247353\">https:\/\/www.zhihu.com\/question\/28247353<\/a><\/p>\n<pre class=\"line-numbers\"><code class=\"language-py\">def recommendItems(rawDF, alsModel, username, topN=30, excludeKnownItems=False):\n    userID = rawDF \n        .where('from_username = \"{0}\"'.format(username)) \n        .select('from_user_id') \n        .take(1)[0]['from_user_id']\n\n    userItemsDF = alsModel \n        .itemFactors. \n        selectExpr('{0} AS user'.format(userID), 'id AS item')\n    if excludeKnownItems:\n        userKnownItemsDF = rawDF \n            .where('from_user_id = {0}'.format(userID)) \n            .selectExpr('repo_id AS item')\n        userItemsDF = userItemsDF.join(userKnownItemsDF, 'item', 'left_anti')\n\n    userPredictedDF = alsModel \n        .transform(userItemsDF) \n        .select('item', 'prediction') \n        .orderBy('prediction', ascending=False) \n        .limit(topN)\n\n    repoDF = rawDF \n        .groupBy('repo_id', 'repo_full_name', 'repo_language') \n        .agg(F.max('stargazers_count').alias('stargazers_count'))\n\n    recommendedItemsDF = userPredictedDF \n        .join(repoDF, userPredictedDF['item'] == repoDF['repo_id'], 'inner') \n        .select('prediction', 'repo_full_name', 'repo_language', 'stargazers_count') \n        .orderBy('prediction', ascending=False)\n\n    return recommendedItemsDF\n\nk = 30\nusername = 'vinta'\nrecommendedItemsDF = recommendItems(rawDF, alsModel, username, topN=k, excludeKnownItems=False)\nfor item in recommendedItemsDF.collect():\n    repoName = item['repo_full_name']\n    repoUrl = 'https:\/\/github.com\/{0}'.format(repoName)\n    print(repoUrl, item['prediction'], item['repo_language'], item['stargazers_count'])<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/github.com\/vinta\/albedo\/blob\/master\/src\/main\/python\/train_als.ipynb\">https:\/\/github.com\/vinta\/albedo\/blob\/master\/src\/main\/python\/train_als.ipynb<\/a><\/p>\n<h2>Cross-validate Models<\/h2>\n<p>\u4f7f\u7528 Spark ML \u7684 <code>pipeline<\/code> \u4f86\u505a cross-validation\uff0c\u9078\u51fa\u6700\u9069\u5408\u7684 hyperparameters \u7d44\u5408\u3002<\/p>\n<ul>\n<li><code>rank<\/code>: The number of latent factors in the model, or equivalently, the number of columns k in the user-feature and product-feature matrices.<\/li>\n<li><code>regParam<\/code>: A standard overfitting parameter, also usually called lambda. Higher values resist overfitting, but values that are too high hurt the factorization\u2019s accuracy.<\/li>\n<li><code>alpha<\/code>: Controls the relative weight of observed versus unobserved user-product interactions in the factorization.<\/li>\n<li><code>maxIter<\/code>: The number of iterations that the factorization runs. More iterations take more time but may produce a better factorization.<\/li>\n<\/ul>\n<pre class=\"line-numbers\"><code class=\"language-py\">dataCleaner = DataCleaner()\n\nals = ALS(implicitPrefs=True, seed=42)\n\npredictionProcessor = PredictionProcessor()\n\npipeline = Pipeline(stages=[\n    dataCleaner,\n    als,\n    predictionProcessor,\n])\n\nparamGrid = ParamGridBuilder() \n    .addGrid(dataCleaner.minItemStargazersCount, [1, 10, 100]) \n    .addGrid(dataCleaner.maxItemStargazersCount, [4000, ]) \n    .addGrid(dataCleaner.minUserStarredCount, [1, 10, 100]) \n    .addGrid(dataCleaner.maxUserStarredCount, [1000, 4000, ]) \n    .addGrid(als.rank, [50, 100]) \n    .addGrid(als.regParam, [0.01, 0.1, 0.5]) \n    .addGrid(als.alpha, [0.01, 0.89, 1, 40, ]) \n    .addGrid(als.maxIter, [22, ]) \n    .build()\n\nrankingEvaluator = RankingEvaluator(k=30)\n\ncv = CrossValidator(estimator=pipeline,\n                    estimatorParamMaps=paramGrid,\n                    evaluator=rankingEvaluator,\n                    numFolds=2)\n\ncvModel = cv.fit(ratingDF)\n\ndef printCrossValidationParameters(cvModel):\n    metric_params_pairs = list(zip(cvModel.avgMetrics, cvModel.getEstimatorParamMaps()))\n    metric_params_pairs.sort(key=lambda x: x[0], reverse=True)\n    for pair in metric_params_pairs:\n        metric, params = pair\n        print('metric', metric)\n        for k, v in params.items():\n            print(k.name, v)\n        print('')\n\nprintCrossValidationParameters(cvModel)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-pipeline.html\">https:\/\/spark.apache.org\/docs\/latest\/ml-pipeline.html<\/a><\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u5728\u672c\u7bc7\u6587\u7ae0\u4e2d\uff0c\u6211\u5011\u4ee5 Candidate Generation \u968e\u6bb5\u5e38\u7528\u7684\u65b9\u6cd5\u4e4b\u4e00\uff1aCollaborative Filtering \u5354\u540c\u904e\u6ffe\u6f14\u7b97\u6cd5\u70ba\u4f8b\uff0c\u5229\u7528 Apache Spark \u7684 ALS (Alternating Least Squares) \u6a21\u578b\u5efa\u7acb\u4e00\u500b GitHub repositories \u7684\u63a8\u85a6\u7cfb\u7d71\uff0c\u4ee5\u7528\u6236\u5c0d repo \u7684\u6253\u661f\u7d00\u9304\u4f5c\u70ba\u8a13\u7df4\u6578\u64da\uff0c\u63a8\u85a6\u51fa\u7528\u6236\u53ef\u80fd\u6703\u611f\u8208\u8da3\u7684\u5176\u4ed6 repo \u4f5c\u70ba\u5019\u9078\u7269\u54c1\u96c6\u3002<\/p>\n","protected":false},"author":1,"featured_media":395,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[97,112,4],"tags":[108,98,2,104],"class_list":["post-394","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-about-ai","category-about-big-data","category-about-python","tag-apache-spark","tag-machine-learning","tag-python","tag-recommender-system"],"_links":{"self":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts\/394","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/comments?post=394"}],"version-history":[{"count":0,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts\/394\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/media\/395"}],"wp:attachment":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/media?parent=394"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/categories?post=394"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/tags?post=394"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}