I have following dataset, where only title will be features and only category_id (already int) will be label. Ignore category_text for now.
category_id,title,category_text
12321332,"drill bit","drilling"
23432212,"class plug","electrical tools"
34567789,"laptop","computers"
I'm able to train it as follows.
But prediction where I can't make it work. I cannot use tainandtest variable as I'm not just testing the model, I wanted to use it in production like scenario where I send user input like drill bit to model and expecting it to return category_id: 12321332
Problem1:
When I try to build my own vector for the input drill bit and try in predict I'm getting Index 9 out of bounds [0, 1)
Problem2: The Predict method returning double, I can't find a good documentation for Java to predict label as string, appreciate any insight on it.
Fully runnable code
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.RandomForest;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
public class SearchML {
public static void main(String...s) {
SparkSession spark = SparkSession.builder()
.master("local")
.appName("RandomForestClassifierExample")
.getOrCreate();
StructType schema = DataTypes.createStructType(new StructField[]{
DataTypes.createStructField("category_id", DataTypes.IntegerType, false),
DataTypes.createStructField("title", DataTypes.StringType, false),
DataTypes.createStructField("category_text", DataTypes.StringType, false)
});
var dataFrame = spark.read().format("csv")
.option("header", "true")
.option("delimiter", ",")
.option("mode", "DROPMALFORMED")
.option("quote", "\"")
.schema(schema)
.load("csv_entry_slim.csv")
.cache();
dataFrame = dataFrame.groupBy("title", "category_id").agg(functions.collect_list("title").alias("titleArray"));
CountVectorizer cv = new CountVectorizer()
.setInputCol("titleArray")
.setOutputCol("features");
cv.setMaxDF(1);
cv.setVocabSize(5000);
dataFrame = cv.fit(dataFrame).transform(dataFrame);
var indexer = new StringIndexer()
.setInputCol("category_id")
.setOutputCol("label");
dataFrame = indexer.setHandleInvalid("skip").fit(dataFrame).transform(dataFrame);
var seed = 5043;
var tainAndTest = dataFrame.randomSplit(new double[]{0.99999, 0.00001}, seed);
// train Random Forest model with training data set
var randomForestClassifier = new RandomForestClassifier()
.setImpurity("gini")
.setMaxDepth(3)
.setNumTrees(20)
.setFeatureSubsetStrategy("auto")
.setMaxBins(10)
.setSeed(seed);
var randomForestModel = randomForestClassifier.fit(tainAndTest[0]);
// PREDICTION WHERE I CAN'T MAKE IT WORK. I CANNOT USE tainAndTest AS I'M NOT JUST TESTING THE MODEL, I WANTED TO USE IT IN PRODUCTION LIKE SCENARIO WHERE I SEND USER INPUT LIKE drill bit TO MODEL AND EXPECTING IT TO RETURN CATEGORY_ID: 12321332
List<String> stringAsList = new ArrayList<>();
stringAsList.add("drill bit");
StructType schemaTest = DataTypes.createStructType(new StructField[]{
DataTypes.createStructField("title", DataTypes.StringType, false)
});
JavaSparkContext sparkContext = new JavaSparkContext(spark.sparkContext());
JavaRDD<org.apache.spark.sql.Row> rowRDD = sparkContext.parallelize(stringAsList).map((String row) -> RowFactory.create(row));
Dataset<org.apache.spark.sql.Row> userInputDataFrame = spark.sqlContext().createDataFrame(rowRDD, schemaTest).toDF();
userInputDataFrame = userInputDataFrame.groupBy("title").agg(functions.collect_list("title").alias("titleArray"));
userInputDataFrame = cv.fit(userInputDataFrame).transform(userInputDataFrame);
System.out.println(userInputDataFrame);
Vector v = (Vector) scala.collection.JavaConverters.seqAsJavaList(userInputDataFrame.collectAsList().get(0).toSeq()).get(2);
var testRes = randomForestModel.predict(v);
System.out.println(testRes);
}
}