Add a tag to the list in the DataFrame based on the threshold given for the values ​in the list in Scala Spark

117 views Asked by At

I have a Dataframe that has a column "grades" containing a list of Grade objects that have 2 fields: name (String) and value (Double). I would like to add the word PASS to the list of tags if there is a Grade on the list with the name: HOME and a minimum value of 20.0. Example below:

INPUT:
+------+-----+----+-------+-------------------------------------------------------------+
| model| cnd | age| tags  |  grades                                                     |
+------+-----+----+-------+-------------------------------------------------------------+
|  foo1|   xx|  10|  []   |   [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}]   | 
|  foo2|   xz|  12|  []   |   [{name:"ATW", value: 70.0}]   | 
|  foo3|   xc|  13|  []   |   [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}]    | 
+------+-----+----+-------+-------------------------------------------------------------+



 OUTPUT:

+------+-----+----+-------+--------------------------------------------------------------+
| model| cnd | age| tags  |  grades                                                     |
+------+-----+----+-------+--------------------------------------------------------------+
|  foo1|   xx|  10| [PASS]|   [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}]    | 
|  foo2|   xz|  12|  []   |   [{name:"ATW", value: 70.0}]                                | 
|  foo3|   xc|  13|  []   |   [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}]    | 
+------+-----+----+-------+--------------------------------------------------------------+

I haven't been able to find a reasonable solution. So far I have got this:

    dataFrame.withColumn("tags",
    when(
      array_contains(
        col("grades.name"),
        lit("HOME")
      ) && col("grades.value") >= lit(20.0),
      array_union(col("tags"), lit(Array("PASS")))
    ).otherwise(col("tags"))

But this code for some reason throws

org.apache.spark.sql.AnalysisException: cannot resolve '(`grades`.`value` >= 20.0D)' due to data type mismatch: differing types in '(`grades`.`value` >= 20.0D)' (array<double> and double).;;

The data is read from bigquery and there is no way that there is an array of double numbers in the value field.

1

There are 1 answers

0
vilalabinot On BEST ANSWER

Assume data is called your dataset (as below for the sake of simplicity):

+----+---------------------------+
|tags|grades                     |
+----+---------------------------+
|[]  |[{ATW, 10.0}, {HOME, 20.0}]|
|[]  |[{ATW, 70.0}]              |
|[]  |[{ATW, 90.0}, {HOME, 10.0}]|
+----+---------------------------+

If by any case your column (grades) is string, then we might want to convert the JSON to a structure as below (you can also skip this part):

data = data.withColumn("grades",
  expr("from_json(grades, 'array<struct<name:string,value:double>>')")
)

Once this is in place, then we can apply the following:

data = data.withColumn("tags",
  when(
    // when this condition is met, meaning that if there is one combo name = HOME and value >= 20
    expr("size(filter(grades, x -> x.name == 'HOME' and x.value >= 20))").geq(1),
    // concatenate whatever there is in TAGS column with array("pass")
    array_union(col("tags"), array(lit("PASS")))
    // otherwise, do not touch TAGS column
  ).otherwise(col("tags")))

Final output looks like:

+------+---------------------------+
|tags  |grades                     |
+------+---------------------------+
|[PASS]|[{ATW, 10.0}, {HOME, 20.0}]|
|[]    |[{ATW, 70.0}]              |
|[]    |[{ATW, 90.0}, {HOME, 10.0}]|
+------+---------------------------+

Good luck!