is the implementation of the Recursive Task below correct?

633 views Asked by At

I am beginning to understand the implementation of the Recursive Task and Recursive Actions. Based on my understanding and some java documentation, I came up with the below code to add up all the numbers in an array.

I need help in correcting this and help me point out where have I gone wrong please.

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {

public static void main(String[] args) {

    ForkJoinPool pool = new ForkJoinPool(4);
    long[] numbers = {1,2,3,4,5,6,7,8,9};
    AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length -1 );
    ForkJoinTask<Long> submit = pool.submit(newTask);
    System.out.println(submit.join());
    
}
}

class AdditionTask extends RecursiveTask<Long> {

long[] numbers;
int start;
int end;

public AdditionTask(long[] numbers, int start, int end) {
    this.numbers = numbers;
    this.start = start;
    this.end = end;
}

@Override
protected Long compute() {

    if ((end - start) > 2) {

        int length = numbers.length;
        int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
        AdditionTask leftSide = new AdditionTask(numbers, 0, mid);

        leftSide.fork();

        AdditionTask rightSide = new AdditionTask(numbers, mid+1, length-1);
        return rightSide.compute() + leftSide.join();

    } else {
        return numbers[0] + numbers[1];
    }


}
}

New Code [Fixed] This is the code I fixed and seems to be working well with only small arrays. In the below example the array size is 10000 and the sum is wrong. Why does it calculate the wrong sum?

import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {

    public static void main(String[] args) {

        Random r = new Random();
        int low = 10000;
        int high = 100000;

        int size = 100000;

        long[] numbers = new long[size];
        int sum = 0;
        for (int i = 0; i < size; i++) {
            int n = r.nextInt(high - low) + low;
            numbers[i] = n;
            sum += numbers[i];
        }

        long s = System.currentTimeMillis();
        ForkJoinPool pool = new ForkJoinPool(1);
        AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length-1);
        ForkJoinTask<Long> submit = pool.submit(newTask);
        System.out.println("Expected Answer: " + sum + ", Actual: " + submit.join());
        long e = System.currentTimeMillis();
        System.out.println("Total time taken: " + (e - s) + " ms in parallel Operation");




        long s2 = System.currentTimeMillis();
        System.out.println("Started: " + s2);

        int manualSum = 0;
        for (long number : numbers) {
            manualSum += number;
        }

        System.out.println("Expected Answer: " + sum + ", Actual: " + manualSum);
        long e2 = System.currentTimeMillis();
        System.out.println("Ended: " + e2);
        System.out.println("Total time taken: " + (e2 - s2) + " ms in sequential Operation");
    }
}

class AdditionTask extends RecursiveTask<Long> {

    long[] numbers;
    int start;
    int end;

    public AdditionTask(long[] numbers, int start, int end) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {


        int length = (start == 0) ? end +1 : (end - (start - 1));

        if (length > 2) {

            int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
            
            AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
            leftSide.fork();
            
            AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);

            Long rightSideLong = rightSide.compute();

            Long leftSideLong = leftSide.join();
            Long total = rightSideLong + leftSideLong;
            
            return total;

        } else {

            if (start == end) {
                return numbers[start];
            }
            return numbers[start] + numbers[end];

        }

    }
}
1

There are 1 answers

0
Holger On BEST ANSWER

The second version of your parallel calculation is correct. But both non-parallel computations in your code are broken as they use int for their sum, which will overflow for large arrays. When you fix them, to also use long, they will produce the same result as your parallel computation.

Still, there are some things to improve. First, you should get rid of those conditionals:

int length = (start == 0) ? end +1 : (end - (start - 1));

and

int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;

They provide no benefit over the simpler

int length = end - (start - 1); // or end - start + 1

and

int mid = length / 2;

Then, an efficient parallel processing should not decompose as much as possible, but incorporate the actual achievable parallelism. You can use getSurplusQueuedTaskCount() for that

@Override
protected Long compute() {
    int length = end - (start - 1);
    // only split when benefit from parallel processing is likely
    if (length > 2 && getSurplusQueuedTaskCount() < 2) {
        int mid = length / 2;

        AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
        leftSide.fork();

        AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);

        Long rightSideLong = rightSide.compute();

        // do in this thread if no worker thread has picked it up yet
        Long leftSideLong = leftSide.tryUnfork()? leftSide.compute(): leftSide.join();
        Long total = rightSideLong + leftSideLong;

        return total;
    } else { // do sequential
        long sum = 0;
        for(int ix = start; ix <= end; ix++) sum += numbers[ix];
        return sum;
    }
}