Find the top level node for each node in a tree structure using graphx pregel API

27 views Asked by At

I'm new to graphx and facing issue with this.

Employee table data is as below

Emp_ID Department_id
10 100
11 200
12 300
20 400
21 500
22 600
50 1000
60 2000
1 10000
2 20000

Employee_manager table data is as below

Manager_ID Emp_ID
10 11
10 12
20 21
20 22
50 10
60 20
1 50
2 60

The tree structure can be described as below.

  1                2
  |                |
  50               60
  |                | 
  10               20
  |                |              
  11               21   

For Emp_ID 11 the manager is 10. For Emp_ID 10 the manager is 50 For Emp_ID 50 the manager is 1. Ultimately the top_level_manager for Emp_ID 11 is 1 and his corresponding Department_id is 10000.

Similarly the top_level_manager for Emp_ID 21 is 2 , for Emp_ID 10 is 1 so on This is in tree structure.

For each Emp_ID in Employee table I need to find his top_level_manager and his manager's Department

I need the result in below format

Emp_ID Emp_Department_id top_level_manager_id Managers_Department_id
11 200 1 10000
12 300 1 10000
21 500 2 20000
22 600 2 20000
10 100 1 10000
20 100 2 20000
50 1000 1 10000
60 2000 2 20000

I have tried this but not seems to be working as expected. I'm getting -1 as result

import org.apache.spark.graphx._
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._

case class employeeInfo(department_id: Long, topLevelParentId: Long, topLeveldepartment_id: Long)

def createGraph(employeeDF: DataFrame, Employee_managerDF: DataFrame): Graph[employeeInfo, Int] = {
  // Create vertices RDD
  val verticesRDD = employeeDF.select(col("employee_id").cast("Long"), col("department_id").cast("Long"))
    .rdd.map(row => (row.getAs[Long]("employee_id"), employeeInfo(row.getAs[Long]("department_id"), -1, -1)))

  // Create edges RDD
  val edgesRDD = Employee_managerDF.select(col("manager_id").cast("Long"), col("Emp_ID").cast("Long"))
    .rdd.map(row => Edge(row.getAs[Long]("manager_id"), row.getAs[Long]("Emp_ID"), 1))

  // Create the graph
  Graph(verticesRDD, edgesRDD)
}



// Function to run Pregel algorithm
def runPregel(graph: Graph[employeeInfo, Int]): Graph[employeeInfo, Int] = {
  // Initialize the graph
  val initialGraph = graph.mapVertices((_, attr) => employeeInfo(attr.department_id, -1, -1))

  // Define the update vertex function
  def updateVertex(id: VertexId, attr: employeeInfo, msg: employeeInfo): employeeInfo = {
    if (msg.topLevelParentId != -1) employeeInfo(attr.department_id, msg.topLevelParentId, msg.topLevelParentId)
    else attr
  }

  // Define the send message function
  def sendMessage(triplet: EdgeTriplet[employeeInfo, Int]): Iterator[(VertexId, employeeInfo)] = {
    Iterator((triplet.srcId, employeeInfo(-1, triplet.dstAttr.topLevelParentId, triplet.dstAttr.topLeveldepartment_id)))
  }

  // Define the merge messages function
  def mergeMessages(msg1: employeeInfo, msg2: employeeInfo): employeeInfo = {
    if (msg1.topLevelParentId != -1) msg1 else msg2
  }

  // Run Pregel algorithm
  Pregel(initialGraph, employeeInfo(-1, -1, -1), Int.MaxValue, EdgeDirection.In)(updateVertex, sendMessage, mergeMessages)
}

def extractResult(graph: Graph[employeeInfo, Int]): DataFrame = {
  val resultDF = graph.vertices.map {
    case (employeeId, info) => (employeeId, info.department_id, info.topLevelParentId)
  }.toDF("employee_id", "department_id", "top_level_parent_id")
  resultDF
}

// Define SparkSession
val spark = SparkSession.builder()
  .appName("Employee Analysis")
  .config("spark.master", "local")
  .getOrCreate()

// Define the data for the Employee table
val employeeData = Seq(
  (10L, 100L),
  (11L, 200L),
  (12L, 300L),
  (20L, 400L),
  (21L, 500L),
  (22L, 600L),
  (50L, 1000L),
  (60L, 2000L),
  (1L, 10000L),
  (2L, 20000L)
)

// Define the schema for the employee table
val employeeSchema = List("employee_id", "department_id")

// Create the Employee DataFrame
val employeeDF = spark.createDataFrame(employeeData).toDF(employeeSchema: _*)

// Define the data for the employee_manager table
val Employee_managerData = Seq(
  (10L, 11L),
  (10L, 12L),
  (20L, 21L),
  (20L, 22L),
  (50L, 10L),
  (60L, 20L),
  (1L, 50L),
  (2L, 60L)
)

// Define the schema for the Employee_manager table
val Employee_managerSchema = List("Manager_ID", "Emp_ID")

// Create the bom DataFrame
val Employee_managerDF = spark.createDataFrame(Employee_managerData).toDF(Employee_managerSchema: _*)

// Show the data for verification
println("employee Table:")
employeeDF.show()

println("Employee_manager Table:")
Employee_managerDF.show()


val graph = createGraph(employeeDF, Employee_managerDF)

val pregelGraph = runPregel(graph)

val resultDF = extractResult(pregelGraph)

resultDF.show()


0

There are 0 answers