explode a row of spark dataset into several rows with added column using flatmap

1.5k views Asked by At

I have a DataFrame with the following schema :

root
 |-- journal: string (nullable = true)
 |-- topicDistribution: vector (nullable = true)

The topicDistribution field is a vector of doubles: [0.1, 0.2 0.15 ...]

What I want is is to explode each row into several rows to obtain the following schema:

root
 |-- journal: string
 |-- topic-prob: double // this is the value from the vector
 |-- topic-id : integer // this is the index of the value from the vector

To clarify, I've created a case class:

case class JournalDis(journal: String, topic_id: Integer, prob: Double)

I've managed to achieve this using dataset.explode in a very awkward way:

val df1 = df.explode("topicDistribution", "topic") {
    topics: DenseVector => topics.toArray.zipWithIndex
}.select("journal", "topic")
val df2 = df1.withColumn("topic_id", df1("topic").getItem("_2")).withColumn("topic_prob", df1("topic").getItem("_1")).drop(df1("topic"))

But dataset.explode is deprecated. I wonder how to achieve this using flatmap method?

1

There are 1 answers

5
user7337271 On

Not tested but should work:

import spark.implicits._
import org.apache.spark.ml.linalg.Vector

df.as[(String, Vector)].flatMap { 
  case (j, ps) => ps.toArray.zipWithIndex.map { 
    case (p, ti) => JournalDis(j, ti, p)
  }
}