I am struggling to understand how to perform a split of k-Dimensional data points according to a median of a given axis (part of balanced K-D tree construction). I also was given an example dataset in two representations, and for both I cannot come up with an implementation of point splits.
Let's consider the following 2-D points (for simplicity, but the algo should work for k-dimensions):
Repr 1:
points = [(5,4), (2,1), (0,9), (3,8), (7,2)]
Repr 2:
x=[5,2,0,3,7] y=[4,1,9,8,2]
Let's assume we want to split by x-axis. We expect the median to be (3,8) and subsequently we want to have all x<=3 elements passed to recursive left subtree building, and elements with x>3 passed to right subtree.
Representation 1 - I saw some implementations using sorting with comparators etc. but I am not sure how that translates to k-dimensions. So we can have a class
class Point {
int[] coords
}
Now how do we sort this cleanly against k-axis? I was thinking of Arrays.sort with some comparator, but I cannot really figure it out... Then the split is easy - we can build two sublists until median and then post-median, but I cannot get the full algo.
Representation 2 - I am really struggling here. Sorting along a given array is simple, we select array at index k and sort it, then median is at length/2. But we have multiple arrays that need to be partitioned against the median in other array, which I cannot solve. In our example, sorting and partitioning x results in [0,2,3,5,7] -> [0,2] [5,7], but we now want to have y sorted/partitioned the same way yielding [9,1,8,4,2] -> [9,1] [4,2]. This is where I am almost completely clueless - once we sort x, we lose the original order, so we cannot get back the original index mapping (0 -> 9 etc.) as the order is now changed... Is this represantation just bad to solve this problem, or potentially impossible? Is there any way, even non-optimal and slow, to sort/partition multiple arrays according to the order of another array?
Any pseudocode/Java would be great, or any pointers in general.
EDIT - Followup Question
What if we introduced a heuristic, that at each level we choose the axis with largest range? Would we simply travers sublists in k dimensions and keep range for each, and then choose max, or is there a better way? Also, we would then need to store that information in KDTreeNode, so that we know by which axis to split, correct?
PointclassLet's start with a point class, in proper Java style:
Sorting by axis
Indeed we need a comparator for this. In old school Java this is a bit of a ceremony - we have to write an
AxisComparator<Point>class with anaxisfield which then compares based on component:You could then use this as
new AxisComparator(k).In "modern" Java, you could get away with a one-liner:
Arrays.sortaccepts a comparator.KdTreeclassWe will want to distinguish between the tree - which stores a potentially
null(in case of an empty tree) pointer to the root - and tree nodes. The tree nodes should get a recursiveprivateconstructor, taking the array and the axis we want to split on and producing a k-d-tree which splits on that axis. The tree should get a public constructor taking an array of points.Better algorithms
This "naive" algorithm of sorting by an axis for each split is not ideal in terms of performance; it incurs O(n log n) costs at each of the O(log n) levels of the trees, resulting in O(n (log n)²), which is not bad, but also not optimal.
I have implemented all three approaches in Lua a while ago here. If you don't know Lua, read it as pseudocode.
Presorting
One option to optimize this is to pre-sort the points by each of the k axes, incurring costs of O(k n log n) for the pre-sorting, and then filtering the k pre-sorted lists into left & right parts as you split, incurring kn costs for each split, for a tree of depth O(log n), resulting in O(k n log n). This can be better than O(n (log n)²) if n is big and k is small (say, 2, 3 or 4).
To be more precise: What you pre-sort are either (pointers to) instances of a
Pointclass, or indices (if you opt for thePointArray-based solution). So you end up withkpre-sorted arrays ofPoints /intindices.When you split, you mark all points left to the median. If you have an array of
Points, one approach would be to add a boolean_markedfield to thePointclass. This field then lets you filter the points along all other axes, splitting them into "marked" and "unmarked" points (except the pivot). For looser coupling, you would probably want to create aMarkablePointclass extendingPointand convert all points to that initially (before presorting), or you might prefer aHashSetof marked points (by reference). For indices, you could also use aHashSet, or a boolean array.Median of medians
The asymptotically optimal option is to use a linear time median selection algorithm (median of medians). Using this, you get O(n) costs for O(log n) layers, for a total of O(n log n).
Out-of-band representation (low-level optimizations)
Representing an array of points as, well, an array of points is definitely the most idiomatic, straightforward, simple way to implement this in an OOP language like Java. I'd stick to this for an initial implementation unless you have a good reason not to. I would not prematurely optimize this.
As you said, you want to keep the components of a point together, and that is cumbersome to do if you store separate arrays of components.
But if you do have to optimize this at a low-level (maybe you have determined that the points eat up too much memory, or the heap allocation, GC or indirection overhead is too large or the cache locality is too bad) by using multiple arrays, write yourself something like a
PointArrayclass which manages an array of k arrays of components. Sorting could simply sort a permutation of indices into these k arrays (well, after the first split of these indices, it is only a permutation of the restricted set of indices, but you get the idea).This could look like this:
Really this approach isn't all that different from having an array of points, except instead of pointers to heap-allocated points, we have indices into our "array-allocated" matrix of point coordinates. This requires some small changes to our k-d-tree: