一文带你理解并实战Spark隐式狄利克雷分布(LDA)

原理 Spark是一个极为优秀的大数据框架,在大数据批处理上基本无人能敌,流处理上也有一席之地,机器学习则是当前正火热AI人工智能的驱动引擎,在大数据场景下如何发挥AI技术成为优秀的大数据挖掘工程师必备技能。

​ 本文采用的组件版本为:Ubuntu 19.10、Jdk 1.8.0_241、Scala 2.11.12、Hadoop 3.2.1、Spark 2.4.5,老规矩先开启一系列Hadoop、Spark服务与Spark-shell窗口:

1.LDA原理介绍

隐式狄利克雷分布(LDA) 是一种主题模型,可以从文本文档集合中推断出主题。可以将LDA视为一种聚类算法,如下所示:

  • 主题对应于聚类中心,文档对应于数据集中的示例(行)。
  • 主题和文档都存在于特征空间中,其中特征向量是字数(词袋)的向量。
  • LDA不会使用传统的距离来估计聚类,而是使用基于文本文件生成方式的统计模型的功能。

DA涉及到的先验知识有:二项分布、Gamma函数、Beta分布、多项分布、Dirichlet分布、马尔科夫链、MCMC、Gibs Sampling、EM算法等。限于篇幅,本文不设计具体理论推导,具体可参考知乎深度文章:https://zhuanlan.zhihu.com/p/31470216

2.LDA参数

LDA通过setOptimizer功能支持不同的推理算法。EMLDAOptimizer使用似然函数的期望最大化学习聚类并产生综合结果,同时 OnlineLDAOptimizer使用迭代小批量采样进行在线变异推断 ,并且通常对内存友好。

LDA将文档集合作为单词计数和以下参数(使用构建器模式设置)的向量:

  • k:主题数(即群集中心)
  • optimizer:用于学习LDA模型的优化程序, EMLDAOptimizer或者OnlineLDAOptimizer
  • docConcentration:Dirichlet参数,用于优先于主题的文档分布。较大的值鼓励更平滑的推断分布。
  • topicConcentration:Dirichlet参数,用于表示主题(单词)在先主题的分布。较大的值鼓励更平滑的推断分布。
  • maxIterations:限制迭代次数。
  • checkpointInterval:如果使用检查点(在Spark配置中设置),则此参数指定创建检查点的频率。如果maxIterations很大,使用检查点可以帮助减少磁盘上的随机文件大小,并有助于故障恢复。

所有spark.mllib的LDA模型都支持:

  • describeTopics:以最重要的术语和术语权重的数组形式返回主题
  • topicsMatrix:返回一个vocabSize由k矩阵,其中各列是一个主题

期望最大化在EMLDAOptimizer和DistributedLDAModel中实现。对于提供给LDA的参数:

  • docConcentration:仅支持对称先验,因此提供的k维向量中的所有值都必须相同。所有值也必须> 1.0。提供Vector(-1)会导致默认行为(值(50 / k)+1的统一k维向量)
  • topicConcentration:仅支持对称先验。值必须> 1.0。提供-1会导致默认值为0.1 + 1。
  • maxIterations:EM迭代的最大数量。
  • 注意:进行足够的迭代很重要。在早期的迭代中,EM通常没有用的主题,但是经过更多的迭代后,这些主题会大大改善。根据您的数据集,通常至少合理使用20次甚至50-100次迭代。
  • EMLDAOptimizer生成一个DistributedLDAModel,它不仅存储推断出的主题,还存储完整的训练语料库和训练语料库中每个文档的主题分布。DistributedLDAModel支持:
  • topTopicsPerDocument:训练语料库中每个文档的主要主题及其权重
  • topDocumentsPerTopic:每个主题的顶部文档以及该主题在文档中的相应权重。
  • logPrior:给定超参数docConcentration和topicConcentration时,估计主题和文档主题分布的对数概率
  • logLikelihood:给定推断的主题和文档主题分布,训练语料库的对数可能性

3.Spark示例

在以下示例中,我们加载代表文档语料库的单词计数向量。然后,我们使用LDA从文档中推断出三个主题。所需聚类的数量传递给算法。然后,我们输出主题,表示为单词上的概率分布。

import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
import org.apache.spark.mllib.linalg.Vectors
// 加载和解析数据
val data = sc.textFile("data/mllib/sample_lda_data.txt")
val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
// 用唯一ID索引文章
val corpus = parsedData.zipWithIndex.map(_.swap).cache()
// 使用LDA将文章聚类为3类主题
val ldaModel = new LDA().setK(3).run(corpus)
//  输出主题。每个主题都是单词分布(匹配单词向量)
println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):")
val topics = ldaModel.topicsMatrix
for (topic <- Range(0, 3)) {
  print(s"Topic $topic :")
  for (word <- Range(0, ldaModel.vocabSize)) {
    print(s"${topics(word, topic)}")
  }
  println()
}
// 保存和加载模型
ldaModel.save(sc, "target/org/apache/spark/LatentDirichletAllocationExample/LDAModel")
val sameModel = DistributedLDAModel.load(sc,
  "target/org/apache/spark/LatentDirichletAllocationExample/LDAModel")

4.源码解析

以上代码主要做了两件事:加载和切分数据、训练模型。在样本数据中,每一行代表一篇文档,经过处理后,corpus的类型为List((id,vector)*),一个(id,vector)代表一篇文档。将处理后的数据传给org.apache.spark.mllib.clustering.LDA类的run方法, 就可以开始训练模型。run方法的代码如下所示:

def run(documents: RDD[(Long, Vector)]): LDAModel = {
    val state = ldaOptimizer.initialize(documents, this)
    var iter = 0
    val iterationTimes = Array.fill[Double](maxIterations)(0)
    while (iter < maxIterations) {
      val start = System.nanoTime()
      state.next()
      val elapsedSeconds = (System.nanoTime() - start) / 1e9
      iterationTimes(iter) = elapsedSeconds
      iter += 1
    }
    state.getLDAModel(iterationTimes)
  }

这段代码首先调用initialize方法初始化状态信息,然后循环迭代调用next方法直到满足最大的迭代次数。在我们没有指定的情况下,迭代次数默认为20。需要注意的是, ldaOptimizer有两个具体的实现类EMLDAOptimizer和OnlineLDAOptimizer,它们分别表示使用EM算法和在线学习算法实现参数估计。在未指定的情况下,默认使用EMLDAOptimizer。

Spark kmeans族的聚类算法的内容至此结束,有关Spark的基础文章可参考前文:

想要入门大数据?这篇文章不得不看!Spark源码分析系列

阿里是怎么做大数据的?淘宝怎么能承载双11?大数据之眸告诉你

Spark分布式机器学习源码分析:如何用分布式集群构建线性模型?

高频面经总结:最全大数据+AI方向面试100题(附答案详解)

Spark分布式机器学习系列:一文带你理解并实战朴素贝叶斯!

Spark分布式机器学习系列:一文带你理解并实战决策树模型!

Spark分布式机器学习系列:一文带你理解并实战集成树模型!

一文带你理解并实战协同过滤!Spark分布式机器学习系列

Spark分布式机器学习源码分析:Kmeans族聚类


参考链接:

http://spark.apache.org/docs/latest/mllib-clustering.html

https://github.com/endymecy/spark-ml-source-analysis

举报
评论 0