{"id":388,"date":"2017-05-15T00:59:12","date_gmt":"2017-05-14T16:59:12","guid":{"rendered":"http:\/\/vinta.ws\/code\/?p=388"},"modified":"2026-02-18T01:20:35","modified_gmt":"2026-02-17T17:20:35","slug":"spark-ml-cookbook-pyspark","status":"publish","type":"post","link":"https:\/\/vinta.ws\/code\/spark-ml-cookbook-pyspark.html","title":{"rendered":"Spark ML cookbook (Python)"},"content":{"rendered":"<h2>Calculate percentage of sparsity of a user-item rating DataFrame<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">result = ratingDF.agg(F.count('rating'), F.countDistinct('user'), F.countDistinct('item')).collect()[0]\ntotalUserCount = result['count(DISTINCT user)']\ntotalItemCount = result['count(DISTINCT item)']\nzonZeroRatingCount = result['count(rating)']\n\ndensity = (zonZeroRatingCount \/ (totalUserCount * totalItemCount)) * 100\nsparsity = 100 - density<\/code><\/pre>\n<h2>Recommend items for single user<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">topN = 30\nuserID = 123\n\nuserItemsDF = alsModel \n    .itemFactors. \n    selectExpr('{0} AS user'.format(userID), 'id AS item')\nuserPredictedDF = alsModel \n    .transform(userItemsDF) \n    .select('item', 'prediction') \n    .orderBy('prediction', ascending=False) \n    .limit(topN)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/www.safaribooksonline.com\/library\/view\/advanced-analytics-with\/9781491972946\/ch03.html\">https:\/\/www.safaribooksonline.com\/library\/view\/advanced-analytics-with\/9781491972946\/ch03.html<\/a><br \/>\n<a href=\"https:\/\/www.safaribooksonline.com\/library\/view\/building-a-recommendation\/9781785282584\/ch06s04.html\">https:\/\/www.safaribooksonline.com\/library\/view\/building-a-recommendation\/9781785282584\/ch06s04.html<\/a><br \/>\n<a href=\"https:\/\/www.safaribooksonline.com\/library\/view\/building-recommendation-engines\/9781785884856\/ch07s05.html\">https:\/\/www.safaribooksonline.com\/library\/view\/building-recommendation-engines\/9781785884856\/ch07s05.html<\/a><\/p>\n<h2>Recommend items for every user (a slightly fast way)<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">import numpy\nfrom numpy import *\n\nmyModel = MatrixFactorizationModel.load(sc, 'BingBong')\nm1 = myModel.productFeatures()\nm2 = m1.map(lambda (product, feature): feature).collect()\nm3 = matrix(m2).transpose()\npf = sc.broadcast(m3)\nuf = myModel.userFeatures().coalesce(100)\n\n# get predictions on all user\nf1 = uf.map(lambda (userID, features): (userID, squeeze(asarray(matrix(array(features)) * pf.value))))<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/www.slideshare.net\/SparkSummit\/26-trillion-app-recomendations-using-100-lines-of-spark-code-ayman-farahat\">https:\/\/www.slideshare.net\/SparkSummit\/26-trillion-app-recomendations-using-100-lines-of-spark-code-ayman-farahat<\/a><\/p>\n<h2>Evaluate a binary classification<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.ml.evaluation import BinaryClassificationEvaluator\n\nmatrix = [\n    (0.5, 1),\n    (2.0, 1),\n    (0.8, 1),\n    (0.2, 0),\n    (0.1, 0),\n    (0.4, 0),\n]\npredictions = spark.createDataFrame(matrix, ['prediction', 'label'])\npredictions = predictions.withColumn('prediction', predictions['prediction'].cast('double'))\n\nevaluator = BinaryClassificationEvaluator(rawPredictionCol='prediction', labelCol='label', metricName='areaUnderROC')\nevaluator.evaluate(predictions)<\/code><\/pre>\n<h2>Calculate ranking metrics<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.mllib.evaluation import RankingMetrics\nfrom pyspark.sql import Window\nfrom pyspark.sql.functions import col, expr\nimport pyspark.sql.functions as F\n\nk = 10\n\nwindowSpec = Window.partitionBy('user').orderBy(col('prediction').desc())\nperUserPredictedItemsDF = outputDF \n    .select('user', 'item', 'prediction', F.rank().over(windowSpec).alias('rank')) \n    .where('rank &lt;= {0}'.format(k)) \n    .groupBy('user') \n    .agg(expr('collect_list(item) as items'))\nperUserPredictedItemsDF.show()\n# +--------+--------------------+\n# |    user|               items|\n# +--------+--------------------+\n# |    2142|[36560369, 197450...|\n# |   47217|[40693501, 643554...|\n# +--------+--------------------+\n\nwindowSpec = Window.partitionBy('from_user_id').orderBy(col('starred_at').desc())\nperUserActualItemsDF = rawDF \n    .select('from_user_id', 'repo_id', 'starred_at', F.rank().over(windowSpec).alias('rank')) \n    .where('rank &lt;= {0}'.format(k)) \n    .groupBy('from_user_id') \n    .agg(expr('collect_list(repo_id) as items')) \n    .withColumnRenamed('from_user_id', 'user')\n# +--------+--------------------+\n# |    user|               items|\n# +--------+--------------------+\n# |    2142|[29122050, 634846...|\n# |   59990|[9820191, 8729416...|\n# +--------+--------------------+\n\nperUserItemsRDD = perUserPredictedItemsDF.join(perUserActualItemsDF, 'user') \n    .rdd \n    .map(lambda row: (row[1], row[2]))\nrankingMetrics = RankingMetrics(perUserItemsRDD)\n\nprint(rankingMetrics.meanAveragePrecision)\nprint(rankingMetrics.precisionAt(k))\nprint(rankingMetrics.ndcgAt(k))<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/www.safaribooksonline.com\/library\/view\/spark-the-definitive\/9781491912201\/ch19.html\">https:\/\/www.safaribooksonline.com\/library\/view\/spark-the-definitive\/9781491912201\/ch19.html<\/a><\/p>\n<h2>Create a custom Transformer<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.ml import Transformer\n\nclass PredictionProcessor(Transformer):\n\n    def _transform(self, predictedDF):\n        nonNullDF = predictedDF.dropna(subset=['prediction', ])\n        predictionDF = nonNullDF.withColumn('prediction', nonNullDF['prediction'].cast('double'))\n        return predictionDF<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"http:\/\/stackoverflow.com\/questions\/32331848\/create-a-custom-transformer-in-pyspark-ml\">http:\/\/stackoverflow.com\/questions\/32331848\/create-a-custom-transformer-in-pyspark-ml<\/a><\/p>\n<h2>RankingMetrics<\/h2>\n<p>\u7d50\u679c\u4ecb\u65bc 0 ~ 1 \u4e4b\u9593\uff0c\u5206\u6578\u8d8a\u5927\u8d8a\u597d\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark.mllib.evaluation import RankingMetrics\n\npredictionAndLabels = sc.parallelize([\n    ([1, 2, 3, 4], [1, 2, 3, 4]),\n    ([1, ], [1, 10]),\n    ([6, 4, 2], [6, 2, 100, 8, 2, 55]),\n])\nmetrics = RankingMetrics(predictionAndLabels)\nmetrics.ndcgAt(5)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/mllib-evaluation-metrics.html#ranking-systems\">https:\/\/spark.apache.org\/docs\/latest\/mllib-evaluation-metrics.html#ranking-systems<\/a><\/p>\n<h2>Create a custom Evaluator<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">from pyspark import keyword_only\nfrom pyspark.ml.evaluation import Evaluator\nfrom pyspark.ml.param.shared import Param\nfrom pyspark.mllib.evaluation import RankingMetrics\nfrom pyspark.sql import Window\nfrom pyspark.sql.functions import col\nfrom pyspark.sql.functions import expr\nimport pyspark.sql.functions as F\n\nclass RankingEvaluator(Evaluator):\n\n    @keyword_only\n    def __init__(self, k=None):\n        super(RankingEvaluator, self).__init__()\n        self.k = Param(self, 'k', 'Top K')\n        self._setDefault(k=30)\n        kwargs = self.__init__._input_kwargs\n        self.setParams(**kwargs)\n\n    @keyword_only\n    def setParams(self, k=None):\n        kwargs = self.setParams._input_kwargs\n        return self._set(**kwargs)\n\n    def isLargerBetter(self):\n        return True\n\n    def setK(self, value):\n        self._paramMap[self.k] = value\n        return self\n\n    def getK(self):\n        return self.getOrDefault(self.k)\n\n    def _evaluate(self, outputDF):\n        k = self.getK()\n\n        windowSpec = Window.partitionBy('user').orderBy(col('prediction').desc())\n        perUserPredictedItemsDF = outputDF \n            .select('user', 'item', 'prediction', F.rank().over(windowSpec).alias('rank')) \n            .where('rank &lt;= {0}'.format(k)) \n            .groupBy('user') \n            .agg(expr('collect_list(item) as items'))\n\n        windowSpec = Window.partitionBy('user').orderBy(col('starred_at').desc())\n        perUserActualItemsDF = outputDF \n            .select('user', 'item', 'starred_at', F.rank().over(windowSpec).alias('rank')) \n            .where('rank &lt;= {0}'.format(k)) \n            .groupBy('user') \n            .agg(expr('collect_list(item) as items'))\n\n        perUserItemsRDD = perUserPredictedItemsDF.join(F.broadcast(perUserActualItemsDF), 'user', 'inner') \n            .rdd \n            .map(lambda row: (row[1], row[2]))\n        rankingMetrics = RankingMetrics(perUserItemsRDD)\n        metric = rankingMetrics.ndcgAt(k)\n        return metric<\/code><\/pre>\n<h2>Show best parameters from a CrossValidatorModel<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">metric_params_pairs = list(zip(cvModel.avgMetrics, cvModel.getEstimatorParamMaps()))\nmetric_params_pairs.sort(key=lambda x: x[0], reverse=True)\nbest_metric_params = metric_params_pairs[0][1]\nfor pair in metric_params_pairs:\n    metric, params = pair\n    print('metric', metric)\n    for k, v in params.items():\n        print(k.name, v)\n    print('')\n# metric 0.5273636632856705\n# rank 50\n# regParam 0.01\n# maxIter 10\n# alpha 1<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/31749593\/how-to-extract-best-parameters-from-a-crossvalidatormodel\">https:\/\/stackoverflow.com\/questions\/31749593\/how-to-extract-best-parameters-from-a-crossvalidatormodel<\/a><br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/39529012\/pyspark-get-all-parameters-of-models-created-with-paramgridbuilder\">https:\/\/stackoverflow.com\/questions\/39529012\/pyspark-get-all-parameters-of-models-created-with-paramgridbuilder<\/a><\/p>\n<h2>Show best parameters from a TrainValidationSplitModel<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-py\">metric_params_pairs = list(zip(tvModel.validationMetrics, tvModel.getEstimatorParamMaps()))\nmetric_params_pairs.sort(key=lambda x: x[0], reverse=True)\nfor pair in metric_params_pairs:\n    metric, params = pair\n    print('metric', metric)\n    for k, v in params.items():\n        print(k.name, v)\n    print('')\n# metric 0.5385481418189484\n# rank 50\n# regParam 0.1\n# maxIter 20\n# alpha 1<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>As of Spark 2.0, the RDD-based APIs in the spark.mllib package have entered maintenance mode. The primary Machine Learning API for Spark is now the DataFrame-based API in the spark.ml package.<\/p>\n","protected":false},"author":1,"featured_media":389,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[97,112,4],"tags":[108,98,2,104],"class_list":["post-388","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-about-ai","category-about-big-data","category-about-python","tag-apache-spark","tag-machine-learning","tag-python","tag-recommender-system"],"_links":{"self":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts\/388","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/comments?post=388"}],"version-history":[{"count":0,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts\/388\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/media\/389"}],"wp:attachment":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/media?parent=388"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/categories?post=388"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/tags?post=388"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}