Why does Spark Dataset.map require all parts of the query to be serializable?

71 views Asked by At

I would like to use the Dataset.map function to transform the rows of my dataset. The sample looks like this:

val result = testRepository.readTable(db, tableName)
  .map(testInstance.doSomeOperation)
  .count()

where testInstance is a class that extends java.io.Serializable, but testRepository does extend this. The code throws the following error:

Job aborted due to stage failure.
Caused by: NotSerializableException: TestRepository

Question

I understand why testInstance.doSomeOperation needs to be serializable, since it's inside the map and will be distributed to the Spark workers. But why does testRepository needs to be serialized? I don't see why that is necessary for the map. Changing the definition to class TestRepository extends java.io.Serializable solves the issue, but that is not desirable in the larger context of the project.

Is there a way to make this work without making TestRepository serializable, or why is it required to be serializable?

Minimal working example

Here's a full example with the code from both classes that reproduces the NotSerializableException:

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

case class MyTableSchema(id: String, key: String, value: Double)
val db = "temp_autodelete"
val tableName = "serialization_test"

class TestRepository extends java.io.Serializable {
  def readTable(database: String, tableName: String): Dataset[MyTableSchema] = {
    spark.table(f"$database.$tableName")
    .as[MyTableSchema]
  }
}

val testRepository = new TestRepository()

class TestClass() extends java.io.Serializable {
  def doSomeOperation(row: MyTableSchema): MyTableSchema = {
  row 
  }
}

val testInstance = new TestClass()

val result = testRepository.readTable(db, tableName)
  .map(testInstance.doSomeOperation)
  .count()
1

There are 1 answers

0
Koedlt On

The reason why is because your map operation is reading from something that already takes place on the executors.

If you look at your pipeline:

val result = testRepository.readTable(db, tableName)
  .map(testInstance.doSomeOperation)
  .count()

The first thing you do is testRepository.readTable(db, tableName). If we look inside of the readTable method, we see that you are doing a spark.table operation in there. If we look at the function signature of this method from the API docs, we see the following function signature:

def table(tableName: String): DataFrame

This is not an operation that solely takes place on the driver (imagine reading in a file of >1TB while only taking place on the driver), and it creates a Dataframe (which is by itself a distributed dataset). That means that the testRepository.readTable(db, tableName) function needs to be distributed, and so your testRepository object needs to be distributed.

Hope this helps you!