Pyspark how to avoid explode for groups in the top and nested structure (code optimalisation)

44 views Asked by At

Problem

I'd like to compute some stats from request data grouped by values in top layer and values in nested layer. The main problem with explode-join and 3x groupby is the code with big data (100GB) is too slow.

Sample data:

import pyspark.sql.types as T

rows = [
    {"id": 1, "typeId": 1, "items":[
        {"itemType": 1,"flag": False,"event": None},
        {"itemType": 3,"flag": True,"event":[{"info1": ""},{"info1": ""}]},
        {"itemType": 3,"flag": True,"event":[{"info1": ""},{"info1": ""}]},
    ]},
    {"id": 2, "typeId": 2, "items":None},
    {"id": 3, "typeId": 1, "items":[
        {"itemType": 1,"flag": False,"event": None},
        {"itemType": 6,"flag": False,"event":[{"info1": ""}]},
        {"itemType": 6,"flag": False,"event":None},
    ]},
    {"id": 4, "typeId": 2, "items":[
        {"itemType": 1,"flag": True,"event":[{"info1": ""}]},
    ]},
    {"id": 5, "typeId": 3, "items":None},
]

schema = T.StructType([
   T.StructField("id", T.IntegerType(), False),
   T.StructField("typeId", T.IntegerType()),
   T.StructField("items", T.ArrayType(T.StructType([
           T.StructField("itemType", T.IntegerType()),
           T.StructField("flag", T.BooleanType()),
           T.StructField("event", T.ArrayType(T.StructType([
                   T.StructField("info1", T.StringType()),
           ]))),
       ])), True),
])

df = spark.createDataFrame(rows, schema)
df.printSchema()

their structure:

root
 |-- id: integer (nullable = false)
 |-- typeId: integer (nullable = true)
 |-- items: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- itemType: integer (nullable = true)
 |    |    |-- flag: boolean (nullable = true)
 |    |    |-- event: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- info1: string (nullable = true)

I'd like to perform these calculations for each typeid and by items.itemtype:

  • total of rows (requests)
  • total of rows (requests) if contains some item
  • total of rows (requests) if contains some item with (items.flag==True)
  • total of items
  • total of flaged items (items.flag==True)
  • total of events on items (sum(size("items.event")))

Code

Get total of request for every typeId is simple, in real analysis layer1_groups contain more category columns:

import pyspark.sql.functions as F

layer1_groups = ["typeId"]

# get count for groups in top layer
totaldf = df.groupby(layer1_groups).agg(F.count(F.lit(1)).alias("requests"))

For future computation (e.g. ratio with computed number on nested group), join these numbers to original dataframe:

df = df.join(totaldf, layer1_groups)

explode items to allow grouping by nested items.itemType

exploded_df = df.withColumn("I", F.explode_outer("items")).select("*","I.*").drop("items","I")
# add another info of item (number of events)
exploded_df = exploded_df.withColumn("eSize", F.greatest(F.size("event"), F.lit(0)))

grouping stats for every request (groupby "id") to obtain, because in future computation I want to count requests only if has flaged items, etc.:

layer2_groups = ["itemType"]

each_requests = exploded_df.groupby(["id", *layer1_groups, *layer2_groups]).agg(
    F.first("requests").alias("requests"),
    F.count(F.lit(1)).alias("ItemCount"),
    F.sum(F.col("flag").cast(T.ByteType())).alias("fItemCount"),
    F.sum("eSize").alias("eCount"),
)

Finish groups are without the "id" group:

# results without layer1 "id" to obtain resulsts
requests_results = each_requests.groupby([*layer1_groups, *layer2_groups]).agg(
    F.first("requests").alias("requests"),
    F.count_if(F.col("ItemCount")>0).alias("requestsWithItems"),
    F.count_if(F.col("fItemCount")>0).alias("requestsWith_fItems"),
    F.sum("ItemCount").alias("ItemCount"),
    F.sum("fItemCount").alias("fItemCount"),
    F.sum("eCount").alias("eCount"),
).show()

result is:

+------+--------+--------+-----------------+-------------------+---------+----------+------+
|typeId|itemType|requests|requestsWithItems|requestsWith_fItems|ItemCount|fItemCount|eCount|
+------+--------+--------+-----------------+-------------------+---------+----------+------+
|     1|       1|       2|                2|                  0|        2|         0|     0|
|     1|       3|       2|                1|                  1|        2|         2|     4|
|     1|       6|       2|                1|                  0|        2|         0|     1|
|     2|       1|       2|                1|                  1|        1|         1|     1|
|     2|    NULL|       2|                1|                  0|        1|      NULL|     0|
|     3|    NULL|       1|                1|                  0|        1|      NULL|     0|
+------+--------+--------+-----------------+-------------------+---------+----------+------+

Whole code

Gist: https://gist.github.com/vanheck/bfcadf7396d765ddd2fff5f544fd7cf2

Question

Is there some way to make faster this code? Or can I avoid explode function to obtain these stats?

1

There are 1 answers

2
Ananth Tirumanur On

You do not need to explode to get the statistics that you need. I tried this below on my local and it worked. I have pasted my result - please adjust as per your requirement.

from pyspark.sql import functions as F

# Function to count items based on a condition
def count_items_condition(col, condition):
    return F.size(F.expr(f"filter({col}, item -> {condition})"))

# Function to sum the size of "event" arrays in "items"
def sum_events_size(col):
    return F.expr(f"aggregate({col}, 0, (acc, item) -> acc + size(item.event))")

# Add calculations as new columns
df = df.withColumn("total_requests", F.lit(1)) \
       .withColumn("total_requests_with_item", F.when(F.size("items") > 0, 1).otherwise(0)) \
       .withColumn("total_requests_with_item_flag_true", count_items_condition("items", "item.flag")) \
       .withColumn("total_items", F.size("items")) \
       .withColumn("total_flagged_items", count_items_condition("items", "item.flag")) \
       .withColumn("total_events_on_items", sum_events_size("items"))

# Group by typeId and sum the calculations
result = df.groupBy("typeId") \
           .agg(
               F.sum("total_requests").alias("total_requests"),
               F.sum("total_requests_with_item").alias("total_requests_with_item"),
               F.sum("total_requests_with_item_flag_true").alias("total_requests_with_item_flag_true"),
               F.sum("total_items").alias("total_items"),
               F.sum("total_flagged_items").alias("total_flagged_items"),
               F.sum("total_events_on_items").alias("total_events_on_items")
           )

result.show()

+------+--------------+------------------------+----------------------------------+-----------+-------------------+---------------------+
|typeId|total_requests|total_requests_with_item|total_requests_with_item_flag_true|total_items|total_flagged_items|total_events_on_items|
+------+--------------+------------------------+----------------------------------+-----------+-------------------+---------------------+
|     1|             2|                       2|                                 2|          6|                  2|                    2|
|     2|             2|                       1|                                 0|          0|                  0|                    1|
|     3|             1|                       0|                                -1|         -1|                 -1|                 NULL|
+------+--------------+------------------------+----------------------------------+-----------+-------------------+---------------------+