Modified countByKey in spark

Asked by At

I have a dataframe as follows:

+------+-------+
| key  | label |
+------+-------+
| key1 | a     |
| key1 | b     |
| key2 | a     |
| key2 | a     |
| key2 | a     |
+------+-------+

I want a modified version of countByKeys in spark which returns output as follows:

+------+-------+
| key  | count |
+------+-------+
| key1 |     0 |
| key2 |     3 |
+------+-------+
//explanation: 
if all labels under a key are same, then return count of all rows under a key 
else count for that key is 0

My approach to solve this problem:

Steps:

  1. reduceByKey() : concatenate all labels (consider labels as strings) to get dataframe of type < key,concat_of_all_labels >
  2. mapValues() : parse each string character-wise to check if there are all same. If they are same return number of labels , else return 0.

I am new to spark and I feel that there should be some efficient way to get this done. Is there any better way to get this task done?

3 Answers

5
Kombajn zbożowy On Best Solutions

It's quite straightforward: get both count and distinct count by key, then it's just a simple case when ... then ...

val df = Seq(("key1", "a"), ("key1", "b"), ("key2", "a"), ("key2", "a"), ("key2", "a")).toDF("key", "label")
df.groupBy('key)
  .agg(countDistinct('label).as("cntDistinct"), count('label).as("cnt"))
  .select('key, when('cntDistinct === 1, 'cnt).otherwise(typedLit(0)).as("count"))
  .show

+----+-----+
| key|count|
+----+-----+
|key1|    0|
|key2|    3|
+----+-----+
1
aelbuni On

Adding to the previous solution. Using reduceByKey is more efficient if your data is really big and you care about parallelism.

If your data is big and want to reduce shuffling effect, as groupBy can cause shuffling, here is another solution using RDD API and reduceByKey that will operate within a partition level:

val mockedRdd = sc.parallelize(Seq(("key1", "a"), ("key1", "b"), ("key2", "a"), ("key2", "a"), ("key2", "a")))

// Converting to PairRDD
val pairRDD = new PairRDDFunctions[String, String](mockedRdd)

// Map and then Reduce
val reducedRDD = pairRDD.mapValues(v => (Set(v), 1)).
     reduceByKey((v1, v2) => (v1._1 ++ v2._1, v1._2 + v1._2))

scala> val result = reducedRDD.collect()
`res0: Array[(String, (scala.collection.immutable.Set[String], Int))] = Array((key1,(Set(a, b),2)), (key2,(Set(a),4)))`

Now the final result has the following format (key, set(labels), count):

Array((key1,(Set(a, b),2)), (key2,(Set(a),4)))

Now after you collect the results in your driver, you can simply accept counts from Sets that contain only one label:

// Filter our sets with more than one label
scala> result.filter(elm => elm._2._1.size == 1)
res15: Array[(String, (scala.collection.immutable.Set[String], Int))] 
              = Array((key2,(Set(a),4)))

Analysis using Spark 2.3.2

1) Analysing the (DataFrame API) groupBy Solution

I am not really a Spark Expert, but I will throw my 5 cents here :)

Yes, DataFrame and SQL Query go through Catalyst Optimizer, which can possibly optimize a groupBy.

groupBy approach proposed using DataFrame API generates the following Physical Plan by running df.explain(true)

== Physical Plan ==
*(3) HashAggregate(keys=[key#14], functions=[count(val#15), count(distinct val#15)], output=[key#14, count#94L])
+- Exchange hashpartitioning(key#14, 200)
   +- *(2) HashAggregate(keys=[key#14], functions=[merge_count(val#15), partial_count(distinct val#15)], output=[key#14, count#105L, count#108L])
      +- *(2) HashAggregate(keys=[key#14, val#15], functions=[merge_count(val#15)], output=[key#14, val#15, count#105L])
         +- Exchange hashpartitioning(key#14, val#15, 200)
            +- *(1) HashAggregate(keys=[key#14, val#15], functions=[partial_count(val#15)], output=[key#14, val#15, count#105L])
               +- *(1) Project [_1#11 AS key#14, _2#12 AS val#15]
                  +- *(1) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._1, true, false) AS _1#11, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, scala.Tuple2, true])._2, true, false) AS _2#12]
                     +- Scan ExternalRDDScan[obj#10]

Note that the job has been split into three stages, and has two exchange phases. And it is worth to mention that the second hashpartitioning exchange used a different set of keys (key, label), which IMO will cause shuffle in this case as partitions hashed with (key, val) will not necessary co-exist with partitions hashed with (key) only.

Here is the plan visualized by Spark UI:

groupBy solution

2) Analysing the RDD API Solution

By running reducedRDD.toDebugString, we will get the following plan:

scala> reducedRDD.toDebugString
res81: String =
(8) ShuffledRDD[44] at reduceByKey at <console>:30 []
 +-(8) MapPartitionsRDD[43] at mapValues at <console>:29 []
    |  ParallelCollectionRDD[42] at parallelize at <console>:30 []

Here is the plan visualized by Spark UI:

RDD API Approach

You can clearly see that the RDD approach generated less number of stages and tasks, and also doesn't cause any shuffle until we process the dataset and collect it from the driver side of course. This alone tells us that this approach consumes less resources and time.

Conclusion at the end of the day, how much optimization you want to apply does really depend on your business requirement, and the size of the Data you are dealing with. If you don't have big data, then going by the groupBy approach will be a straight forward option; otherwise, considering (Parallelism, Speed, Shuffling, & Memory Management) will be important, and most of the time you can accomplish that by analyzing Query Plans and examining your jobs through Spark UI.

0
Community On
scala> val df = sc.parallelize(Seq(("key1", "a"), ("key1", "b"), ("key2", "a"), ("key2", "a"), ("key2", "a")))

scala> val grpby = df.groupByKey()

scala> val mp = gb.map( line  => (line._1,line._2.toList.length,line._2.toSet.size))
                  .map { case(a,b,c) => (a,if (c!=1) 0 else b) }

scala> val finres  =  mp.toDF("key","label")

scala> finres.show
+----+-----+
| key|label|
+----+-----+
|key1|    0|
|key2|    3|
+----+-----+