Apache Flink - Prediction Handling

101 views Asked by At

I am currently working with Apache Flink's SVM-Class to predict some text data.

The class provides a predict-function which is taking a DataSet[Vector] as an input and gives me a DataSet[Prediction] as result. So far so good.

My problem is, that i dont have the context which prediction belongs to which text and i cant insert the text within the predict()-function to have it afterwards.

Code:

val tweets: DataSet[(SparseVector, String)] =
        source.flatMap(new SelectEnglishTweetWithCreatedAtFlatMapper)
                .map(tweet => (featureVectorService.transform(tweet._2))

    model.predict(tweets).print


result example:
(SparseVector((462,8.73165920153676), (10844,8.508515650222549), (15656,2.931052542245018)),-1.0)

Is there a way to keep other data next to the prediction to have everything together ? because without context the prediction is not helping me.

Or maybe there is a way to just predict one vector instead of a DataSet, that i could call the function inside the map function above.

1

There are 1 answers

0
Till Rohrmann On BEST ANSWER

The SVM predictor expects as input a sub type of Vector. Hence there are two options to solve this problem:

  1. Create a sub type of Vector which contains the tweet text as a tag. It will then be looped through the predictor. This approach has the advantage that no additional operation is needed. However, one needs define new classes an utilities to represent different vector types with tags:
val env = ExecutionEnvironment.getExecutionEnvironment

val input = env.fromElements("foobar", "barfo", "test")

val vectorizedInput = input.map(word => {
  val value = word.chars().sum()
  new DenseVectorWithTag(Array(value), word)
})

val svm = SVM().setBlocks(env.getParallelism)

val weights = env.fromElements(DenseVector(1.0))

svm.weightsOption = Option(weights) // skipping the training here

val predictionResult: DataSet[(DenseVectorWithTag, Double)] = svm.predict(vectorizedInput)

class DenseVectorWithTag(override val data: Array[Double], tag: String)
  extends DenseVector(data) {
  override def toString: String = "(" + super.toString + ", " + tag + ")"
}
  1. Join the prediction DataSet with the input DataSet on the vectorized representation of the tweets. This approach has the advantage that we don't need to introduce new classes. The price we pay for this is an additional join operation which might be expensive:
val input = env.fromElements("foobar", "barfo", "test")

val vectorizedInput = input.map(word => {
  val value = word.chars().sum()
  (DenseVector(value), word)
})

val svm = SVM().setBlocks(env.getParallelism)

val weights = env.fromElements(DenseVector(1.0))

svm.weightsOption = Option(weights) // skipping the training here

val predictionResult = svm.predict(vectorizedInput.map(a => a._1))
val inputWithPrediction: DataSet[(String, Double)] = vectorizedInput
  .join(predictionResult)
  .where(0)
  .equalTo(0)
  .apply((t, p) => (t._2, p._2))