How to get an interaction term (multiplication) between two columns using Interaction transformer in Pyspark?

42 views Asked by At

I have a dataframe like this:

+----+-----------+
|flag|probability|
+----+-----------+
| 1.0|  [0.5,0.5]|
| 0.0|  [0.9,0.1]|
| 1.0|  [0.2,0.8]|
+----+-----------+

It's created using:

from pyspark.ml.linalg import Vectors

data = [(1.0, Vectors.dense([0.5, 0.5])),(0.0, Vectors.dense([0.9, 0.1])),(1.0, Vectors.dense([0.2, 0.8]))]
df = spark.createDataFrame(data, ["is_dlc", "probability"])
df.show()

I want to get a multiplication value between flag column and class-1 probability using Interaction transformer.

For that, I first use a vector slicer to extract class-1 probability and then use Interaction like this:

from pyspark.ml.feature import VectorSlicer, VectorAssembler, Interaction

slicer = VectorSlicer(inputCol="probability", outputCol="class_1_probability", indices=[1])
sliced_df = slicer.transform(df)

sliced_df.select("flag", "class_1_probability").show()

assembler = VectorAssembler(inputCols=["flag", "class_1_probability"], outputCol="features_for_interaction")
assembled_df = assembler.transform(sliced_df)

interaction = Interaction(inputCols=["features_for_interaction"], outputCol="interaction_value")
interaction_df = interaction.transform(assembled_df)

interaction_df.select("flag", "class_1_probability", "interaction_value").show()

But instead of getting a column with multiplication value, I just get back a vector with two values without multiplication like this:

+----+-------------------+-----------------+
|flag|class_1_probability|interaction_value|
+----+-------------------+-----------------+
| 1.0|              [0.5]|        [1.0,0.5]|
| 0.0|              [0.1]|        [0.0,0.1]|
| 1.0|              [0.8]|        [1.0,0.8]|
+----+-------------------+-----------------+

The documentation clearly mentions that:

This transformer takes in Double and Vector type columns and outputs a flattened vector of their feature interactions.

Not sure what am I missing here!

0

There are 0 answers