Spark Dataset.groupBy as input to Spark ML Pipeline.fit

58 views Asked by At

Given: A big Dataset (1 billion+ records) from a DeltaTable on Databricks

I want to partition this Dataset in +- 1000 different partitions, dependent on some properties of each record. Then, I want to fit a Spark ML Pipeline for each of these partitions.

My first idea would be to use Dataset.groupByKey, and then mapGroups. However, the latter function provides me an Iterator instead of a Dataset. PipelineModel.fit only accepts a Dataset as input.

dataset.groupByKey(item => (item.propertyOne, item.propertyTwo))(product)
  .mapGroups((key, group) => {
     val pipeline: Pipeline = // ...
     pipeline.fit(group) // Doesn't work since group is of type Iterator, not Dataset
   })(product)
  .foreach(_.save(some_path))

The alternative implementation I have now is to iterate and each time filter on the original Dataset, which I then pass to the PipelineModel.fit method. This works, but is awkwardly slow. I played around with multithreading this, but to no avail.

// 1000 possible values
allPossiblePropertyValues.forEach { v =>
  pipeline.fit(dataset.filter(_.property == v)).save(some_path)
}
0

There are 0 answers