I'm dissecting the code from the OpenAI implementation of a Segment Tree which is used in their implementation of a prioritized replay buffer

I am trying to understand if this is returning the proper value or if it is returning one node lower than requested. I assume I'm mistaken in my reading of the code

See the code for SegmentTree below (which is a super class of SumSegmentTree. Particularly, note the method _reduce_helper:

import operator


class SegmentTree(object):
def __init__(self, capacity, operation, neutral_element):
    """Build a Segment Tree data structure.

    https://en.wikipedia.org/wiki/Segment_tree

    Can be used as regular array, but with two
    important differences:

        a) setting item's value is slightly slower.
           It is O(lg capacity) instead of O(1).
        b) user has access to an efficient ( O(log segment size) )
           `reduce` operation which reduces `operation` over
           a contiguous subsequence of items in the array.

    Paramters
    ---------
    capacity: int
        Total size of the array - must be a power of two.
    operation: lambda obj, obj -> obj
        and operation for combining elements (eg. sum, max)
        must form a mathematical group together with the set of
        possible values for array elements (i.e. be associative)
    neutral_element: obj
        neutral element for the operation above. eg. float('-inf')
        for max and 0 for sum.
    """
    assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
    self._capacity = capacity
    self._value = [neutral_element for _ in range(2 * capacity)]
    self._operation = operation

def _reduce_helper(self, start, end, node, node_start, node_end):
    if start == node_start and end == node_end:
        return self._value[node]
    mid = (node_start + node_end) // 2
    if end <= mid:
        return self._reduce_helper(start, end, 2 * node, node_start, mid)
    else:
        if mid + 1 <= start:
            return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
        else:
            return self._operation(
                self._reduce_helper(start, mid, 2 * node, node_start, mid),
                self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
            )

def reduce(self, start=0, end=None):
    """Returns result of applying `self.operation`
    to a contiguous subsequence of the array.

        self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))

    Parameters
    ----------
    start: int
        beginning of the subsequence
    end: int
        end of the subsequences

    Returns
    -------
    reduced: obj
        result of reducing self.operation over the specified range of array elements.
    """
    if end is None:
        end = self._capacity
    if end < 0:
        end += self._capacity
    end -= 1
    return self._reduce_helper(start, end, 1, 0, self._capacity - 1)

def __setitem__(self, idx, val):
    # index of the leaf
    idx += self._capacity
    self._value[idx] = val
    idx //= 2
    while idx >= 1:
        self._value[idx] = self._operation(
            self._value[2 * idx],
            self._value[2 * idx + 1]
        )
        idx //= 2

def __getitem__(self, idx):
    assert 0 <= idx < self._capacity
    return self._value[self._capacity + idx]

Note that reduce passes the default argument of 1 for the node. However, _reduce_helper has the line:

    if start == node_start and end == node_end:
        return self._value[node]

Using an example that the array has a capacity of 2^3 = 8, indexed from 0-7, if we wanted to get the sum of the entire array, the _reduce_helper function would return the value of node (index) 1 by default. Shouldn't this return the value at index 0?

It also appears to be using the convention that a node at index i has children stored at indices 2*i and (2*i) + 1 but since python is zero indexed should each parent instead have children at indices (2*i) + 1 and (2*i) + 2`?

I must be reading this wrong but need input for clarity.

0 Answers