I have a PySpark DataFrame which looks like this:
df = spark.createDataFrame(
data=[
(1, "GERMANY", "20230606", True),
(2, "GERMANY", "20230620", False),
(3, "GERMANY", "20230627", True),
(4, "GERMANY", "20230705", True),
(5, "GERMANY", "20230714", False),
(6, "GERMANY", "20230715", True),
],
schema=["ID", "COUNTRY", "DATE", "FLAG"]
)
df.show()
+---+-------+--------+-----+
| ID|COUNTRY| DATE| FLAG|
+---+-------+--------+-----+
| 1|GERMANY|20230606| true|
| 2|GERMANY|20230620|false|
| 3|GERMANY|20230627| true|
| 4|GERMANY|20230705| true|
| 5|GERMANY|20230714|false|
| 6|GERMANY|20230715| true|
+---+-------+--------+-----+
The DataFrame has more countries. I want to create a new column COUNT_WITH_RESET following the logic:
- If
FLAG=False, thenCOUNT_WITH_RESET=0. - If
FLAG=True, thenCOUNT_WITH_RESETshould count the number of rows starting from the previous date whereFLAG=Falsefor that specific country.
This should be the output for the example above.
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 1|
| 4|GERMANY|20230705| true| 2|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 1|
+---+-------+--------+-----+----------------+
I have tried with row_number() over a window but I can't manage to reset the count. I have also tried with .rowsBetween(Window.unboundedPreceding, Window.currentRow). Here's my approach:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
window_reset = Window.partitionBy("COUNTRY").orderBy("DATE")
df_with_reset = (
df
.withColumn("COUNT_WITH_RESET", F.when(~F.col("FLAG"), 0)
.otherwise(F.row_number().over(window_reset)))
)
df_with_reset.show()
+---+-------+--------+-----+----------------+
| ID|COUNTRY| DATE| FLAG|COUNT_WITH_RESET|
+---+-------+--------+-----+----------------+
| 1|GERMANY|20230606| true| 1|
| 2|GERMANY|20230620|false| 0|
| 3|GERMANY|20230627| true| 3|
| 4|GERMANY|20230705| true| 4|
| 5|GERMANY|20230714|false| 0|
| 6|GERMANY|20230715| true| 6|
+---+-------+--------+-----+----------------+
This is obviously wrong as my window is partitioning only by country, but am I on the right track? Is there a specific built-in function in PySpark to achieve this? Do I need a UDF? Any help would be appreciated.
Partition the dataframe by
COUNTRYthen calculate the cumulative sum over the invertedFLAGcolumn to assign group numbers in order to distinguish between differentblocksof rows which start withfalsePartition the dataframe by
COUNTRYalong withblocksthen calculate row number over the ordered partition to create sequential counter