Building balanced K-D tree - median split/partition implementation

57 views Asked by At

I am struggling to understand how to perform a split of k-Dimensional data points according to a median of a given axis (part of balanced K-D tree construction). I also was given an example dataset in two representations, and for both I cannot come up with an implementation of point splits.

Let's consider the following 2-D points (for simplicity, but the algo should work for k-dimensions):

Repr 1: points = [(5,4), (2,1), (0,9), (3,8), (7,2)]

Repr 2: x=[5,2,0,3,7] y=[4,1,9,8,2]

Let's assume we want to split by x-axis. We expect the median to be (3,8) and subsequently we want to have all x<=3 elements passed to recursive left subtree building, and elements with x>3 passed to right subtree.

Representation 1 - I saw some implementations using sorting with comparators etc. but I am not sure how that translates to k-dimensions. So we can have a class

class Point {
    int[] coords
}

Now how do we sort this cleanly against k-axis? I was thinking of Arrays.sort with some comparator, but I cannot really figure it out... Then the split is easy - we can build two sublists until median and then post-median, but I cannot get the full algo.

Representation 2 - I am really struggling here. Sorting along a given array is simple, we select array at index k and sort it, then median is at length/2. But we have multiple arrays that need to be partitioned against the median in other array, which I cannot solve. In our example, sorting and partitioning x results in [0,2,3,5,7] -> [0,2] [5,7], but we now want to have y sorted/partitioned the same way yielding [9,1,8,4,2] -> [9,1] [4,2]. This is where I am almost completely clueless - once we sort x, we lose the original order, so we cannot get back the original index mapping (0 -> 9 etc.) as the order is now changed... Is this represantation just bad to solve this problem, or potentially impossible? Is there any way, even non-optimal and slow, to sort/partition multiple arrays according to the order of another array?

Any pseudocode/Java would be great, or any pointers in general.

EDIT - Followup Question

What if we introduced a heuristic, that at each level we choose the axis with largest range? Would we simply travers sublists in k dimensions and keep range for each, and then choose max, or is there a better way? Also, we would then need to store that information in KDTreeNode, so that we know by which axis to split, correct?

1

There are 1 answers

4
Luatic On

Point class

Let's start with a point class, in proper Java style:

class Point {
    private int[] components;
    public Point(int... components) {
        this.components = components;
    }
    public int getDimension() {
        return components.length;
    }
    public int getComponent(int i) {
        return components[i];
    }
    public void setComponent(int i, int v) {
        components[i] = v;
    }
}

Sorting by axis

Indeed we need a comparator for this. In old school Java this is a bit of a ceremony - we have to write an AxisComparator<Point> class with an axis field which then compares based on component:

class AxisComparator<Point> implements Comparator<Point> {
    private int axis;
    public AxisComparator(int axis) {
        this.axis = axis;
    }
    public int compare(Point p, Point q) {
        return Integer.compare(p.getComponent(axis), q.getComponent(axis));
    }
}

You could then use this as new AxisComparator(k).

In "modern" Java, you could get away with a one-liner:

Comparator.comparingInt((Point p) -> p.getComponent(k));

Arrays.sort accepts a comparator.

KdTree class

We will want to distinguish between the tree - which stores a potentially null (in case of an empty tree) pointer to the root - and tree nodes. The tree nodes should get a recursive private constructor, taking the array and the axis we want to split on and producing a k-d-tree which splits on that axis. The tree should get a public constructor taking an array of points.

class KdTree {
    record Node(int axis, Node left, Node right, Point pivot) {
        // Note: points is mutated (sorted by axis)
        static Node build(Point[] points, int axis) {
            if (points.length == 0)
                return null; // empty node
            Arrays.sort(points, Comparator.comparingInt((Point p) -> p.getComponent(axis)));
            var mid = points.length / 2;
            var leqPoints = Arrays.copyOf(points, mid);
            var geqPoints = Arrays.copyOfRange(points, mid, points.length);
            var nextAxis = (axis + 1) % points[0].getDimension();
            return Node(axis, build(leqPoints, nextAxis), build(geqPoints, nextAxis), points[mid]);
        }
    }
    private Node root;
    public KdTree(Point[] points) {
        root = Node.build(points.clone(), 0);
    }
    // Implement operations on your k-d-tree, like finding the nearest neighbor to a point here
}

Better algorithms

This "naive" algorithm of sorting by an axis for each split is not ideal in terms of performance; it incurs O(n log n) costs at each of the O(log n) levels of the trees, resulting in O(n (log n)²), which is not bad, but also not optimal.

I have implemented all three approaches in Lua a while ago here. If you don't know Lua, read it as pseudocode.

Presorting

One option to optimize this is to pre-sort the points by each of the k axes, incurring costs of O(k n log n) for the pre-sorting, and then filtering the k pre-sorted lists into left & right parts as you split, incurring kn costs for each split, for a tree of depth O(log n), resulting in O(k n log n). This can be better than O(n (log n)²) if n is big and k is small (say, 2, 3 or 4).

To be more precise: What you pre-sort are either (pointers to) instances of a Point class, or indices (if you opt for the PointArray-based solution). So you end up with k pre-sorted arrays of Points / int indices.

When you split, you mark all points left to the median. If you have an array of Points, one approach would be to add a boolean _marked field to the Point class. This field then lets you filter the points along all other axes, splitting them into "marked" and "unmarked" points (except the pivot). For looser coupling, you would probably want to create a MarkablePoint class extending Point and convert all points to that initially (before presorting), or you might prefer a HashSet of marked points (by reference). For indices, you could also use a HashSet, or a boolean array.

Median of medians

The asymptotically optimal option is to use a linear time median selection algorithm (median of medians). Using this, you get O(n) costs for O(log n) layers, for a total of O(n log n).

Out-of-band representation (low-level optimizations)

Representing an array of points as, well, an array of points is definitely the most idiomatic, straightforward, simple way to implement this in an OOP language like Java. I'd stick to this for an initial implementation unless you have a good reason not to. I would not prematurely optimize this.

As you said, you want to keep the components of a point together, and that is cumbersome to do if you store separate arrays of components.

But if you do have to optimize this at a low-level (maybe you have determined that the points eat up too much memory, or the heap allocation, GC or indirection overhead is too large or the cache locality is too bad) by using multiple arrays, write yourself something like a PointArray class which manages an array of k arrays of components. Sorting could simply sort a permutation of indices into these k arrays (well, after the first split of these indices, it is only a permutation of the restricted set of indices, but you get the idea).

This could look like this:

class PointArray {
    private int[][] points; // points[i][j] = i-th coordinate of j-th point
    private int[] indices; // into points
    public PointArray(int[][] points) {
        this.points = points;
        int n = points[0].length;
        assert n > 0;
        for (int i = 1; i < points.length; i++)
            assert points[i].length == n;
        indices = new int[n];
        for (int i = 1; i < n; i++)
            indices[i] = i;
    }
    private PointArray(int[] indices, int[][] points) {
        this.indices = indices;
        this.points = points;
    }
    public int getDimension() {
        return points.length;
    }
    public int getLength() {
        return indices.length;
    }
    public int getComponent(int pointIdx, int axis) {
        return points[axis][indices[pointIdx]];
    }
    public Point getPoint(int i) {
        var components = new int[getDimension()];
        for (int axis = 0; axis < components.length; axis++)
             components[axis] = getComponent(i, axis);
        return new Point(components);
    }
    public void sortByAxis(int axis) {
        Arrays.sort(indices, Comparator.comparingInt((int i) -> points[axis][i]));
    }
    // To slice the array, it suffices to slice the indices.
    // We do not have to slice the points.
    public PointArray slice(int from, int to /*exclusive*/) {
        return new PointArray(Arrays.copyOfRange(indices, from, to), points);
    }
}

Really this approach isn't all that different from having an array of points, except instead of pointers to heap-allocated points, we have indices into our "array-allocated" matrix of point coordinates. This requires some small changes to our k-d-tree:

class KdTree {
    record Node(int axis, Node left, Node right, Point pivot) {
        // Note: points is mutated (sorted by axis)
        static Node build(PointArray points, int axis) {
            if (points.getLength() == 0)
                return null; // empty node
            points.sortByAxis(axis);
            var mid = points.getLength() / 2;
            var leqPoints = points.slice(0, mid);
            var geqPoints = points.slice(0, points.getLength());
            var nextAxis = (axis + 1) % points.getDimension();
            return Node(axis, build(leqPoints, nextAxis), build(geqPoints, nextAxis), points.getPoint(mid));
        }
    }
    private Node root;
    public KdTree(int[][] points) {
        root = Node.build(new PointArray(points), 0);
    }
    // Implement operations on your k-d-tree, like finding the nearest neighbor to a point here
}