How to group by percentile distributions for every variable in a dataset and output the mean/median in pyspark

205 views Asked by At

I asked a fairly similar yet different question and got a good response here:

Groupby and percentage distributions pyspark equivalent of given pandas code

I am not sure how to tailor the modification I need to do for my current need.

What I would like to do is to create a separate group by each of the percentiles and the aggregate for each of the percentiles to be a mean & median value of a certain fixed variable (it would be the same variable for every percentile). Here is the illustration for what I have in mind (just to make sure it's clear, those variables on the left currently exist on a variable level like "age", "income", I don't have them already created in advance with the percentiles, that's part of what I need to create.

                         mean(credit score)     median(credit score)
age_10th_percentile        700                       550
age_25th_percentile        710                       560
age_50th_percentile        750                       580
income_10th_percentile     710                       590
income_25th_percentile     730                       610
income_50th_percentile     740                       640
1

There are 1 answers

12
Derek O On

The format of the output dataframe(s) won't be the same as what you have written in your question because pyspark doesn't really have the concept of an index. However, you can to first calculate the percentiles of each col, use this to bin your data accordingly, and then calculate the mean and median of any other columns using those bins.

We start out with a sample pyspark dataframe that looks like the following:

import numpy as np
import pandas as pd
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType

import pyspark.sql.functions as F

spark = SparkSession.builder.appName('percentileAggregations').getOrCreate()

np.random.seed(42)
col1 = 100 + 50*np.random.rand(100)
col2 = 200 + 50*np.random.rand(100)
col3 = 300 + 50*np.random.rand(100)
col4 = 400 + 50*np.random.rand(100)
col5 = 500 + 50*np.random.rand(100)

pandasDF = pd.DataFrame({
    'col1': col1,
    'col2': col2,
    'col3': col3,
    'col4': col4,
    'col5': col5,
})

df = spark.createDataFrame(pandasDF)

+------------------+------------------+------------------+------------------+------------------+
|              col1|              col2|              col3|              col4|              col5|
+------------------+------------------+------------------+------------------+------------------+
|118.72700594236812|201.57145928433673| 332.1015823077144| 402.5840860584304|505.15619344179663|
| 147.5357153204958|  231.820520563189| 304.2069982497524| 426.5677315784074| 545.1276453339783|
|136.59969709057026|215.71779905381632|308.08143570473067| 427.0317560805053| 525.2626186223929|
|129.93292420985182|225.42853455823513|344.92770942635394|431.87149507491034| 541.3228733053871|
|107.80093202212183|245.37832369630465| 330.3214529829795|436.30456668613306| 516.0024800515306|
|107.79972601681013|212.46461145744374|300.45985258083147|448.79260397312675|   544.77616142481|
|102.90418060840997|220.51914615178148| 305.0735771433016|425.81501741505974| 519.4600839367082|
...
+------------------+------------------+------------------+------------------+------------------+

Then we calculate the percentiles of each column, and use these values to assign buckets with a udf:

## calculate percentiles, note this will be one row

all_percentiles = [
    F.expr(f'percentile({col}, array(0, 0.10, 0.25, 0.50))').alias(f'{col}_percentiles')
    for col in df.columns
]

df_percentiles = df.select(
    *all_percentiles
)

percentiles_row = df_percentiles.collect()[0]

percentile_info = {
    f'{col}_percentiles': percentiles_row[f'{col}_percentiles']
     for col in df.columns
}

def categorizer(bins, value):
  if bins[0] <= value <= bins[1]:
    return "10th"
  elif bins[1] < value <= bins[2]:
    return "25th"
  elif bins[2] < value <= bins[3]:
    return "50th"
  else: 
    return "above_50th"

def bucket_udf(bins):
    return F.udf(lambda l: categorizer(bins, l))

df_percentile_buckets = df.select('*')

for col in df.columns:
    df_percentile_buckets = df_percentile_buckets.withColumn(
        f"{col}_buckets", bucket_udf(percentile_info[f"{col}_percentiles"])(F.col(col))
    )

This gives us the following:

+------------------+------------------+------------------+------------------+------------------+------------+------------+------------+------------+------------+
|              col1|              col2|              col3|              col4|              col5|col1_buckets|col2_buckets|col3_buckets|col4_buckets|col5_buckets|
+------------------+------------------+------------------+------------------+------------------+------------+------------+------------+------------+------------+
|118.72700594236812|201.57145928433673| 332.1015823077144| 402.5840860584304|505.15619344179663|        50th|        10th|  above_50th|        10th|        25th|
| 147.5357153204958|  231.820520563189| 304.2069982497524| 426.5677315784074| 545.1276453339783|  above_50th|  above_50th|        10th|  above_50th|  above_50th|
|136.59969709057026|215.71779905381632|308.08143570473067| 427.0317560805053| 525.2626186223929|  above_50th|        50th|        25th|  above_50th|        50th|
|129.93292420985182|225.42853455823513|344.92770942635394|431.87149507491034| 541.3228733053871|  above_50th|  above_50th|  above_50th|  above_50th|  above_50th|
|107.80093202212183|245.37832369630465| 330.3214529829795|436.30456668613306| 516.0024800515306|        25th|  above_50th|  above_50th|  above_50th|        50th|
|107.79972601681013|212.46461145744374|300.45985258083147|448.79260397312675|   544.77616142481|        25th|        50th|        10th|  above_50th|  above_50th|
|102.90418060840997|220.51914615178148| 305.0735771433016|425.81501741505974| 519.4600839367082|        10th|        50th|        25th|  above_50th|        50th|
|143.30880728874675|237.77755692715243| 333.1750884554028| 416.1478236470623| 500.5418825740149|  above_50th|  above_50th|  above_50th|        50th|        10th|
|130.05575058716045|211.43990827458111| 300.2530791923109| 439.7593097384352| 545.2690988209632|  above_50th|        25th|        10th|  above_50th|  above_50th|
| 135.4036288898023|203.84899549143964|308.04040257087496|413.54161256310374| 504.5643338393067|  above_50th|        10th|        25th|        50th|        25th|
|101.02922471479012| 214.4875726456884| 327.4366894683293| 421.9485710352818| 515.9656818795207|        10th|        50th|        50th|        50th|        50th|
| 148.4954926080997|208.06106436270022|334.59475988463464| 403.9228190671133| 547.5030983525403|  above_50th|        25th|  above_50th|        25th|  above_50th|
|141.62213204002109|246.48488261712865|   332.59806297513|401.26753717077287| 547.5303573468778|  above_50th|  above_50th|  above_50th|        10th|  above_50th|
|110.61695553391381|240.40601897822086|  311.213465473028|448.13242073389625| 528.6718944061643|        50th|  above_50th|        25th|  above_50th|  above_50th|
|109.09124836035502|231.67018782552117| 335.6089610673768| 441.7990060256103| 531.5918606084899|        25th|  above_50th|  above_50th|  above_50th|  above_50th|
|109.17022549267169| 243.5730295093859|   311.86245437484| 434.7987103046849|  522.422276098916|        25th|  above_50th|        25th|  above_50th|        50th|
|115.21211214797688| 240.1836038449557|316.26998490796336| 420.4476472207135| 514.6605385849032|        50th|  above_50th|        50th|        50th|        50th|
| 126.2378215816119| 209.3285029443018| 337.3245702559012| 408.6647160035423| 516.4332272684958|  above_50th|        25th|  above_50th|        25th|        50th|
| 121.5972509321058| 244.6279499244989|332.48164495236074|407.82185213355433|  533.625922803852|        50th|  above_50th|  above_50th|        25th|  above_50th|
| 114.5614570099021|226.96711209578254| 342.4611705247089|412.51214490822974|  537.618726471884|        50th|  above_50th|  above_50th|        50th|  above_50th|
+------------------+------------------+------------------+------------------+------------------+------------+------------+------------+------------+------------+

Then we can calculate metrics separately for any column based on the buckets of another column using groupby. For example below we calculate the mean and median of col2 based on the buckets of col1:

df_col_based_metrics = df_percentile_buckets.groupby(
    'col1_buckets'
).agg(
    F.mean('col2').alias('mean_col2'),
    F.percentile_approx('col2', 0.50).alias('median_col2'),
)

+------------+------------------+------------------+
|col1_buckets|         mean_col2|       median_col2|
+------------+------------------+------------------+
|        10th|223.64443638995675|220.51914615178148|
|        50th|227.48628498873842|230.47821669899486|
|        25th| 225.8940829666104|231.67018782552117|
|  above_50th| 223.5429176531586|220.87055015743894|
+------------+------------------+------------------+