what is the best way in spark to get each group as a new dataframe and pass on another function in loop?

121 views Asked by At

I'm using spark-sql-2.4.1v, and I'm trying to do find quantiles, i.e. percentile 0, percentile 25, etc, on each column of my given data.

My data:

+----+---------+-------------+----------+-----------+--------+
|  id|     date|      revenue|con_dist_1| con_dist_2| state  |
+----+---------+-------------+----------+-----------+--------+
|  10|1/15/2018|  0.010680705|         6|0.019875458|   TX   |
|  10|1/15/2018|  0.006628853|         4|0.816039063|   AZ   |
|  10|1/15/2018|   0.01378215|         4|0.082049528|   TX   |
|  10|1/15/2018|  0.010680705|         6|0.019875458|   TX   |
|  10|1/15/2018|  0.006628853|         4|0.816039063|   AZ   |
|  10|1/15/2018|   0.01378215|         4|0.082049528|   CA   |
|  10|1/15/2018|  0.010680705|         6|0.019875458|   CA   |
|  10|1/15/2018|  0.006628853|         4|0.816039063|   CA   |
+----+---------+-------------+----------+-----------+--------+

I would get the states to calculate i.e

val states = Seq("CA","AZ");
val cols = Seq("con_dist_1" ,"con_dist_2")

for each given state I need to fetch data from source table and calculate percentiles only for the given columns.

I'm trying as below

for( state <- states){

     for( col <- cols){
        // pecentile calculation
     }
}

this is too slow, when doing group by "state" wont get another columns like revenue, date and id.. how to get those?

How to find the quantiles on the columns "con_dist_1" & "con_dist_2" for each state? So what is the best way which scales well on cluster?

What is the best way to handle this use-case?

Expected result

+-----+---------------+---------------+---------------+---------------+---------------+---------------+
|state|col1_quantile_1|col1_quantile_2|col1_quantile_3|col2_quantile_1|col2_quantile_2|col2_quantile_3|
+-----+---------------+---------------+---------------+---------------+---------------+---------------+
|   AZ|              4|              4|              4|    0.816039063|    0.816039063|    0.816039063|
|   TX|              4|              6|              6|    0.019875458|    0.019875458|    0.082049528|
+-----+---------------+---------------+---------------+---------------+---------------+---------------+
2

There are 2 answers

0
Anup Thomas On

You may have to do something similar to the below piece of code

df.groupBy(col("state"))
    .agg(collect_list(col("con_dist_1")).as("col1_quant"), collect_list(col("con_dist_2")).as("col2_quant"))
    .withColumn("col1_quant1", col("col1_quant")(0))
    .withColumn("col1_quant2", col("col1_quant")(1))
    .withColumn("col2_quant1", col("col2_quant")(0))
    .withColumn("col2_quant2", col("col2_quant")(1))
    .show

OutPut:
+-----+----------+--------------------+-----------+-----------+-----------+-----------+
|state|col1_quant|          col2_quant|col1_quant1|col1_quant2|col2_quant1|col2_quant2|
+-----+----------+--------------------+-----------+-----------+-----------+-----------+
|   AZ|    [4, 4]|[0.816039063, 0.8...|          4|          4|0.816039063|0.816039063|
|   CA|    [4, 6]|[0.082049528, 0.0...|          4|          6|0.082049528|0.019875458|
|   TX| [6, 4, 6]|[0.019875458, 0.0...|          6|          4|0.019875458|0.082049528|
+-----+----------+--------------------+-----------+-----------+-----------+-----------+

may be the last set of withColumn should be inside a loop based on number of records for each state.

Hope this helps!

1
Lamanus On

UPDATE

I found the percentile_approx function from the hive context, so you don't need to use the stat functions.

val states = Seq("CA", "AZ")
val cols = Seq("con_dist_1", "con_dist_2")

val l = cols.map(c => expr(s"percentile_approx($c, Array(0.25, 0.5, 0.75)) as ${c}_quantiles"))
val df2 = df.filter($"state".isin(states: _*)).groupBy("state").agg(l.head, l.tail: _*)

df2.select(col("state") +: cols.flatMap( c => (1 until 4).map( i => col(c + "_quantiles")(i - 1).alias(c + "_quantile_" + i))): _*).show(false)

Here, I tried the automated method for the given states and cols. The result will be;

+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|state|con_dist_1_quantile_1|con_dist_1_quantile_2|con_dist_1_quantile_3|con_dist_2_quantile_1|con_dist_2_quantile_2|con_dist_2_quantile_3|
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|AZ   |4                    |4                    |4                    |0.816039063          |0.816039063          |0.816039063          |
|CA   |4                    |4                    |6                    |0.019875458          |0.082049528          |0.816039063          |
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+

Be aware that the result is a bit different with your expected one because I set the states = Seq("CA", "AZ") that is given by you.


ORIGINAL

Use Window for the states and calculate the percent_rank for each column.

import org.apache.spark.sql.expressions.Window

val w1 = Window.partitionBy("state").orderBy("con_dist_1")
val w2 = Window.partitionBy("state").orderBy("con_dist_2")
df.withColumn("p1", percent_rank.over(w1))
  .withColumn("p2", percent_rank.over(w2))
  .show(false)

You may filter the dataframe first, only for specific states. Anyway, the result is:

+---+---------+-----------+----------+-----------+-----+---+---+
|id |date     |revenue    |con_dist_1|con_dist_2 |state|p1 |p2 |
+---+---------+-----------+----------+-----------+-----+---+---+
|10 |1/15/2018|0.006628853|4         |0.816039063|AZ   |0.0|0.0|
|10 |1/15/2018|0.006628853|4         |0.816039063|AZ   |0.0|0.0|
|10 |1/15/2018|0.010680705|6         |0.019875458|CA   |1.0|0.0|
|10 |1/15/2018|0.01378215 |4         |0.082049528|CA   |0.0|0.5|
|10 |1/15/2018|0.006628853|4         |0.816039063|CA   |0.0|1.0|
|10 |1/15/2018|0.010680705|6         |0.019875458|TX   |0.5|0.0|
|10 |1/15/2018|0.010680705|6         |0.019875458|TX   |0.5|0.0|
|10 |1/15/2018|0.01378215 |4         |0.082049528|TX   |0.0|1.0|
+---+---------+-----------+----------+-----------+-----+---+---+