I want to extract keywords using pyspark.ml.feature.CountVectorizer
.
My input Spark dataframe looks as following:
id | text |
---|---|
1 | sun, mars, solar system, solar system, mars, solar system, venus, solar system, mars |
2 | planet, moon, milky way, milky way, moon, milky way, sun, milky way, mars, star |
I applied the following pipeline:
# Convert string to array
input_df = input_df.withColumn("text_array", split("text", ','))
cv_text = CountVectorizer() \
.setInputCol("text_array") \
.setOutputCol("cv_text")
cv_model = cv_text.fit(input_df)
cv_result = cv_model.transform(input_df)
cv_result.show()
Output:
id | text | text_array | cv_text |
---|---|---|---|
1 | sun, mars, solar system, .. | [sun, mars, solar system, .. | (9,[1,2,4,7],[3.0,4.0,1.0,1.0]) |
2 | planet, moon, milky way, .. | [planet, moon, milky way, .. | (9,[0,1,3,5,6,8],[4.0,1.0,2.0,1.0,1.0,1.0]) |
How can I now get for each id
(for each row) top n keywords (top 2, for example)?
Expected output:
id | text | text_array | cv_text | keywords |
---|---|---|---|---|
1 | sun, mars, solar system, .. | [sun, mars, solar system, .. | (9,[1,2,4,7],[3.0,4.0,1.0,1.0]) | solar system, mars |
2 | planet, moon, milky way, .. | [planet, moon, milky way, .. | (9,[0,1,3,5,6,8],[4.0,1.0,2.0,1.0,1.0,1.0]) | milky way, moon |
I will be very grateful for any advice, docs, examples!
I haven't found a way to work with Sparse Vectors besides very few operations in the
pyspark.ml.feature
module so for something like taking the top n values I would say a UDF is the way to go.The function below uses
np.argpartition
to find the top n values of vector values and return their indices which conveniently we can put in the vector indices to get the values.The values returned are the vocabulary index and not the actual word. If the vocabulary is not that big we can put it as an array column of its own and transform the idx to the actual word.
I'm not sure I feel that good with the solution above, probably not scalable. That being said, if you don't actually need the CountVectorizer , there is a combination of standard functions we can do on the
input_df
to simply get the top_n words of every sentence.