I have an LDA model running on corpus size of 12,054 documents with vocab size of 9,681 words and 60 clusters. I am trying to get the topic distribution over documents by calling .topicDistributions() or .javaTopicDistributions(). Both of these methods return an rdd of topic distributions over documents. From my understanding the number of rows should therefore be number of documents and number of columns should be number of topics. But, when I take the count of the rdd after calling topicDistributions(), I get a count of 11,665 (less than number of documents passed to model)? Each document has the correct number of topics (60). Why is this?
Here's the demo: http://spark.apache.org/docs/latest/mllib-clustering.html
and documentation: https://spark.apache.org/docs/1.4.0/api/java/org/apache/spark/mllib/clustering/DistributedLDAModel.html
Here's the code:
enter code here
//parse tf vectors from corpus
JavaRDD<Vector> parsedData = data.map(
new Function<String, Vector>() {
public Vector call(String s) {
s = s.substring(1, s.length()-1);
String[] sarray = s.trim().split(",");
double[] values = new double[sarray.length];
for (int i = 0; i < sarray.length; i++)
{
values[i] = Double.parseDouble(sarray[i]);
}
return Vectors.dense(values);
}
);
System.out.println(parsedData.count()) //prints 12,054
// Index documents with unique IDs
JavaPairRDD<Long, Vector> corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
return doc_id.swap();
}
}
));
System.out.println(corpus.count()) //prints 12,054
LDA lda = new LDA()
LDAModel ldaModel = lda.setK(k.intValue()).run(corpus);
RDD<scala.Tuple2<Object,Vector>> topic_dist_over_docs = ((DistributedLDAModel) ldaModel).topicDistributions();
System.out.println(topic_dist_over_docs.count()) //prints 11,655 ???
JavaPairRDD<Long,Vector> topic_dist_over_docs2 = ((DistributedLDAModel) ldaModel).javaTopicDistributions();
System.out.println(topic_dist_over_docs2.count()) //also prints 11,655 ???
There seems to be a bug in Spark 1.4 with topicDistributions. After updating to the experimental version of Spark 1.5 I was able to resolve this issue.