I need a databricks sql query to explode an array column and then pivot into dynamic number of columns based on the number of values in the array

168 views Asked by At

I have a json data where the location is the array colum derived with below values

["USA","China","India","UK"]

["Nepal","China","India","UK","Japan"]

I need a simple sql query to explode the array column and then pivot into dynamic number of columns based on the number of values in the array. something like this pivot(explode(from_json(jsondata:export.location, 'array<string>'))) as loc_

SELECT 
    from_json(jsondata:export.location, 'array<string>') AS `Location`
    pivot(explode(from_json(jsondata:export.location, 'array<string>'))) as loc_,
FROM mytable

Input

Location
[China, India, UK]
[China, India, UK, Japan]

Output

Location loc_1 loc_2 loc_3 loc_4
[China, India, UK] "China" "India" "UK"
[China, India, UK, Japan] "China" "India" "UK" "Japan"
3

There are 3 answers

2
s.polam On BEST ANSWER

Please find the below simple & different solution

import org.apache.spark.sql.functions._

val df = Seq(
  Seq("China", "India", "UK"), 
  Seq("China", "India", "UK", "Japan")
).toDF("location")

Find one array which has more number of elements then convert that into 'loc_1 STRING, loc_2 STRING, loc_3 STRING, loc_4 STRING' schema.

val schema = df
.selectExpr("""
  CONCAT_WS(
    ', ', 
    transform(
      sequence(1,size(location)), 
      index -> CONCAT('loc_', index, ' STRING')
    )
  ) headers
""")
.orderBy(size($"location").desc)
.limit(1)
.as[String].collect.mkString("'",", ","'")

Use from_csv function & schema to convert string data into columns.

df
.withColumn("parsed", expr(s"from_csv(concat_ws(', ', location), ${schema})"))
.select($"location", $"parsed.*")
.show(false)
+-------------------------+-----+------+-----+------+
|location                 |loc_1|loc_2 |loc_3|loc_4 |
+-------------------------+-----+------+-----+------+
|[China, India, UK]       |China| India| UK  |NULL  |
|[China, India, UK, Japan]|China| India| UK  | Japan|
+-------------------------+-----+------+-----+------+
0
Lingesh.K On

You can also try applying a posexplode followed by a window to create a unique identifier that can demarcate each array row as an unique element.

# Define the data
data = [
    (["China", "India", "UK"],),
    (["China", "India", "UK", "Japan"],)
]

# Define the schema
schema = StructType([
    StructField("Location", ArrayType(StringType()))
])

# Create a DataFrame
df = spark.createDataFrame(data, schema=schema)
df.show(df.count(), False)

# Explode the 'Location' column
print("Explode the 'Location' column")
df = df.select(F.posexplode(F.col('Location')).alias('loc_', 'Location'))
df.show(df.count(), False)

# Create a column that identifies each list of locations using a pseudo-unique identifier
df = df.withColumn('id', F.when(F.col('loc_') == 0, F.monotonically_increasing_id()).otherwise(None))

# Forward fill the null values in the 'id' column using the -sys.maxsize window
window = Window.rowsBetween(-sys.maxsize, 0)
df = df.withColumn('id', F.last('id', ignorenulls=True).over(window))

# Recreate the 'Location' column
df = df.groupBy('id').agg(F.collect_list('Location').alias('Location'), F.max('loc_').alias('max_loc_'))

# Explode the 'Location' column and filter the DataFrame to give the values that have the location index 
# less than or equal to the maximum location index
df = df.select('id', 'Location', F.posexplode('Location').alias('loc_', 'exploded_Location'))
df = df.filter(F.col('loc_') <= F.col('max_loc_'))

# Pivot the DataFrame
df_pivot = df.groupBy('id', 'Location').pivot('loc_').agg(F.first('exploded_Location'))

# Show the pivoted DataFrame
df_pivot.show(df_pivot.count(), False)

This gives the following result:

+---+-------------------------+-----+-----+---+-----+
|id |Location                 |0    |1    |2  |3    |
+---+-------------------------+-----+-----+---+-----+
|0  |[China, India, UK]       |China|India|UK |null |
|1  |[China, India, UK, Japan]|China|India|UK |Japan|
+---+-------------------------+-----+-----+---+-----+
0
DumbCoder On

You can do something like this where you split the array column into individual columns:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import StructField, StringType, ArrayType, StructType

# Create a SparkSession
spark = SparkSession.builder \
    .appName("Split Array Column Example") \
    .master('local') \
    .getOrCreate()

# Define the data
data = [
    (["China", "India", "UK"],),
    (["China", "India", "UK", "Japan"],)
]

# Define the schema
schema = StructType([
    StructField("Location", ArrayType(StringType()))
])

# Create a DataFrame
df = spark.createDataFrame(data, schema=schema)

# Determine the maximum array length
max_length = df.selectExpr("max(size(Location))").collect()[0][0]

# Split the array column into individual columns
df_split = df.select('Location', *[
    col("Location")[i].alias(f"Location_{i+1}") 
    for i in range(max_length)
])

# Show the DataFrame with split columns
df_split.show()