Pyspark Window using dates as boundary

27 views Asked by At

I have a challenge in pyspark and I haven't really been able to find a good solution so far.

I need to sum values of six months for each line.

E.g.

Lets assume I have this:

ID DATE VALUE
1 2023-04-17 1.0
1 2023-05-17 1.0
1 2023-06-17 1.0
1 2023-07-17 1.0
1 2023-08-17 1.0
1 2023-09-17 0.5
1 2023-10-17 2.0
1 2023-10-20 1.0
2 2023-04-17 1.0
2 2023-05-17 1.0
2 2023-06-17 1.0
2 2023-07-17 1.0
2 2023-08-17 1.0
2 2023-09-17 0.5
2 2023-10-17 2.0
2 2023-10-20 1.0

I need to create a sum column like that

ID DATE VALUE SUM(VALUE) COMMENT
1 2023-04-17 1.0 1.0 sum value only from month 04 from id 1
1 2023-05-17 1.0 2.0 sum value from months 04, 05 from id 1
1 2023-06-17 1.0 3.0 sum value from months 04, 05, 06 from id 1
1 2023-07-17 1.0 4.0 sum value from months 04, 05, 06, 07 from id 1
1 2023-08-17 1.0 5.0 sum value from months 04, 05, 06, 07, 08 from id 1
1 2023-09-17 0.5 5.5 sum value from months 04, 05, 06, 07, 08, 09 from id 1
1 2023-10-17 2.0 6.5 sum value from months 05, 06, 07, 08, 09, 10 from id 1
1 2023-10-20 1.0 7.5 sum value from months 05, 06, 07, 08, 09, 10 from id 1
2 2023-04-17 1.0 1.0 sum value only from month 04 from id 2
2 2023-05-17 1.0 2.0 sum value from months 04, 05 from id 2
2 2023-06-17 1.0 3.0 sum value from months 04, 05, 06 from id 2
2 2023-07-17 1.0 4.0 sum value from months 04, 05, 06, 07 from id 2
2 2023-08-17 1.0 5.0 sum value from months 04, 05, 06, 07, 08 from id 2
2 2023-09-17 0.5 5.5 sum value from months 04, 05, 06, 07, 08, 09 from id 2
2 2023-10-17 2.0 6.5 sum value from months 05, 06, 07, 08, 09, 10 from id 2
2 2023-10-20 1.0 7.5 sum value from months 05, 06, 07, 08, 09, 10 from id 2

I'v tried to use Window and rowsBetween/rangeBetween but I couldn't get the expected result

EDIT: I can't show the real code as it contains some confidential information about the work, but I made these cells in Jupyter to test.


data = [
    ['1', datetime(2024, 1, 1), 1.0],
    ['1', datetime(2023, 7, 1), 1.0],
    ['1', datetime(2023, 12, 1), 1.0],
    ['1', datetime(2023, 11, 1), 1.0],
    ['1', datetime(2023, 10, 1), 1.0],
    ['1', datetime(2023, 9, 15), 1.0],
    ['1', datetime(2023, 9, 16), 1.5],
    ['1', datetime(2023, 8, 17), 1.0],
    ['2', datetime(2023, 10, 1), 2.0],
    ['2', datetime(2023, 9, 7), 2.0],
    ['2', datetime(2023, 8, 20), 2.0]
    
]

header = ['id', 'data', 'valor']

rdd = sc.parallelize(data)
df = spark.createDataFrame(rdd, header)
df = df.withColumn('ano_mes', F.trunc(F.col('data'), 'month'))
df = df.withColumn('ano_mes', F.to_timestamp(F.col('data')))
df.show(truncate=False)

w = Window().partitionBy('id').orderBy('ano_mes').rangeBetween(Window.unboundedPreceding, Window.currentRow)

# here will sum all months by ID, but in this exemple, I need just the current month and last 2 by ID
df.withColumn('sum_valor', F.sum("valor").over(w)).show(truncate=False)

I managed to get the expected result using sql, but I'm still curious if there's something I can do in pyspark

Here is the code in SQL

spark.sql("""
select *, sum(valor) over(
partition by id
order by ano_mes
range between interval 2 months PRECEDING AND CURRENT ROW) as sum from df
""").show(truncate=False)
0

There are 0 answers