I am having trouble writing a custom predict method using MLFlow and pyspark (2.4.0). What I have so far is a custom transformer that changes the data into the format I need.
class CustomGroupBy(Transformer):
    def __init__(self):
        pass
    def _transform(self, dataset):
        df = dataset.select("userid", explode(split("widgetid", ',')).alias("widgetid"))
        return(df)
Then I built a custom estimator to run one of the pyspark machine learning algorithms
class PipelineFPGrowth(Estimator, HasInputCol, DefaultParamsReadable, DefaultParamsWritable): 
    def __init__(self, inputCol=None, minSupport=0.005, minConfidence=0.01):
        super(PipelineFPGrowth, self).__init__()
        self.minSupport = minSupport
        self.minConfidence = minConfidence
    def setInputCol(self, value):
        return(self._set(inputCol=value))
    def _fit(self, dataset):
        c = self.getInputCol() 
        fpgrowth = FPGrowth(itemsCol=c, minSupport=self.minSupport, minConfidence=self.minConfidence)
        model = fpgrowth.fit(dataset)
        return(model)
This runs in the MLFlow pipeline.
pipeline = Pipeline(stages = [CustomGroupBy,PipelineFPGrowth]).fit(df)
This all works. If I create a new pyspark dataframe with new data to predict on, I get predictions.
newDF = spark.createDataFrame([(123456,['123ABC', '789JSF'])], ["userid", "widgetid"])
pipeline.stages[1].transform(newDF).show(3, False)
# How to access frequent itemset.
pipeline.stages[1].freqItemsets.show(3, False)
Where I run into problems is writing a custom predict. I need to append the frequent itemset that FPGrowth generates to the end of the predictions. I have written the logic for that, but I am having a hard time figuring out how to put it into a custom method. I have tried adding it to my custom estimator but this didn't work. Then I wrote a separate class to take in the returned model and give the extended predictions. This was also unsuccessful.
Eventually I need to log and save the model so I can Dockerize it, which means I will need a custom flavor and to use the pyfunc function. Does anyone have a hint on how to extend the predict method and then log and save the model?