Cumulative sum to find Subarrays' whose sum equals a give value

463 views Asked by At

I'm trying to understand the logic behind the following code however I'm unclear about 2 parts of the code partially because the math supporting the logic is not totally clear to me at this moment.

  1. CONFUSION 1: I don't understand why would we put 0 with count = 1 in the map before we start finding the sum of the array? How does it help?

  2. CONFUSION 2: If I move the map.put(sum, map.getOrDefault(sum)+1) after the if() condition, I get the correct solution. However if I put it at the place as shown in the code below, it gives me wrong result. The question is why does the position of this matters, when we're searching for the value of sum-k in the map for finding the count

    public int subarraySum(int[] nums, int k) {
    
         HashMap<Integer,Integer> prefixSumMap = new HashMap<>();
         prefixSumMap.put(0, 1); // CONFUSION 1
    
         int sum = 0;
         int count = 0;
    
         for(int i=0; i<nums.length; i++) {
             sum += nums[i];
    
             prefixSumMap.put(sum, prefixSumMap.getOrDefault(sum, 0)+1); //CONFUSION 2
             if(prefixSumMap.containsKey(sum - k)) {
                 count += prefixSumMap.get(sum - k);
             }
         }
    
         return count;
     }
    
2

There are 2 answers

0
WJS On BEST ANSWER

You may find this interesting. I modified the method to use longs to prevent integer overflow resulting in negative numbers.

Both of these methods work just fine for positive numbers. Even though the first one is much simpler, they both return the same count for the test array.

public static void main(String[] args) {
Random r = new Random();
long[] vals = r.longs(10_000_000, 1, 1000).toArray();
long k = 29329;
System.out.println(positiveValues(vals, k));
System.out.println(anyValues(vals, k));


public static int positiveValues(long[] array, long k) {
    
    Map<Long,Long> map = new HashMap<>(Map.of(0L,1L));
    int count = 0;
    long sum = 0;
    
    for (long v : array) {
      sum += v;
      map.put(sum,1L);
       if (map.containsKey(sum-k)) {
           count++;
       }
    }
    return count;
}

public static int anyValues(long[] nums, long k) {

     HashMap<Long,Long> prefixSumMap = new HashMap<>();
     prefixSumMap.put(0L, 1L); 

     long sum = 0;
     int count = 0;

     for(int i=0; i<nums.length; i++) {
         sum += nums[i];
         prefixSumMap.put(sum, prefixSumMap.getOrDefault(sum, 0L)+1L);
         if(prefixSumMap.containsKey(sum - k)) {
             count += prefixSumMap.get(sum - k);
         }
     }
     return count;
 }

Additionally, the statement

    long v = prefixSumMap.getOrDefault(sum,  0L) + 1L;

Always returns 1 for positive arrays. This is because previous sums can never be re-encountered for positive only values.

That statement, and the one which computes count by taking a value from the map is to allow the array to contain both positive and negative numbers. And ths same is true a -k and all positive values.

For the following input:

long[] vals = {1,2,3,-3,0,3};

The subarrays that sum to 3 are

(1+2), (3), (1+2+3-3), (1+2+3-3+0), (3-3+0+3), (0+3), (3)

Since adding negative numbers can result in previous sums, those need to be accounted for. The solution for positive values does not do this.

This will also work for all negative values. If k is positive, no subarray will be found since all sums will be negative. If k is negative one or more subarrays may possibly be found.

0
Andreas On

#1: put(0, 1) is a convenience so you don't have to have an extra if statement checking if sum == k.

Say k = 6 and you have input [1,2,3,4], then after you've processed the 3 you have sum = 6, which of course means that subarray [1, 2, 3] needs to be counted. Since sum - k is 0, get(sum - k) returns a 1 to add to count, which means we don't need a separate if (sum == k) { count++; }

#2: prefixSumMap.put(sum, prefixSumMap.getOrDefault(sum, 0)+1) means that the first time a sum is seen, it does a put(sum, 1). The second time, it becomes put(sum, 2), third time put(sum, 3), and so on.

Basically the map is a map of sum to the number of times that sum has been seen.

E.g. if k = 3 and input is [0, 0, 1, 2, 4], by the time the 2 has been processed, sum = 3 and the map contains { 0=3, 1=1, 3=1 }, so get(sum - k), aka get(0), returns 3, because there are 3 subarrays summing to 3: [0, 0, 1, 2], [0, 1, 2], and [1, 2]

Similar if k = 4 and input is [1, 2, 0, 0, 4], by the time the 4 has been processed, sum = 7 and the map contains { 0=1, 1=1, 3=3, 7=1 }, so get(sum - k), aka get(3), returns 3, because there are 3 subarrays summing to 3: [0, 0, 4], [0, 4], and [4].

Note: This all assumes that values cannot be negative.