Check if any of the values from list are in pyspark column's list

115 views Asked by At

I have this problem with my pyspark dataframe, I created a column with collect_list() by doing normal groupBy agg and I want to write something that would return Boolean with information if at least 1 of the values in this list is in some other list of "constants":

# this is just an example of data

data = [
    (111, ["A", "B", "C"]),
    (222, ["C", "D", "E"]),
    (333, ["D", "E", "F"]),
]

schema = ["id", "my_list"]

df = spark_session.createDataFrame(data, schema=schema)

# this list is for the comparsion
constants = ["A", "B", "C", "D"]

# here I want to check if at least 1 element in list within a column is in constants
contains_any_udf = udf(lambda x: F.any(item in const_list for item in x), BooleanType())

df_result = df.withColumn("is_in_col", contains_any_udf(df["my_list"]))

Is there any better way? I tried array_contains, array_intersect, but with poor result.

What Im expecting is same df with additional column that would contain True if at least 1 value from column "my_list" is within list of constants

2

There are 2 answers

2
过过招 On BEST ANSWER

What you need is the arrays_overlap function.

import pyspark.sql.functions as F
...
df = df.withColumn('is_in_col', F.arrays_overlap('my_list', F.array([F.lit(e) for e in constants])))
0
Pawan Tolani On

Try the method below. I beleive you would already have my_list column as list somewhere so you may want to drop the second line from the code below.

from pyspark.sql import Window
df = df.withColumn("row_idx", row_number().over(Window.orderBy(monotonically_increasing_id()))) #assigning row/index numbers 
first_list = [str(row.my_list) for row in df.select('my_list').collect()]                       # collecting my_list as a list. YOu may have it already
print(first_list)
condition=[]
for element in first_list:
  if set(constants) & set(element):
    condition.append(True)
  else:
    condition.append(False)
b = spark.createDataFrame([(l,) for l in condition], ['is_in_col'])                             # storing values into a new dataframe
b = b.withColumn("row_idx", row_number().over(Window.orderBy(monotonically_increasing_id())))   # assigning row/index numbers  
final_df = df.join(b, df.row_idx == b.row_idx).\                                                #joining two dataframes on row/index numbers
             drop("row_idx")
final_df.show()