Spark SQL generate SCD2 without dropping historic state

322 views Asked by At

Data from an relation database is loaded over into spark - supposedly daily but in reality not every day. Furthermore, it is a full copy of the DB - no delta loading.

In order to join the dimension tables easily with the main event data I want to:

  • deduplicate it (i.e. improves potential for broadcast join later)
  • have valid_to/valid_from columns so even though data is not available daily (inconsistently) it can still be used nicely (from downstream)

I am using spark 3.0.1 and want to SCD2 style transform the existing data - without loosing history.

spark-shell

import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window

case class Foo (key:Int, value:Int, date:String)
val d = Seq(Foo(1, 1, "20200101"), Foo(1, 8, "20200102"), Foo(1, 9, "20200120"),Foo(1, 9, "20200121"),Foo(1, 9, "20200122"), Foo(1, 1, "20200103"), Foo(2, 5, "20200101"), Foo(1, 10, "20200113")).toDF
d.show

val windowDeduplication =  Window.partitionBy("key", "value").orderBy("key", "date")
val windowPrimaryKey = Window.partitionBy("key").orderBy("key", "date")

val nextThing = lead("date", 1).over(windowPrimaryKey)
d.withColumn("date", to_date(col("date"), "yyyyMMdd")).withColumn("rank", rank().over(windowDeduplication)).filter(col("rank") === 1).drop("rank").withColumn("valid_to", nextThing).withColumn("valid_to", when(nextThing.isNotNull, date_sub(nextThing, 1)).otherwise(current_date)).withColumnRenamed("date", "valid_from").orderBy("key", "valid_from", "valid_to").show

results in:

+---+-----+----------+----------+
|key|value|valid_from|  valid_to|
+---+-----+----------+----------+
|  1|    1|2020-01-01|2020-01-01|
|  1|    8|2020-01-02|2020-01-12|
|  1|   10|2020-01-13|2020-01-19|
|  1|    9|2020-01-20|2020-10-09|
|  2|    5|2020-01-01|2020-10-09|
+---+-----+----------+----------+

which is already pretty good. However:

|  1|    1|2020-01-03|   2|2020-01-12|

Is lost. I.e. any values which occur again later (after an intermediary change) are lost. How can I keep this row without keeping larger ranks such as:

d.withColumn("date", to_date(col("date"), "yyyyMMdd")).withColumn("rank", rank().over(windowDeduplication)).withColumn("valid_to", nextThing).withColumn("valid_to", 

when(nextThing.isNotNull, date_sub(nextThing, 1)).otherwise(current_date)).withColumnRenamed("date", "valid_from").orderBy("key", "valid_from", "valid_to").show

+---+-----+----------+----+----------+
|key|value|valid_from|rank|  valid_to|
+---+-----+----------+----+----------+
|  1|    1|2020-01-01|   1|2020-01-01|
|  1|    8|2020-01-02|   1|2020-01-02|
|  1|    1|2020-01-03|   2|2020-01-12|
|  1|   10|2020-01-13|   1|2020-01-19|
|  1|    9|2020-01-20|   1|2020-01-20|
|  1|    9|2020-01-21|   2|2020-01-21|
|  1|    9|2020-01-22|   3|2020-10-09|
|  2|    5|2020-01-01|   1|2020-10-09|
+---+-----+----------+----+----------+

Which is definitely not desired

  • The idea is to drop duplicates
  • But keep any historic changes to the data using a valid_to, valid_from

How can I properly transform this to a SCD2 representation, i.e. have a valid_from, valid_to but not drop intermediary state?

NOTICE: I do not need to update existing data (merge into, JOIN). It is fine to recreate / overwrite it.

I.e. Implement SCD Type 2 in Spark seems to be way too complicated. Is there a better way in my case where the state handling is not required? I.e. I have data originating from a daily full copy of a database and want to deduplicate it.

1

There are 1 answers

1
Georg Heiler On

The previous approach only keeps the first (earliest) version of a duplicate. I think the only solution without a join for state handling is with a window function where each value is compared against the previous row - and if there is no change in the whole row it is discarded.

Probably less efficient - but more accurate. But this also depends on the use-case at hand i.e. how likely it is that a changed value will be seen again.

import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window

case class Foo (key:Int, value:Int, value2:Int, date:String)
val d = Seq(Foo(1, 1,1, "20200101"), Foo(1, 8,1, "20200102"), Foo(1, 9,1, "20200120"),Foo(1, 6,1, "20200121"),Foo(1, 9,1, "20200122"), Foo(1, 1,1, "20200103"), Foo(2, 5,1, "20200101"), Foo(1, 10,1, "20200113"), Foo(1, 9,1, "20210120"),Foo(1, 9,1, "20220121"),Foo(1, 9,3, "20230122")).toDF

def compare2Rows(key:Seq[String], sortChangingIgnored:Seq[String], timeColumn:String)(df:DataFrame):DataFrame = {
    val windowPrimaryKey = Window.partitionBy(key.map(col):_*).orderBy(sortChangingIgnored.map(col):_*)
    val columnsToCompare = df.drop(key ++ sortChangingIgnored:_*).columns

    val nextDataChange = lead(timeColumn, 1).over(windowPrimaryKey)

    val deduplicated = df.withColumn("data_changes", columnsToCompare.map(e=> col(e) =!= lead(col(e), 1).over(windowPrimaryKey)).reduce(_ or _)).filter(col("data_changes").isNull or col("data_changes"))
    deduplicated.withColumn("valid_to", when(nextDataChange.isNotNull, date_sub(nextDataChange, 1)).otherwise(current_date)).withColumnRenamed("date", "valid_from").drop("data_changes")
}
d.orderBy("key", "date").show
d.withColumn("date", to_date(col("date"), "yyyyMMdd")).transform(compare2Rows(Seq("key"), Seq("date"), "date")).orderBy("key", "valid_from", "valid_to").show

returns:

+---+-----+------+----------+----------+
|key|value|value2|valid_from|  valid_to|
+---+-----+------+----------+----------+
|  1|    1|     1|2020-01-01|2020-01-01|
|  1|    8|     1|2020-01-02|2020-01-02|
|  1|    1|     1|2020-01-03|2020-01-12|
|  1|   10|     1|2020-01-13|2020-01-19|
|  1|    9|     1|2020-01-20|2020-01-20|
|  1|    6|     1|2020-01-21|2022-01-20|
|  1|    9|     1|2022-01-21|2023-01-21|
|  1|    9|     3|2023-01-22|2020-10-09|
|  2|    5|     1|2020-01-01|2020-10-09|
+---+-----+------+----------+----------+

for an input of:

+---+-----+------+--------+
|key|value|value2|    date|
+---+-----+------+--------+
|  1|    1|     1|20200101|
|  1|    8|     1|20200102|
|  1|    1|     1|20200103|
|  1|   10|     1|20200113|
|  1|    9|     1|20200120|
|  1|    6|     1|20200121|
|  1|    9|     1|20200122|
|  1|    9|     1|20210120|
|  1|    9|     1|20220121|
|  1|    9|     3|20230122|
|  2|    5|     1|20200101|
+---+-----+------+--------+

This function has the downside that unlimited amount of state is build up - for each key ... But as I plan to apply this to rather small dimension tables I think it should be fine anyways.