{"id":431,"date":"2017-09-02T16:59:30","date_gmt":"2017-09-02T08:59:30","guid":{"rendered":"http:\/\/vinta.ws\/code\/?p=431"},"modified":"2026-03-17T01:19:02","modified_gmt":"2026-03-16T17:19:02","slug":"spark-ml-cookbook-scala","status":"publish","type":"post","link":"https:\/\/vinta.ws\/code\/spark-ml-cookbook-scala.html","title":{"rendered":"Spark ML cookbook (Scala)"},"content":{"rendered":"<p>Scala is the first class citizen language for interacting with Apache Spark, but it's difficult to learn. This article is mostly about Spark ML - the new Spark Machine Learning library which was rewritten in DataFrame-based API.<\/p>\n<h2>Convert a String Categorical Feature into Numeric One<\/h2>\n<p><code>StringIndexer<\/code> converts labels (categorical values) into numbers (0.0, 1.0, 2.0 and so on) which ordered by label frequencies, the most frequent label gets <code>0<\/code>. This method is able to handle unseen labels with optional strategies.<\/p>\n<p><code>StringIndexer<\/code>'s <code>inputCol<\/code> accepts string, numeric and boolean types.<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">val df1 = spark.createDataFrame(Seq(\n    (1, \"Python\"),\n    (2, \"C++\"),\n    (3, \"C++\"),\n    (4, \"JavaScript\"),\n    (5, \"Python\"),\n    (6, \"Python\"),\n    (7, \"Go\")\nBinaryClassificationEvaluator\n)).toDF(\"repo_id\", \"repo_language\")\n\nval df2 = spark.createDataFrame(Seq(\n    (1, \"Python\"),\n    (2, \"C++\"),\n    (3, \"C++\"),\n    (4, \"JavaScript\"),\n    (5, \"Python\"),\n    (6, \"Python\"),\n    (7, \"Go\"),\n    (8, \"JavaScript\"),\n    (9, \"Brainfuck\"),\n    (10, \"Brainfuck\"),\n    (11, \"Red\")\n)).toDF(\"repo_id\", \"repo_language\")\n\nimport org.apache.spark.ml.feature.StringIndexer\n\nval stringIndexer = new StringIndexer()\n  .setInputCol(\"repo_language\")\n  .setOutputCol(\"repo_language_index\")\n  .setHandleInvalid(\"keep\")\nval stringIndexerModel = stringIndexer.fit(df1)\n\nstringIndexerModel.labels\n\/\/ Array[String] = Array(Python, C++, JavaScript, Go)\n\nval indexedDF = stringIndexerModel.transform(df2)\nindexedDF.show()\n\/\/ +-------+-------------+-------------------+\n\/\/ |repo_id|repo_language|repo_language_index|\n\/\/ +-------+-------------+-------------------+\n\/\/ |      1|       Python|                0.0|\n\/\/ |      2|          C++|                1.0|\n\/\/ |      3|          C++|                1.0|\n\/\/ |      4|   JavaScript|                3.0|\n\/\/ |      5|       Python|                0.0|\n\/\/ |      6|       Python|                0.0|\n\/\/ |      7|           Go|                2.0|\n\/\/ |      8|   JavaScript|                3.0|\n\/\/ |      9|    Brainfuck|                4.0| &lt;- previously unseen\n\/\/ |     10|    Brainfuck|                4.0| &lt;- previously unseen\n\/\/ |     11|          Red|                4.0| &lt;- previously unseen\n\/\/ +-------+-------------+-------------------+<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#stringindexer\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#stringindexer<\/a><br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/34681534\/spark-ml-stringindexer-handling-unseen-labels\">https:\/\/stackoverflow.com\/questions\/34681534\/spark-ml-stringindexer-handling-unseen-labels<\/a><br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/32277576\/how-to-handle-categorical-features-with-spark-ml\/32278617\">https:\/\/stackoverflow.com\/questions\/32277576\/how-to-handle-categorical-features-with-spark-ml\/32278617<\/a><\/p>\n<h2>Convert an Indexed Numeric Feature Back to the Original Categorical One<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.IndexToString\n\nval indexToString = new IndexToString()\n  .setInputCol(\"repo_language_index\")\n  .setOutputCol(\"repo_language_ori\")\n\nval oriIndexedDF = indexToString.transform(indexedDF)\noriIndexedDF.show()\n\/\/ +-------+-------------+-------------------+----------------------+\n\/\/ |repo_id|repo_language|repo_language_index|     repo_language_ori|\n\/\/ +-------+-------------+-------------------+----------------------+\n\/\/ |      1|       Python|                0.0|                Python|\n\/\/ |      2|          C++|                1.0|                   C++|\n\/\/ |      3|          C++|                1.0|                   C++|\n\/\/ |      4|   JavaScript|                2.0|            JavaScript|\n\/\/ |      5|       Python|                0.0|                Python|\n\/\/ |      6|       Python|                0.0|                Python|\n\/\/ |      7|           Go|                3.0|                    Go|\n\/\/ |      8|   JavaScript|                2.0|            JavaScript|\n\/\/ |      9|    Brainfuck|                4.0|             __unknown| &lt;- previously unseen\n\/\/ |     10|    Brainfuck|                4.0|             __unknown| &lt;- previously unseen\n\/\/ |     11|          Red|                4.0|             __unknown| &lt;- previously unseen\n\/\/ +-------+-------------+-------------------+----------------------+<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#indextostring\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#indextostring<\/a><\/p>\n<h2>One-hot Encoding for Categorical Features<\/h2>\n<p><code>OneHotEncoder<\/code>'s input column only accepts numeric types. If you have string columns, you need to use <code>StringIndexer<\/code> to transform them into doubles, besides, <code>StringIndexer<\/code> is able to properly deal with unseen values. In my humble opinion, you should always apply <code>StringIndexer<\/code> before <code>OneHotEncoder<\/code>.<\/p>\n<p>Be careful that <code>OneHotEncoder<\/code>'s vector length will be the maximum value in the column, you must apply <code>OneHotEncoder<\/code> on the union dataset of both training set and test set. Since <code>OneHotEncoder<\/code> does not accept empty string for name, you need to replace all empty strings with a placeholder, something like <code>__empty<\/code>.<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.OneHotEncoder\n\nval knownDF = spark.createDataFrame(Seq(\n  (2, \"b\"),\n  (3, \"c\"),\n  (0, \"x\"),\n  (6, \"c\"),\n  (4, \"a\"),\n  (1, \"a\"),\n  (5, \"a\")\n)).toDF(\"category_1\", \"category_2\")\n\nval unseenDF = spark.createDataFrame(Seq(\n  (123, \"e\"),\n  (6, \"c\"),\n  (2, \"b\"),\n  (456, \"c\"),\n  (1, \"a\")\n)).toDF(\"category_1\", \"category_2\")\n\nval knownOneHotDF = new OneHotEncoder()\n  .setDropLast(true)\n  .setInputCol(\"category_1\")\n  .setOutputCol(\"category_1_one_hot\")\n  .transform(knownDF)\nknownOneHotDF.show()\n\/\/ +----------+----------+------------------+\n\/\/ |category_1|category_2|category_1_one_hot|\n\/\/ +----------+----------+------------------+\n\/\/ |         2|         b|     (6,[2],[1.0])|\n\/\/ |         3|         c|     (6,[3],[1.0])|\n\/\/ |         0|         x|     (6,[0],[1.0])|\n\/\/ |         6|         c|         (6,[],[])|\n\/\/ |         4|         a|     (6,[4],[1.0])|\n\/\/ |         1|         a|     (6,[1],[1.0])|\n\/\/ |         5|         a|     (6,[5],[1.0])|\n\/\/ +----------+----------+------------------+\n\nval unseenOneHotDF = new OneHotEncoder()\n  .setDropLast(true)\n  .setInputCol(\"category_1\")\n  .setOutputCol(\"category_1_one_hot\")\n  .transform(unseenDF)\nunseenOneHotDF.show()\n\/\/ +----------+----------+------------------+\n\/\/ |category_1|category_2|category_1_one_hot|\n\/\/ +----------+----------+------------------+\n\/\/ |       123|         e| (456,[123],[1.0])|\n\/\/ |         6|         c|   (456,[6],[1.0])|\n\/\/ |         2|         b|   (456,[2],[1.0])|\n\/\/ |       456|         c|       (456,[],[])|\n\/\/ |         1|         a|   (456,[1],[1.0])|\n\/\/ +----------+----------+------------------+<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#onehotencoder\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#onehotencoder<\/a><br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/32277576\/how-to-handle-categorical-features-with-spark-ml\/40615508\">https:\/\/stackoverflow.com\/questions\/32277576\/how-to-handle-categorical-features-with-spark-ml\/40615508<\/a><br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/33089781\/spark-dataframe-handing-empty-string-in-onehotencoder\">https:\/\/stackoverflow.com\/questions\/33089781\/spark-dataframe-handing-empty-string-in-onehotencoder<\/a><\/p>\n<h2>Create a Regular Expression Tokenizer<\/h2>\n<p><code>setGaps(true)<\/code> \u6642\u7684 pattern \u662f match \u5206\u9694\u7b26\uff1b<code>setGaps(false)<\/code> \u6642\u7684 pattern \u5247\u662f match \u5b57\u3002<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.RegexTokenizer\nimport org.apache.spark.sql.functions._\n\nval sentenceDF = spark.createDataFrame(Seq(\n  (1, \"Hi, I heard about Spark\"),\n  (2, \"I wish Java could use case classes.\"),\n  (3, \"Deep,Learning,models,are,state-of-the-art\"),\n  (4, \"fuck_yeah!!! No.\")\n)).toDF(\"id\", \"sentence\")\n\nval countTokensUDF = udf((words: Seq[String]) =&gt; words.length)\n\nval regexTokenizer = new RegexTokenizer()\n  .setInputCol(\"sentence\")\n  .setOutputCol(\"words\")\n  .setPattern(\"\"\"[w-_]+\"\"\").setGaps(false)\n  \/\/ .setPattern(\"\"\"W\"\"\").setGaps(true)\n  \/\/ .setPattern(\"\"\"[,. ]\"\"\").setGaps(true)\nval tokenizedDF = regexTokenizer.transform(sentenceDF)\n\nval df = tokenizedDF\n  .select(\"sentence\", \"words\")\n  .withColumn(\"count\", countTokensUDF($\"words\"))\n\/\/ +-----------------------------------------+-----------------------------------------------+-----+\n\/\/ |sentence                                 |words                                          |count|\n\/\/ +-----------------------------------------+-----------------------------------------------+-----+\n\/\/ |Hi, I heard about Spark                  |[hi, i, heard, about, spark]                   |5    |\n\/\/ |I wish Java could use case classes.      |[i, wish, java, could, use, case, classes]     |7    |\n\/\/ |Deep,Learning,models,are,state-of-the-art|[deep, learning, models, are, state-of-the-art]|5    |\n\/\/ |fuck_yeah!!! No.                         |[fuck_yeah, no]                                |2    |\n\/\/ +-----------------------------------------+-----------------------------------------------+-----+<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#tokenizer\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#tokenizer<\/a><\/p>\n<h2>Handle Comma-seperated Categorical Column<\/h2>\n<p>You could use <code>RegexTokenizer<\/code>, <code>CountVectorizer<\/code> or <code>HashingTF<\/code>.<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.{RegexTokenizer, CountVectorizer}\n\nval df = spark.createDataFrame(Seq(\n  (1, \"Action,Sci-Fi\"),\n  (2, \"Sci-Fi,Romance,Horror\"),\n  (3, \"War,Horror\")\n)).toDF(\"movie_id\", \"genres\")\n\nval regexTokenizer = new RegexTokenizer()\n  .setInputCol(\"genres\")\n  .setOutputCol(\"genres_words\")\n  .setPattern(\"\"\"[w-_]+\"\"\").setGaps(false)\nval wordsDF = regexTokenizer.transform(df)\n\nval countVectorizerModel = new CountVectorizer()\n  .setInputCol(\"genres_words\")\n  .setOutputCol(\"genres_vector\")\n  .setMinDF(1) \/\/ for whole corpus, delete any term that appears less then n times\n  .setMinTF(1) \/\/ for each document, delete any term that appears less then n times\n  .fit(wordsDF)\nval countVectorDF = countModel.transform(wordsDF)\n\n\/\/ HashingTF might suffer from potential hash collisions\n\/\/ it's good to use a power of two\nval hashingTF = new HashingTF()\n  .setInputCol(\"genres_words\")\n  .setOutputCol(\"genres_htf_vector\")\n  .setNumFeatures(4)\nval htfVectorDF = hashingTF.transform(countVectorDF)\n\nhtfVectorDF.show(false)\n\/\/ +--------+---------------------+-------------------------+-------------------------+-------------------+\n\/\/ |movie_id|genres               |genres_words             |genres_count_vector      |genres_htf_vector  |\n\/\/ +--------+---------------------+-------------------------+-------------------------+-------------------+\n\/\/ |1       |Action,Sci-Fi        |[action, sci-fi]         |(5,[0,3],[1.0,1.0])      |(4,[0],[2.0])      |\n\/\/ |2       |Sci-Fi,Romance,Horror|[sci-fi, romance, horror]|(5,[0,1,4],[1.0,1.0,1.0])|(4,[0,2],[2.0,1.0])|\n\/\/ |3       |War,Horror           |[war, horror]            |(5,[1,2],[1.0,1.0])      |(4,[0,2],[1.0,1.0])|\n\/\/ +--------+---------------------+-------------------------+-------------------------+-------------------+\n\ncountModel.vocabulary\n\/\/ Array(sci-fi, horror, action, romance, war)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#countvectorizer\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#countvectorizer<\/a><br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#tf-idf\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#tf-idf<\/a><\/p>\n<h2>Train a Word2Vec Model<\/h2>\n<p>The output vector of any Word2Vec model is dense!<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.Word2Vec\n\nval df = spark.createDataFrame(Seq(\n  (1, \"Hi I heard about Apache Spark\".toLowerCase().split(\" \")),\n  (2, \"I wish Java could use case classes\".toLowerCase().split(\" \")),\n  (3, \"Logistic regression models are neat\".toLowerCase().split(\" \")),\n  (4, \"Apache Spark with Scala is awesome\".toLowerCase().split(\" \")),\n  (5, Array(\"\u4e2d\u6587\", \"\u561b\u311f\u901a\", \"but\", \"\u5fc5\u9808\", \"\u53e6\u5916\", \"\u5206\u8a5e\"))\n)).toDF(\"id\", \"words\")\n\nval word2Vec = new Word2Vec()\n  .setInputCol(\"words\")\n  .setOutputCol(\"words_w2v\")\n  .setMaxIter(10)\n  .setVectorSize(3)\n  .setWindowSize(5)\n  .setMinCount(1)\nval word2VecModel = word2Vec.fit(df)\n\nword2VecModel.transform(df)\n\/\/ +---+------------------------------------------+----------------------------------------------------------+\n\/\/ |id |words                                     |words_w2v                                                 |\n\/\/ +---+------------------------------------------+----------------------------------------------------------+\n\/\/ |1  |[hi, i, heard, about, apache, spark]      |[-0.02013699459393,-0.02995631482274,0.047685102870066956]|\n\/\/ |2  |[i, wish, java, could, use, case, classes]|[-0.05012317272186,0.01141336891094,-0.03742781743806387] |\n\/\/ |3  |[logistic, regression, models, are, neat] |[-0.04678827972413,0.032994424477,0.0010566591750830413]  |\n\/\/ |4  |[apache, spark, with, scala, is, awesome] |[0.0265524153169,0.02056275321716,0.013326843579610188]   |\n\/\/ |5  |[\u4e2d\u6587, \u561b\u311f\u901a, but, \u5fc5\u9808, \u53e6\u5916, \u5206\u8a5e]         |[0.0571783996973,-0.02301329133545,0.013507421438892681]  |\n\/\/ +---+------------------------------------------+----------------------------------------------------------+\n\nval df2 = spark.createDataFrame(Seq(\n  (6, Array(\"not-in-vocabularies\", \"neither\", \"no\")),\n  (7, Array(\"spark\", \"not-in-vocabularies\")),\n  (8, Array(\"not-in-vocabularies\", \"spark\")),\n  (9, Array(\"no\", \"not-in-vocabularies\", \"spark\")),\n  (10, Array(\"\u4e2d\u6587\", \"spark\"))\n)).toDF(\"id\", \"words\")\n\nword2VecModel.transform(df2)\n\/\/ the order of words doesn't matter\n\/\/ +---+-------------------------------------+-----------------------------------------------------------------+\n\/\/ |id |words                                |words_w2v                                                        |\n\/\/ +---+-------------------------------------+-----------------------------------------------------------------+\n\/\/ |6  |[not-in-vocabularies, neither, no]   |[0.0,0.0,0.0]                                                    |\n\/\/ |7  |[spark, hell_no, not-in-vocabularies]|[0.0027440187210838,-0.0529780387878418,0.05730373660723368]     |\n\/\/ |8  |[hell_no, not-in-vocabularies, spark]|[0.0027440187210838,-0.0529780387878418,0.05730373660723368]     |\n\/\/ |9  |[not-in-vocabularies, hell_no, spark]|[0.0027440187210838,-0.0529780387878418,0.05730373660723368]     |\n\/\/ |10 |[no, not-in-vocabularies, spark]     |[0.0027440187210838,-0.0529780387878418,0.05730373660723368]     |\n\/\/ |11 |[\u4e2d\u6587, spark]                         |[-0.009499748703092337,-0.018227852880954742,0.13357853144407272]|\n\/\/ +---+-------------------------------------+-----------------------------------------------------------------+\n\nanotherWord2VecModel.findSynonyms(\"developer\", 5)\n\/\/ +-----------+------------------+\n\/\/ |       word|        similarity|\n\/\/ +-----------+------------------+\n\/\/ |        dev| 0.881394624710083|\n\/\/ |development|0.7730562090873718|\n\/\/ |       oier|0.6866029500961304|\n\/\/ |  develover|0.6720684766769409|\n\/\/ |     webdev|0.6582568883895874|\n\/\/ +-----------+------------------+<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#word2vec\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#word2vec<\/a><\/p>\n<h2>Calculate the Pearson Correlation between Features<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.VectorAssembler\nimport org.apache.spark.ml.linalg.Matrix\nimport org.apache.spark.ml.stat.Correlation\nimport org.apache.spark.sql.Row\n\nval featureNames = Array(\"stargazers_count\", \"forks_count\", \"subscribers_count\")\nval vectorAssembler = new VectorAssembler()\n  .setInputCols(featureNames)\n  .setOutputCol(\"features\")\n\nval df = vectorAssembler.transform(rawRepoInfoDS)\nval correlationDF = Correlation.corr(df, \"features\")\nval Row(coeff: Matrix) = correlationDF.head\n\nprintln(featureNames.mkString(\", \"))\nprintln(coeff.toString)\n\/\/ stargazers_count, forks_count, subscribers_count\n\/\/ 1.0                 0.5336901230713282  0.7664204175159971  \n\/\/ 0.5336901230713282  1.0                 0.5414244966152617  \n\/\/ 0.7664204175159971  0.5414244966152617  1.0<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-statistics.html\">https:\/\/spark.apache.org\/docs\/latest\/ml-statistics.html<\/a><\/p>\n<h2>DIMSUM<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry}\n\nval repoWordRDD = repoVectorDF\n  .select($\"repo_id\", $\"text_w2v\")\n  .rdd\n  .flatMap((row: Row) =&gt; {\n    val repoId = row.getInt(0)\n    val vector = row.getAs[DenseVector](1)\n    vector.toArray.zipWithIndex.map({\n      case (element, index) =&gt; MatrixEntry(repoId, index, element)\n    })\n  })\nval repoWordMatrix = new CoordinateMatrix(repoWordRDD)\nval wordRepoMatrix = repoWordMatrix.transpose\n\nval repoSimilarityRDD = wordRepoMatrix\n  .toRowMatrix\n  .columnSimilarities(0.1)\n  .entries\n  .flatMap({\n    case MatrixEntry(row: Long, col: Long, sim: Double) =&gt; {\n      if (sim &gt;= 0.5) {\n        Array((row, col, sim))\n      }\n      else {\n        None\n      }\n    }\n  })\nspark.createDataFrame(repoSimilarityRDD).toDF(\"item_1\", \"item_2\", \"similarity\")\nrepoSimilarityDF.show(false)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/42455725\/columnsimilarities-back-to-spark-data-frame\">https:\/\/stackoverflow.com\/questions\/42455725\/columnsimilarities-back-to-spark-data-frame<\/a><br \/>\n<a href=\"https:\/\/forums.databricks.com\/questions\/248\/when-should-i-use-rowmatrixcolumnsimilarities.html\">https:\/\/forums.databricks.com\/questions\/248\/when-should-i-use-rowmatrixcolumnsimilarities.html<\/a><\/p>\n<h2>Train a Locality Sensitive Hashing (LSH) Model: Bucketed Random Projection LSH<\/h2>\n<p>To specify the value of <code>bucketLength<\/code>, if input vectors are normalized, 1-10 times of pow(numRecords, -1\/inputDim) would be a reasonable value. For instance, <code>Math.pow(334913.0, -1.0 \/ 200.0) = 0.9383726472256705<\/code>.<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.BucketedRandomProjectionLSH\nimport org.apache.spark.ml.linalg.Vectors\n\nval userDF = spark.createDataFrame(Seq(\n  (1, Vectors.sparse(6, Seq((0, -4.0), (1, 1.0), (2, 0.2)))),\n  (2, Vectors.sparse(6, Seq((0, 5.5), (1, -0.6), (2, 9.0)))),\n  (3, Vectors.sparse(6, Seq((1, 1.0), (2, 5.3), (4, 3.0)))),\n  (4, Vectors.sparse(6, Seq((1, 1.0), (2, 1.0), (4, 1.0)))),\n  (5, Vectors.sparse(6, Seq((2, 1.0), (5, -0.2)))),\n  (6, Vectors.sparse(6, Seq((0, 0.7)))),\n  (7, Vectors.sparse(6, Seq((1, 0.3), (2, 1.0))))\n)).toDF(\"user_id\", \"features\")\n\nval repoDF = spark.createDataFrame(Seq(\n  (11, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0), (5, 1.0)))),\n  (12, Vectors.sparse(6, Seq((0, 9.0), (1, -2.0), (2, -21.0), (3, 9.0), (4, 1.0), (5, 9.0)))),\n  (13, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, -3.0), (3, 3.0), (4, 7.0), (5, 9.0)))),\n  (14, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, -3.0)))),\n  (15, Vectors.sparse(6, Seq((1, 1.0), (2, 1.0))))\n)).toDF(\"repo_id\", \"features\")\n\nval lsh = new BucketedRandomProjectionLSH()\n  .setBucketLength(0.6812920690579612)\n  .setNumHashTables(4)\n  .setInputCol(\"features\")\n  .setOutputCol(\"hashes\")\nval lshModel = lsh.fit(repoDF)\n\nval hashedUserDF = lshModel.transform(userDF)\nval hashedRepoDF = lshModel.transform(repoDF)\nhashedRepoDF.show(false)\n\/\/ +-------+----------------------------------------------+--------------------------------+\n\/\/ |repo_id|features                                      |hashes                          |\n\/\/ +-------+----------------------------------------------+--------------------------------+\n\/\/ |11     |(6,[0,1,2,3,4,5],[1.0,1.0,1.0,1.0,1.0,1.0])   |[[1.0], [-2.0], [-1.0], [-1.0]] |\n\/\/ |12     |(6,[0,1,2,3,4,5],[9.0,-2.0,-21.0,9.0,1.0,9.0])|[[21.0], [-28.0], [18.0], [0.0]]|\n\/\/ |13     |(6,[0,1,2,3,4,5],[1.0,1.0,-3.0,3.0,7.0,9.0])  |[[4.0], [-10.0], [6.0], [-3.0]] |\n\/\/ |14     |(6,[0,1,2],[1.0,1.0,-3.0])                    |[[2.0], [-3.0], [2.0], [1.0]]   |\n\/\/ |15     |(6,[1,2],[1.0,1.0])                           |[[-1.0], [0.0], [-2.0], [0.0]]  |\n\/\/ +-------+----------------------------------------------+--------------------------------+\n\nval similarDF = lshModel\n  .approxSimilarityJoin(hashedUserDF, hashedRepoDF, 10.0, \"distance\")\n  .select($\"datasetA.user_id\".alias(\"user_id\"), $\"datasetB.repo_id\".alias(\"repo_id\"), $\"distance\")\n  .orderBy($\"user_id\", $\"distance\".asc)\nsimilarDF.show(false)\n\/\/ +-------+-------+------------------+\n\/\/ |user_id|repo_id|distance          |\n\/\/ +-------+-------+------------------+\n\/\/ |1      |15     |4.079215610874228 |\n\/\/ |3      |15     |5.243090691567332 |\n\/\/ |4      |15     |1.0               |\n\/\/ |4      |11     |1.7320508075688772|\n\/\/ |5      |15     |1.019803902718557 |\n\/\/ |5      |11     |2.33238075793812  |\n\/\/ |6      |15     |1.57797338380595  |\n\/\/ |7      |15     |0.7               |\n\/\/ |7      |11     |2.118962010041709 |\n\/\/ +-------+-------+------------------+\n\nval userVector = Vectors.sparse(6, Seq((0, 1.5), (1, 0.8), (2, 2.0)))\nval singleSimilarDF = lshModel\n  .approxNearestNeighbors(hashedRepoDF, userVector, 5, \"distance\")\n  .select($\"repo_id\", $\"features\", $\"distance\")\nsingleSimilarDF.show(false)\n\/\/ +-------+----------------------------------------------+------------------+\n\/\/ |repo_id|features                                      |distance          |\n\/\/ +-------+----------------------------------------------+------------------+\n\/\/ |15     |(6,[1,2],[1.0,1.0])                           |1.8138357147217055|\n\/\/ |12     |(6,[0,1,2,3,4,5],[9.0,-2.0,-21.0,9.0,1.0,9.0])|27.49709075520536 |\n\/\/ +-------+----------------------------------------------+------------------+<\/code><\/pre>\n<p>The problem of <code>approxSimilarityJoin()<\/code> is that you can't control the number of generated items, the disadvantage of <code>approxNearestNeighbors()<\/code> is that you have to manually iterate all users to find similar items. Moreover, both methods can easily suffer from the infamous <code>java.lang.OutOfMemoryError<\/code>.<\/p>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#locality-sensitive-hashing\">https:\/\/spark.apache.org\/docs\/latest\/ml-features.html#locality-sensitive-hashing<\/a><\/p>\n<h2>Train a Locality Sensitive Hashing (LSH) Model: MinHash LSH<\/h2>\n<p>MinHash LSH treats input as a binary vector, that is, all non-zero values (include negative values) are just <code>1<\/code>. Basically, the Word2Vec vector won't be an appropriate input to MinHash LSH.<\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature.MinHashLSH\nimport org.apache.spark.ml.linalg.Vectors\n\nval userDF = spark.createDataFrame(Seq(\n  (1, Vectors.sparse(6, Seq((0, -4.0), (1, 1.0), (2, 0.2)))),\n  (2, Vectors.sparse(6, Seq((0, 5.5), (1, -0.6), (2, 9.0)))),\n  (3, Vectors.sparse(6, Seq((1, 1.0), (2, 5.3), (4, 3.0)))),\n  (4, Vectors.sparse(6, Seq((1, 1.0), (2, 1.0), (4, 1.0)))),\n  (5, Vectors.sparse(6, Seq((2, 1.0), (5, -0.2)))),\n  (6, Vectors.sparse(6, Seq((2, 0.7)))),\n  (7, Vectors.sparse(6, Seq((3, 0.3), (5, 1.0))))\n)).toDF(\"user_id\", \"features\")\n\nval repoDF = spark.createDataFrame(Seq(\n  (11, Vectors.sparse(6, Seq((1, 1.0), (3, 1.0), (5, 1.0)))),\n  (12, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (5, 1.0)))),\n  (13, Vectors.sparse(6, Seq((1, 1.0), (2, 1.0), (4, 1.0))))\n)).toDF(\"repo_id\", \"features\")\n\nval lsh = new MinHashLSH()\n  .setNumHashTables(4)\n  .setInputCol(\"features\")\n  .setOutputCol(\"hashes\")\n\nval lshModel = lsh.fit(userDF)\nval hashedUserDF = lshModel.transform(userDF)\nval hashedRepoDF = lshModel.transform(repoDF)\n\nhashedUserDF.show(false)\n\/\/ user 1 and 2 have the same hashed vector\n\/\/ user 3 and 4 have the same hashed vector\n\/\/ +-------+--------------------------+-----------------------------------------------------------------------+\n\/\/ |user_id|features                  |hashes                                                                 |\n\/\/ +-------+--------------------------+-----------------------------------------------------------------------+\n\/\/ |1      |(6,[0,1,2],[-4.0,1.0,0.2])|[[-2.031299587E9], [-1.974869772E9], [-1.974047307E9], [4.95314097E8]] |\n\/\/ |2      |(6,[0,1,2],[5.5,-0.6,9.0])|[[-2.031299587E9], [-1.974869772E9], [-1.974047307E9], [4.95314097E8]] |\n\/\/ |3      |(6,[1,2,4],[1.0,5.3,3.0]) |[[-2.031299587E9], [-1.974869772E9], [-1.230128022E9], [8.7126731E8]]  |\n\/\/ |4      |(6,[1,2,4],[1.0,1.0,1.0]) |[[-2.031299587E9], [-1.974869772E9], [-1.230128022E9], [8.7126731E8]]  |\n\/\/ |5      |(6,[2,5],[1.0,-0.2])      |[[-2.031299587E9], [-1.758749518E9], [-4.86208737E8], [-1.919887134E9]]|\n\/\/ |6      |(6,[2],[0.7])             |[[-2.031299587E9], [-1.758749518E9], [-4.86208737E8], [1.247220523E9]] |\n\/\/ |7      |(6,[3,5],[0.3,1.0])       |[[-1.278435698E9], [-1.542629264E9], [2.57710548E8], [-1.919887134E9]] |\n\/\/ +-------+--------------------------+-----------------------------------------------------------------------+\n\nval userSimilarRepoDF = lshModel\n  .approxSimilarityJoin(hashedUserDF, hashedRepoDF, 0.6, \"distance\")\n  .select($\"datasetA.user_id\".alias(\"user_id\"), $\"datasetB.repo_id\".alias(\"repo_id\"), $\"distance\")\n  .orderBy($\"user_id\", $\"distance\".asc)\n\nuserSimilarRepoDF.show(false)\n\/\/ +-------+-------+-------------------+\n\/\/ |user_id|repo_id|distance           |\n\/\/ +-------+-------+-------------------+\n\/\/ |1      |13     |0.5                |\n\/\/ |2      |13     |0.5                |\n\/\/ |3      |13     |0.0                |\n\/\/ |4      |13     |0.0                |\n\/\/ |5      |12     |0.33333333333333337|\n\/\/ |7      |12     |0.33333333333333337|\n\/\/ |7      |11     |0.33333333333333337|\n\/\/ +-------+-------+-------------------+<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/databricks.com\/blog\/2017\/05\/09\/detecting-abuse-scale-locality-sensitive-hashing-uber-engineering.html\">https:\/\/databricks.com\/blog\/2017\/05\/09\/detecting-abuse-scale-locality-sensitive-hashing-uber-engineering.html<\/a><\/p>\n<h2>Train a Logistic Regression Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.classification.LogisticRegression\nimport org.apache.spark.ml.linalg.Vectors\n\nval training = spark.createDataFrame(Seq(\n  (1.0, Vectors.dense(1.0, 2.5, 0.0, 0.0)),\n  (1.0, Vectors.dense(0.1, 9.0, 0.0, 0.0)),\n  (1.0, Vectors.dense(0.0, 0.0, 1.0, 0.0)),\n  (0.0, Vectors.dense(0.0, 0.0, 2.0, 9.0)),\n  (0.0, Vectors.dense(1.0, 0.0, 0.0, 5.0))\n)).toDF(\"label\", \"features\")\n\nval lr = new LogisticRegression()\n  .setMaxIter(100)\n  .setRegParam(0.0)\n  .setElasticNetParam(0.0)\n  .setFamily(\"binomial\")\n  .setFeaturesCol(\"features\")\n  .setLabelCol(\"label\")\n\nlr.explainParams()\n\nval lrModel = lr.fit(training)\n\nprintln(s\"Coefficients: ${lrModel.coefficients}\")\n\/\/ [2.0149015925419,2.694173163503675,9.547978766053463,-5.592221425156231]\n\nprintln(s\"Intercept: ${lrModel.intercept}\")\n\/\/ 8.552229795281482\n\nval result = lrModel.transform(test)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-classification-regression.html#logistic-regression\">https:\/\/spark.apache.org\/docs\/latest\/ml-classification-regression.html#logistic-regression<\/a><br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/mllib-linear-methods.html#logistic-regression\">https:\/\/spark.apache.org\/docs\/latest\/mllib-linear-methods.html#logistic-regression<\/a><\/p>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary\n\nval binarySummary = lrModel.summary.asInstanceOf[BinaryLogisticRegressionSummary]\nprintln(s\"Area Under ROC: ${binarySummary.areaUnderROC}\")<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/api\/scala\/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary\">https:\/\/spark.apache.org\/docs\/latest\/api\/scala\/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary<\/a><\/p>\n<h2>Evaluate a Binary Classification Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator\nimport org.apache.spark.ml.linalg.Vectors\n\nval df = spark.createDataFrame(Seq(\n  (Vectors.dense(0.0, 2.5), 1.0), \/\/ correct\n  (Vectors.dense(1.0, 4.1), 1.0), \/\/ correct\n  (Vectors.dense(9.2, 1.1), 0.0), \/\/ correct\n  (Vectors.dense(1.0, 0.1), 0.0), \/\/ correct\n  (Vectors.dense(5.0, 0.5), 1.0)  \/\/ incorrect\n)).toDF(\"rawPrediction\", \"starring\")\n\nval evaluator = new BinaryClassificationEvaluator()\n  .setMetricName(\"areaUnderROC\")\n  .setRawPredictionCol(\"rawPrediction\")\n  .setLabelCol(\"starring\")\nval metric = evaluator.evaluate(df)\n\/\/ 0.8333333333333333<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/api\/scala\/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator\">https:\/\/spark.apache.org\/docs\/latest\/api\/scala\/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator<\/a><\/p>\n<h2>Train an ALS Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.recommendation.ALS\n\nval df = spark.createDataFrame(Seq(\n  (1, 1, 12),\n  (1, 2, 90),\n  (1, 4, 4),\n  (2, 4, 1),\n  (3, 5, 8)\n)).toDF(\"user\", \"item\", \"rating\")\n\nval als = new ALS()\n  .setImplicitPrefs(true)\n  .setRank(5)\n  .setRegParam(0.5)\n  .setAlpha(40)\n  .setMaxIter(10)\n  .setSeed(42)\n  .setColdStartStrategy(\"drop\")\nval alsModel = als.fit(df)\n\nval predictionDF = alsModel.transform(df)\n\/\/ +----+----+------+----------+\n\/\/ |user|item|rating|prediction|\n\/\/ +----+----+------+----------+\n\/\/ |   1|   1|    12| 0.9988487|\n\/\/ |   3|   5|     8| 0.9984464|\n\/\/ |   1|   4|     4|0.99887615|\n\/\/ |   2|   4|     1| 0.9921428|\n\/\/ |   1|   2|    90| 0.9997897|\n\/\/ +----+----+------+----------+\n\npredictionDF.printSchema()\n\/\/ root\n \/\/ |-- user: integer (nullable = false)\n \/\/ |-- item: integer (nullable = false)\n \/\/ |-- rating: integer (nullable = false)\n\/\/ |-- prediction: float (nullable = false)\n\nval userRecommendationsDF = alsModel.recommendForAllUsers(15)\n\/\/ +----+-----------------------------------------------------------------+\n\/\/ |user|recommendations                                                  |\n\/\/ +----+-----------------------------------------------------------------+\n\/\/ |1   |[[2,0.9997897], [4,0.9988761], [1,0.9988487], [5,0.0]]           |\n\/\/ |3   |[[5,0.9984464], [1,2.9802322E-8], [2,0.0], [4,0.0]]              |\n\/\/ |2   |[[4,0.9921428], [2,0.10759391], [1,0.10749264], [5,1.4901161E-8]]|\n\/\/ +----+-----------------------------------------------------------------+\n\nuserRecommendationsDF.printSchema()\n\/\/ root\n \/\/ |-- user: integer (nullable = false)\n \/\/ |-- recommendations: array (nullable = true)\n \/\/ |    |-- element: struct (containsNull = true)\n \/\/ |    |    |-- item: integer (nullable = true)\n\/\/ |    |    |-- rating: float (nullable = true)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-collaborative-filtering.html\">https:\/\/spark.apache.org\/docs\/latest\/ml-collaborative-filtering.html<\/a><\/p>\n<h2>Save and Load an ALS Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.hadoop.mapred.InvalidInputException\nimport org.apache.spark.ml.recommendation.{ALS, ALSModel}\n\nval alsModelSavePath = \".\/spark-data\/20170902\/alsModel.parquet\"\nval alsModel: ALSModel = try {\n  ALSModel.load(alsModelSavePath)\n} catch {\n  case e: InvalidInputException =&gt; {\n    if (e.getMessage().contains(\"Input path does not exist\")) {\n      val als = new ALS()\n        .setImplicitPrefs(true)\n        .setRank(100)\n        .setRegParam(0.5)\n        .setAlpha(40)\n        .setMaxIter(22)\n        .setSeed(42)\n        .setColdStartStrategy(\"drop\")\n        .setUserCol(\"user_id\")\n        .setItemCol(\"repo_id\")\n        .setRatingCol(\"starring\")\n      val alsModel = als.fit(rawRepoStarringDS)\n      alsModel.save(alsModelSavePath)\n      alsModel\n    } else {\n      throw e\n    }\n  }\n}<\/code><\/pre>\n<h2>Create a Custom Transformer<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">package ws.vinta.albedo.transformers\n\nimport org.apache.spark.broadcast.Broadcast\nimport org.apache.spark.ml.Transformer\nimport org.apache.spark.ml.param.{DoubleParam, Param, ParamMap}\nimport org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}\nimport org.apache.spark.sql.types._\nimport org.apache.spark.sql.{DataFrame, Dataset, Row}\n\nimport scala.collection.mutable\n\nclass NegativeBalancer(override val uid: String, val bcPopularItems: Broadcast[mutable.LinkedHashSet[Int]])\n  extends Transformer with DefaultParamsWritable {\n\n  def this(bcPopularItems: Broadcast[mutable.LinkedHashSet[Int]]) = {\n    this(Identifiable.randomUID(\"negativeBalancer\"), bcPopularItems)\n  }\n\n  val userCol = new Param[String](this, \"userCol\", \"User \u6240\u5728\u7684\u6b04\u4f4d\u540d\u7a31\")\n\n  def getUserCol: String = $(userCol)\n\n  def setUserCol(value: String): this.type = set(userCol, value)\n  setDefault(userCol -&gt; \"user\")\n\n  val itemCol = new Param[String](this, \"itemCol\", \"Item \u6240\u5728\u7684\u6b04\u4f4d\u540d\u7a31\")\n\n  def getItemCol: String = $(itemCol)\n\n  def setItemCol(value: String): this.type = set(itemCol, value)\n  setDefault(itemCol -&gt; \"item\")\n\n  val labelCol = new Param[String](this, \"labelCol\", \"Label \u6240\u5728\u7684\u6b04\u4f4d\u540d\u7a31\")\n\n  def getLabelCol: String = $(labelCol)\n\n  def setLabelCol(value: String): this.type = set(labelCol, value)\n  setDefault(labelCol -&gt; \"label\")\n\n  val negativeValue = new DoubleParam(this, \"negativeValue\", \"\u8ca0\u6a23\u672c\u7684\u503c\")\n\n  def getNegativeValue: Double = $(negativeValue)\n\n  def setNegativeValue(value: Double): this.type = set(negativeValue, value)\n  setDefault(negativeValue -&gt; 0.0)\n\n  val negativePositiveRatio = new DoubleParam(this, \"negativePositiveRatio\", \"\u8ca0\u6a23\u672c\u8207\u6b63\u6a23\u672c\u7684\u6bd4\u4f8b\")\n\n  def getNegativePositiveRatio: Double = $(negativePositiveRatio)\n\n  def setNegativePositiveRatio(value: Double): this.type = set(negativePositiveRatio, value)\n  setDefault(negativePositiveRatio -&gt; 1.0)\n\n  override def transformSchema(schema: StructType): StructType = {\n    Map($(userCol) -&gt; IntegerType, $(itemCol) -&gt; IntegerType, $(labelCol) -&gt; DoubleType)\n      .foreach{\n        case(columnName: String, expectedDataType: DataType) =&gt; {\n          val actualDataType = schema(columnName).dataType\n          require(actualDataType.equals(IntegerType), s\"Column $columnName must be of type $expectedDataType but was actually $actualDataType.\")\n        }\n      }\n\n    schema\n  }\n\n  override def transform(dataset: Dataset[_]): DataFrame = {\n    transformSchema(dataset.schema)\n\n    val popularItems: mutable.LinkedHashSet[Int] = this.bcPopularItems.value\n\n    val emptyItemSet = new mutable.HashSet[Int]\n    val addToItemSet = (itemSet: mutable.HashSet[Int], item: Int) =&gt; itemSet += item\n    val mergeItemSets = (set1: mutable.HashSet[Int], set2: mutable.HashSet[Int]) =&gt; set1 ++= set2\n\n    val getUserNegativeItems = (userItemsPair: (Int, mutable.HashSet[Int])) =&gt; {\n      val (user, positiveItems) = userItemsPair\n      val negativeItems = popularItems.diff(positiveItems)\n      val requiredNegativeItemsCount = (positiveItems.size * this.getNegativePositiveRatio).toInt\n      (user, negativeItems.slice(0, requiredNegativeItemsCount))\n    }\n    val expandNegativeItems = (userItemsPair: (Int, mutable.LinkedHashSet[Int])) =&gt; {\n      val (user, negativeItems) = userItemsPair\n      negativeItems.map({(user, _, $(negativeValue))})\n    }\n\n    import dataset.sparkSession.implicits._\n\n    \/\/ TODO: \u76ee\u524d\u662f\u5047\u8a2d\u50b3\u9032\u4f86\u7684 dataset \u90fd\u662f positive samples\uff0c\u4e4b\u5f8c\u53ef\u80fd\u5f97\u8655\u7406\u542b\u6709 negative samples \u7684\u60c5\u6cc1\n    val negativeDF = dataset\n      .select($(userCol), $(itemCol))\n      .rdd\n      .map({\n        case Row(user: Int, item: Int) =&gt; (user, item)\n      })\n      .aggregateByKey(emptyItemSet)(addToItemSet, mergeItemSets)\n      .map(getUserNegativeItems)\n      .flatMap(expandNegativeItems)\n      .toDF($(userCol), $(itemCol), $(labelCol))\n\n    dataset.select($(userCol), $(itemCol), $(labelCol)).union(negativeDF)\n  }\n\n  override def copy(extra: ParamMap): this.type = {\n    defaultCopy(extra)\n  }\n}<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/www.safaribooksonline.com\/library\/view\/high-performance-spark\/9781491943199\/ch09.html#extending_spark_ml\">https:\/\/www.safaribooksonline.com\/library\/view\/high-performance-spark\/9781491943199\/ch09.html#extending_spark_ml<\/a><br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/40615713\/how-to-write-a-custom-transformer-in-mllib\">https:\/\/stackoverflow.com\/questions\/40615713\/how-to-write-a-custom-transformer-in-mllib<\/a><br \/>\n<a href=\"https:\/\/issues.apache.org\/jira\/browse\/SPARK-17048\">https:\/\/issues.apache.org\/jira\/browse\/SPARK-17048<\/a><\/p>\n<h2>Create a Custom Evaluator<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">package ws.vinta.albedo.evaluators\n\nimport org.apache.spark.ml.evaluation.Evaluator\nimport org.apache.spark.ml.param.{Param, ParamMap}\nimport org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}\nimport org.apache.spark.mllib.evaluation.RankingMetrics\nimport org.apache.spark.sql.{DataFrame, Dataset, Row}\n\nclass RankingEvaluator(override val uid: String, val userActualItemsDF: DataFrame)\n  extends Evaluator with DefaultParamsWritable {\n\n  def this(userActualItemsDF: DataFrame) = {\n    this(Identifiable.randomUID(\"rankingEvaluator\"), userActualItemsDF)\n  }\n\n  val metricName = new Param[String](this, \"metricName\", \"\u8a55\u4f30\u65b9\u5f0f\")\n\n  def getMetricName: String = $(metricName)\n\n  def setMetricName(value: String): this.type = set(metricName, value)\n  setDefault(metricName -&gt; \"ndcg@k\")\n\n  val k = new Param[Int](this, \"k\", \"\u53ea\u8a55\u4f30\u524d k \u500b items \u7684\u6392\u5e8f\u7d50\u679c\")\n\n  def getK: Int = $(k)\n\n  def setK(value: Int): this.type = set(k, value)\n  setDefault(k -&gt; 15)\n\n  override def isLargerBetter: Boolean = $(metricName) match {\n    case \"map\" =&gt; true\n    case \"ndcg@k\" =&gt; true\n    case \"precision@k\" =&gt; true\n  }\n\n  override def evaluate(dataset: Dataset[_]): Double = {\n    import dataset.sparkSession.implicits._\n\n    val userPredictedItemsDF = dataset.select($\"user_id\", $\"recommendations.repo_id\".alias(\"items\"))\n\n    val bothItemsRDD = userPredictedItemsDF.join(userActualItemsDF, Seq(\"user_id\", \"user_id\"))\n      .select(userPredictedItemsDF.col(\"items\"), userActualItemsDF.col(\"items\"))\n      .rdd\n      .map((row: Row) =&gt; {\n        \/\/ Row(userPredictedItems, userActualItems)\n        (row(0).asInstanceOf[Seq[Int]].toArray, row(1).asInstanceOf[Seq[Int]].toArray)\n      })\n\n    val rankingMetrics = new RankingMetrics(bothItemsRDD)\n    val metric = $(metricName) match {\n      case \"map\" =&gt; rankingMetrics.meanAveragePrecision\n      case \"ndcg@k\" =&gt; rankingMetrics.ndcgAt($(k))\n      case \"precision@k\" =&gt; rankingMetrics.precisionAt($(k))\n    }\n    metric\n  }\n\n  override def copy(extra: ParamMap): RankingEvaluator = {\n    defaultCopy(extra)\n  }\n}<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/mllib-evaluation-metrics.html#ranking-systems\">https:\/\/spark.apache.org\/docs\/latest\/mllib-evaluation-metrics.html#ranking-systems<\/a><br \/>\n<a href=\"https:\/\/www.safaribooksonline.com\/library\/view\/spark-the-definitive\/9781491912201\/ch19.html#s6c5---recommendation\">https:\/\/www.safaribooksonline.com\/library\/view\/spark-the-definitive\/9781491912201\/ch19.html#s6c5---recommendation<\/a><\/p>\n<h2>Apply Transformer on Multiple Columns<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.feature._\n\nval userCategoricalColumnNames = Array(\"account_type\", \"clean_company\", \"clean_email\", \"clean_location\")\nval userCategoricalTransformers = userCategoricalColumnNames.flatMap((columnName: String) =&gt; {\n  val stringIndexer = new StringIndexer()\n    .setInputCol(columnName)\n    .setOutputCol(s\"${columnName}_index\")\n    .setHandleInvalid(\"keep\")\n  val oneHotEncoder = new OneHotEncoder()\n    .setInputCol(s\"${columnName}_index\")\n    .setOutputCol(s\"${columnName}_ohe\")\n    .setDropLast(true)\n  Array(stringIndexer, oneHotEncoder)\n})\nuserCategoricalTransformers.foreach(println)\n\/\/ strIdx_4029f57e379a\n\/\/ oneHot_f0decb92a05c\n\/\/ strIdx_fb855ad6caaa\n\/\/ oneHot_f1be19344002\n\/\/ strIdx_7fa62a683293\n\/\/ oneHot_097ae442d8fc\n\/\/ strIdx_0ff7ffa022a1\n\/\/ oneHot_4a9f72a7f5d8<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/34167105\/using-spark-mls-onehotencoder-on-multiple-columns\">https:\/\/stackoverflow.com\/questions\/34167105\/using-spark-mls-onehotencoder-on-multiple-columns<\/a><\/p>\n<h2>Cross-validate a Pipeline Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.classification.LogisticRegression\nimport org.apache.spark.ml.evaluation.BinaryClassificationEvaluator\nimport org.apache.spark.ml.feature.VectorAssembler\nimport org.apache.spark.ml.Pipeline\nimport org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}\n\nval vectorAssembler = new VectorAssembler()\n  .setInputCols(Array(\"feature1\", \"feature2\", \"feature3\"))\n  .setOutputCol(\"features\")\n\nval lr = new LogisticRegression()\n  .setFeaturesCol(\"features\")\n  .setLabelCol(\"starring\")\n\nval pipeline = new Pipeline()\n  .setStages(Array(vectorAssembler, lr))\n\nval paramGrid = new ParamGridBuilder()\n  .addGrid(lr.maxIter, Array(20, 100))\n  .addGrid(lr.regParam, Array(0.0, 0.5, 1.0, 2.0))\n  .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))\n  .build()\n\nval evaluator = new BinaryClassificationEvaluator()\n  .setMetricName(\"areaUnderROC\")\n  .setRawPredictionCol(\"rawPrediction\")\n  .setLabelCol(\"starring\")\n\nval cv = new CrossValidator()\n  .setEstimator(pipeline)\n  .setEstimatorParamMaps(paramGrid)\n  .setEvaluator(evaluator)\n  .setNumFolds(3)\n\nval cvModel = cv.fit(trainingDF)<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/spark.apache.org\/docs\/latest\/ml-tuning.html#cross-validation\">https:\/\/spark.apache.org\/docs\/latest\/ml-tuning.html#cross-validation<\/a><\/p>\n<h2>Extract Best Parameters from a Cross-validation Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.PipelineModel\nimport org.apache.spark.ml.classification.LogisticRegressionModel\n\nval bestPipelineModel = cvModel.bestModel.asInstanceOf[PipelineModel]\nval lrModel = bestPipelineModel.stages(0).asInstanceOf[LogisticRegressionModel]\nlrModel.extractParamMap()\n\/\/ or\nlrModel.explainParams()<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/31749593\/how-to-extract-best-parameters-from-a-crossvalidatormodel\">https:\/\/stackoverflow.com\/questions\/31749593\/how-to-extract-best-parameters-from-a-crossvalidatormodel<\/a><\/p>\n<h2>Show All Parameters of a Cross-validation Model<\/h2>\n<pre class=\"line-numbers\"><code class=\"language-scala\">import org.apache.spark.ml.param.ParamMap\n\ncvModel.getEstimatorParamMaps\n  .zip(cvModel.avgMetrics)\n  .sortWith(_._2 &gt; _._2)\n  .foreach((pair: (ParamMap, Double)) =&gt; {\n    println(s\"${pair._2}: ${pair._1}\")\n  })\n\/\/ 0.8999999999999999: {\n\/\/     hashingTF_ac8be8d5806b-numFeatures: 1000,\n\/\/     logreg_9f79de6e51ec-regParam: 0.1\n\/\/ }\n\/\/ 0.8875: {\n\/\/     hashingTF_ac8be8d5806b-numFeatures: 100,\n\/\/     logreg_9f79de6e51ec-regParam: 0.1\n\/\/ }\n\/\/ 0.875: {\n\/\/     hashingTF_ac8be8d5806b-numFeatures: 100,\n\/\/     logreg_9f79de6e51ec-regParam: 0.01\n\/\/ }<\/code><\/pre>\n<p>ref:<br \/>\n<a href=\"https:\/\/stackoverflow.com\/questions\/31749593\/how-to-extract-best-parameters-from-a-crossvalidatormodel\">https:\/\/stackoverflow.com\/questions\/31749593\/how-to-extract-best-parameters-from-a-crossvalidatormodel<\/a><br \/>\n<a href=\"https:\/\/alvinalexander.com\/scala\/how-sort-scala-sequences-seq-list-array-buffer-vector-ordering-ordered\">https:\/\/alvinalexander.com\/scala\/how-sort-scala-sequences-seq-list-array-buffer-vector-ordering-ordered<\/a><\/p>\n","protected":false},"excerpt":{"rendered":"<p>Scala is the first class citizen language for interacting with Apache Spark, but it's difficult to learn. This article is mostly about Spark ML - the new Spark Machine Learning library which was rewritten in DataFrame-based API.<\/p>\n","protected":false},"author":1,"featured_media":432,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[97,112],"tags":[108,111,98,104,109],"class_list":["post-431","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-about-ai","category-about-big-data","tag-apache-spark","tag-feature-engineering","tag-machine-learning","tag-recommender-system","tag-scala"],"_links":{"self":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts\/431","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/comments?post=431"}],"version-history":[{"count":0,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/posts\/431\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/media\/432"}],"wp:attachment":[{"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/media?parent=431"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/categories?post=431"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/vinta.ws\/code\/wp-json\/wp\/v2\/tags?post=431"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}