I have a challenge in pyspark and I haven't really been able to find a good solution so far.
I need to sum values of six months for each line.
E.g.
Lets assume I have this:
| ID | DATE | VALUE |
|---|---|---|
| 1 | 2023-04-17 | 1.0 |
| 1 | 2023-05-17 | 1.0 |
| 1 | 2023-06-17 | 1.0 |
| 1 | 2023-07-17 | 1.0 |
| 1 | 2023-08-17 | 1.0 |
| 1 | 2023-09-17 | 0.5 |
| 1 | 2023-10-17 | 2.0 |
| 1 | 2023-10-20 | 1.0 |
| 2 | 2023-04-17 | 1.0 |
| 2 | 2023-05-17 | 1.0 |
| 2 | 2023-06-17 | 1.0 |
| 2 | 2023-07-17 | 1.0 |
| 2 | 2023-08-17 | 1.0 |
| 2 | 2023-09-17 | 0.5 |
| 2 | 2023-10-17 | 2.0 |
| 2 | 2023-10-20 | 1.0 |
I need to create a sum column like that
| ID | DATE | VALUE | SUM(VALUE) | COMMENT |
|---|---|---|---|---|
| 1 | 2023-04-17 | 1.0 | 1.0 | sum value only from month 04 from id 1 |
| 1 | 2023-05-17 | 1.0 | 2.0 | sum value from months 04, 05 from id 1 |
| 1 | 2023-06-17 | 1.0 | 3.0 | sum value from months 04, 05, 06 from id 1 |
| 1 | 2023-07-17 | 1.0 | 4.0 | sum value from months 04, 05, 06, 07 from id 1 |
| 1 | 2023-08-17 | 1.0 | 5.0 | sum value from months 04, 05, 06, 07, 08 from id 1 |
| 1 | 2023-09-17 | 0.5 | 5.5 | sum value from months 04, 05, 06, 07, 08, 09 from id 1 |
| 1 | 2023-10-17 | 2.0 | 6.5 | sum value from months 05, 06, 07, 08, 09, 10 from id 1 |
| 1 | 2023-10-20 | 1.0 | 7.5 | sum value from months 05, 06, 07, 08, 09, 10 from id 1 |
| 2 | 2023-04-17 | 1.0 | 1.0 | sum value only from month 04 from id 2 |
| 2 | 2023-05-17 | 1.0 | 2.0 | sum value from months 04, 05 from id 2 |
| 2 | 2023-06-17 | 1.0 | 3.0 | sum value from months 04, 05, 06 from id 2 |
| 2 | 2023-07-17 | 1.0 | 4.0 | sum value from months 04, 05, 06, 07 from id 2 |
| 2 | 2023-08-17 | 1.0 | 5.0 | sum value from months 04, 05, 06, 07, 08 from id 2 |
| 2 | 2023-09-17 | 0.5 | 5.5 | sum value from months 04, 05, 06, 07, 08, 09 from id 2 |
| 2 | 2023-10-17 | 2.0 | 6.5 | sum value from months 05, 06, 07, 08, 09, 10 from id 2 |
| 2 | 2023-10-20 | 1.0 | 7.5 | sum value from months 05, 06, 07, 08, 09, 10 from id 2 |
I'v tried to use Window and rowsBetween/rangeBetween but I couldn't get the expected result
EDIT: I can't show the real code as it contains some confidential information about the work, but I made these cells in Jupyter to test.
data = [
['1', datetime(2024, 1, 1), 1.0],
['1', datetime(2023, 7, 1), 1.0],
['1', datetime(2023, 12, 1), 1.0],
['1', datetime(2023, 11, 1), 1.0],
['1', datetime(2023, 10, 1), 1.0],
['1', datetime(2023, 9, 15), 1.0],
['1', datetime(2023, 9, 16), 1.5],
['1', datetime(2023, 8, 17), 1.0],
['2', datetime(2023, 10, 1), 2.0],
['2', datetime(2023, 9, 7), 2.0],
['2', datetime(2023, 8, 20), 2.0]
]
header = ['id', 'data', 'valor']
rdd = sc.parallelize(data)
df = spark.createDataFrame(rdd, header)
df = df.withColumn('ano_mes', F.trunc(F.col('data'), 'month'))
df = df.withColumn('ano_mes', F.to_timestamp(F.col('data')))
df.show(truncate=False)
w = Window().partitionBy('id').orderBy('ano_mes').rangeBetween(Window.unboundedPreceding, Window.currentRow)
# here will sum all months by ID, but in this exemple, I need just the current month and last 2 by ID
df.withColumn('sum_valor', F.sum("valor").over(w)).show(truncate=False)
I managed to get the expected result using sql, but I'm still curious if there's something I can do in pyspark
Here is the code in SQL
spark.sql("""
select *, sum(valor) over(
partition by id
order by ano_mes
range between interval 2 months PRECEDING AND CURRENT ROW) as sum from df
""").show(truncate=False)