I want to import a trained pyspark model (or pipeline) into a pyspark script. I trained a decision tree model like so:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
# Create assembler and labeller for spark.ml format preperation
assembler = VectorAssembler(inputCols = requiredFeatures, outputCol = 'features')
label_indexer = StringIndexer(inputCol='measurement_status', outputCol='indexed_label')
# Apply transformations
eq_df_labelled = label_indexer.fit(eq_df).transform(eq_df)
eq_df_labelled_featured = assembler.transform(eq_df_labelled)
# Split into training and testing datasets
(training_data, test_data) = eq_df_labelled_featured.randomSplit([0.75, 0.25])
# Create a decision tree algorithm
dtree = DecisionTreeClassifier(
labelCol ='indexed_label',
featuresCol = 'features',
maxDepth = 5,
minInstancesPerNode=1,
impurity = 'gini',
maxBins=32,
seed=None
)
# Fit classifier object to training data
dtree_model = dtree.fit(training_data)
# Save model to given directory
dtree_model.save("models/dtree")
All of the code above works without any erros. The problem is, when I try to load this model (on the same or on another pyspark application), using:
from pyspark.ml.classification import DecisionTreeClassifier
imported_model = DecisionTreeClassifier()
imported_model.load("models/dtree")
I get the following error:
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-4-b283bc2da75f> in <module>
2
3 imported_model = DecisionTreeClassifier()
----> 4 imported_model.load("models/dtree")
5
6 #lodel = DecisionTreeClassifier.load("models/dtree-test/")
~/.local/lib/python3.6/site-packages/pyspark/ml/util.py in load(cls, path)
328 def load(cls, path):
329 """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
--> 330 return cls.read().load(path)
331
332
~/.local/lib/python3.6/site-packages/pyspark/ml/util.py in load(self, path)
278 if not isinstance(path, basestring):
279 raise TypeError("path should be a basestring, got type %s" % type(path))
--> 280 java_obj = self._jread.load(path)
281 if not hasattr(self._clazz, "_from_java"):
282 raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
~/.local/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
1303 answer = self.gateway_client.send_command(command)
1304 return_value = get_return_value(
-> 1305 answer, self.gateway_client, self.target_id, self.name)
1306
1307 for temp_arg in temp_args:
~/.local/lib/python3.6/site-packages/pyspark/sql/utils.py in deco(*a, **kw)
126 def deco(*a, **kw):
127 try:
--> 128 return f(*a, **kw)
129 except py4j.protocol.Py4JJavaError as e:
130 converted = convert_exception(e.java_exception)
~/.local/lib/python3.6/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o39.load.
: java.lang.UnsupportedOperationException: empty collection
at org.apache.spark.rdd.RDD.$anonfun$first$1(RDD.scala:1439)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:388)
at org.apache.spark.rdd.RDD.first(RDD.scala:1437)
at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:587)
at org.apache.spark.ml.util.DefaultParamsReader.load(ReadWrite.scala:465)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
I went for this approach because it also didnt work using a Pipeline
object. Any ideas about what is happening?
UPDATE
I have realised that this error only occurs when I work with my Spark cluster (one master, two workers using Spark's standalone cluster manager). If I set Spark Session like so (where the master is set to the local one):
spark = SparkSession\
.builder\
.config(conf=conf)\
.appName("MachineLearningTesting")\
.master("local[*]")\
.getOrCreate()
I do not get the above error.
Also, I am using Spark 3.0.0, could it be that model importing and exporting in Spark 3 still has bugs?
There were two problems:
SSH authenticated communication must be enabled between all nodes in the cluster. Even though all nodes in my Spark cluster are in the same network, only the master had SSH authentication to the workers and not vise versa.
The model must be available to all nodes in the cluster. This may sound really obvious but I thought that the model files need to only be available to the master who then diffuses this to the worker nodes. In other words, when you load the model like so:
The file
/absoloute_path/models/dtree
must exist on every machine in the cluster. This made me understand that in production contexts, the models are probably accessed via an external shared file system.These two steps solved my problem of loading pyspark models into a Spark application running on a cluster.