PySpark: remove rows which derivate from others

381 views Asked by At

I do have the following dataframe, which contains all the paths within a tree after going through all nodes. For each jump between nodes, a row will be created where "dist" is the number of nodes so far, "node" the current node and "path" the path so far.

dist   |  node     |  path
0      |     1     |    [1]   
1      |     2     |    [1,2] 
1      |     5     |    [1,5] 
2      |     3     |    [1,2,3] 
2      |     4     |    [1,2,4] 

At the end I just want to have a dataframe containing the complete paths without the intermediate steps:

dist   |  node     |  path
1      |     5     |    [1,5] 
2      |     3     |    [1,2,3] 
2      |     4     |    [1,2,4]

I also tried by having the path column as a string ("1;2;3") and comparing which row is a substring from each other, however i could not find a way to do that.

1

There are 1 answers

1
Alex Ortner On

I found my old code and created an adapted example for your problem. I used the spark graph library Graphframes for this. The path can be determined by a Pregel like message aggregation loop.

Here the code. First import all modules

from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import SQLContext

import pyspark.sql.functions as f
from graphframes import GraphFrame
from pyspark.sql.types import *

from graphframes.lib import *
# shortcut for the aggregate message object from the graphframes.lib
AM=AggregateMessages


# to plot the graph
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt


spark = (SparkSession
         .builder
         .appName("PathReduction")
         .getOrCreate()
        )

sc=spark.sparkContext

Then create a sample dataset

# create dataframe
raw_data = [
  ("0","1"),
  ("1","2"),
  ("1","5"),
  ("2","3"),
  ("2","4"),
  ("a","b"),
  ("b","c"),
  ("c","d")]

schema = ["src","dst"]
data = spark.createDataFrame(data=raw_data, schema = schema) 
data.show()

+---+---+
|src|dst|
+---+---+
|  0|  1|
|  1|  2|
|  1|  5|
|  2|  3|
|  2|  4|
|  a|  b|
|  b|  c|
|  c|  d|
+---+---+

For visualisation run

plotData_1 = data.select("src","dst").rdd.collect()
plotData_2 = np.array(plotData_1)

plotData_3=[]
for row in plotData_2:
  plotData_3.append((row[0],row[1]))

G=nx.DiGraph(directed=True)
G.add_edges_from(plotData_3)

options = {
    'node_color': 'orange',
    'node_size': 500,
    'width': 2,
    'arrowstyle': '-|>',
    'arrowsize': 20,
}

nx.draw(G, arrows=True, **options,with_labels=True)

plot of graph

With this message aggregation algorithm you find the paths as you searched them. if you set the flag show_steps to True the results of each step is shown which helps to understand.

# if flag is true print results within the loop for debuging
show_steps=False
# max itertions of the loop, should be larger then the longest expected path
max_iter=10

# create vertices from edge data set
vertices=(data.select("src").union(data.select("dst")).distinct().withColumnRenamed('src', 'id'))
edges=data

# create graph to get in and out degrees
gx = GraphFrame(vertices, edges)
# calclulate in and out degrees of each node
inDegrees=gx.inDegrees
outDegrees=gx.outDegrees

if(show_steps==True):
  print("in and out degrees")
  inDegrees.show()
  outDegrees.show()

# create intial vertices
init_vertices=(vertices
               # join out degrees on vertices
               .join(outDegrees,on="id",how="left")
               # join in degree on vertices
               .join(inDegrees,on="id",how="left")
               # define root, childs in the middle and leafs of the path in order to distinguish full paths later on
               .withColumn("nodeType",f.when(f.col("inDegree").isNull(),"root").otherwise(f.when(f.col("outDegree").isNull(),"leaf").otherwise("child")))
               # define message with all information [array(id) and array(nodeType)] to be send to the next noe
               .withColumn("message",f.array_union(f.array(f.array(f.col("id"))),f.array(f.array(f.col("nodeType")))))
               # remove columns that are not used anymore
               .drop("inDegree","outDegree")
              )

if(show_steps==True):
  print("init vertices")
  init_vertices.show()

# update graph object with init vertices
gx = GraphFrame(init_vertices, edges)


# define empty dataframe to append found paths on
results = sqlContext.createDataFrame(
        sc.emptyRDD(), 
        StructType([StructField("paths",ArrayType(StringType()),True)])
    )


# start loopp for mesage aggregation. Set a max_iter value which has to be larger as the longest path expected

for iter_ in range(max_iter):

    if(show_steps==True):
        print("iteration step=" + str(iter_))
        print("##################################################")
    # define the message that should be send. Here we send a message to the source node and we take the column message from the destination source we send backward
    msgToSrc = AM.dst["message"]
    agg = gx.aggregateMessages(
      f.collect_set(AM.msg).alias("aggMess"),  # aggregation function is a collect into an array (attention!! this can be an expensive operation in terms of shuffel)
      sendToSrc=msgToSrc,
      sendToDst=None
    )
    
    if(show_steps==True):
      print("aggregated message")
      agg.show(truncate=False)
    
    # stop loop if no more agg messages collected
    if(len(agg.take(1))==0):
      print("All paths found in " + str(iter_) + " iterations")
      break
    
    # get new vertices to send into next round. Here we have to prepare the next message columns all _column_names are temporary columns for calculation purpose only
    vertices_update=(agg
                     # join initial data to aggregation in order to have to nodeType of the vertice
                     .join(init_vertices,on="id",how="left")
                     # exploe the nested array with the path and the nodeType
                     .withColumn("_explode_to_flatten_array",f.explode(f.col("aggMess")))
                     # put the path aray into a seperate column 
                     .withColumn("_dataMsg",f.col("_explode_to_flatten_array")[0])
                     # put the node type into a seperate column
                     .withColumn("_typeMsg",f.col("_explode_to_flatten_array")[1][0])
                     # deside if a path is complete. A path is complete if the vertices type is a root and the message type is a leaf
                     .withColumn("pathComplete",f.when(((f.col("nodeType")=="root") & (f.col("_typeMsg")=="leaf")),True).otherwise(False))
                     # append the curent vertice id to the path array that is send forward
                     .withColumn("_message",f.array_union(f.array(f.col("id")),f.col("_dataMsg")))
                     # merge together the path array and the nodeType array for the new message object
                     .withColumn("message",f.array_union(f.array(f.col("_message")),f.array(f.array(f.col("_typeMsg")))))
                    )
                    
    if(show_steps==True):
      print("new vertices with all temp columns") 
      vertices_update.show()
    
    # add complete paths to the result dataframe
    results=(
            results
            .union(
                vertices_update
                .where(f.col("pathComplete")==True)
                .select(f.col("_message"))
                )
        )

    
    # chache the vertices for next iteration and only push forward the two relevant columns in order to reduce data shuffeling between spark executors
    cachedNewVertices = AM.getCachedDataFrame(vertices_update.select("id","message"))
    # create new updated graph object for next iteration
    gx = GraphFrame(cachedNewVertices, gx.edges)
    
    
print("##################################################")   
print("Collecting result set")    
results.show()  

it shows then the correct results

All paths found in 3 iterations
##################################################
Collecting result set
+------------+
|       paths|
+------------+
|   [0, 1, 5]|
|[0, 1, 2, 3]|
|[0, 1, 2, 4]|
|[a, b, c, d]|
+------------+

to get your final dataframe you can join it back or take the first and last element of the array into separate columns

result2=(results
         .withColumn("dist",f.element_at(f.col("paths"), 1))
         .withColumn("node",f.element_at(f.col("paths"), -1))
        )
result2.show()

+------------+----+----+
|       paths|dist|node|
+------------+----+----+
|   [0, 1, 5]|   0|   5|
|[0, 1, 2, 3]|   0|   3|
|[0, 1, 2, 4]|   0|   4|
|[a, b, c, d]|   a|   d|
+------------+----+----+

You can write the same algorithm with the Graphframes Pregel API I suppose.

P.S: The algorithm in this form might cause problems if the graph has lops or backward directed edges. I had another algorithm to first clean up loops and cycles