Friday, November 30, 2012

KD Tree Range Bounded Algorithm

Searching a K-d Tree is simple enough, its just a basic binary search based off the axis you're splitting by on your nodes. Doing a nearest neighbor search is one of the most basic operations you can do on a K-d Tree and is typically what you use this type of tree for. But because of the geometric properties of the tree itself and how each node creates a splitting plane for the space, we can make an awesome algorithm that returns points in the tree that are within a specific region of space.

This is for a 2 dimensional tree and space, but could be extended to 3 easily, and probably up to k dimensions after having a headache for a while:

So First off we need to be able to define a Region in space. Because the K-d Tree splits on planes, essentially we're going to define a box in K space, so it makes sense to have our regions be square in some fashion. So here's the code to construct a region:


public class Region{

double xleft   =  -Float.MAX_VALUE;
double xright  =   Float.MAX_VALUE;
double ybottom =  -Float.MAX_VALUE;
double ytop    =   Float.MAX_VALUE;


//Creates a region containing the entire coordinate plane
public Region(){
}


//Creates a Region defined by the four locations of its sides
public Region(double left, double right, double top, double bottom){
//Assume we might mess up and not put them in the right order for lefts and rights
if(left < right){
xleft = left;
xright = right;
}else{
xleft = right;
xright = left;
}
if(bottom < top){
ybottom = bottom;
ytop = top;
}else{
ybottom = top;
ytop = bottom;
}
}
}

I've made some assumptions about not being able to tell my left from my right but the key point that we'll take advantage of in the algorithm later is that a Region's default is to be the entire coordinate space. Next we need to add some helping functions to alter the region


public Region setLeft(double left){
this.xleft = left;
return this;
}

public Region setRight(double right){
this.xright = right;
return this;
}

public Region setTop(double top){
this.ytop = top;
return this;
}

public Region setBottom(double bottom){
this.ybottom = bottom;
return this;
}


//Checks to see if this Region contains the point
public boolean fullyContains(Node node){
float y =node.getY();
float x = node.getX();
return ybottom <= y && y <= ytop && xleft <= x && x <= xright;
}

public boolean fullyContains(Region r){
return xleft <= r.xleft && r.xright <= xright && r.ytop <= ytop && ybottom <= r.ybottom;
}

public boolean intersects(Region r){
return !(xright < r.xleft) && !(r.xright < xleft) && !(r.ytop < ybottom) && !(ytop < r.ybottom);
}

Now it's important that the setLeft,Right,Top, and Bottom sections return a Region, we'll see why in a bit. But the more interesting part of this code is the intersects function. To see if two Regions intersect, you could do an 8 part boolean expression checking all the possible combinations of the two Regions intersecting. OR you can be clever about it. The expression to see if two Regions don't intersect is far easier than the larger expression we just talked about. You just check if it's right side is strictly less than the other regions is the left side. You do this for all 4 sides and then you negate the expression. By deMorgans laws your OR's become not ands and pow. Your intersects function is quick, efficient and done. Now lets get to the actual algorithm:

private ArrayList<Node> boundedSearch(Region sRegion,ArrayList<Node> results,int depth,Region bRegion){
int axis = depth % coordinates.length;
if(this.isLeaf()){
if(sRegion.fullyContains(this)){
results.add(this);
}
}else{
//Subtree we need to redefine our bounding region bRegion
if(axis == 0){
//We are splitting on the x axis
if(sRegion.fullyContains(bRegion)){
//We are in the region so we report ourselves
results.add(this);
}else{
if(sRegion.fullyContains(this)){
results.add(this);
}
}
if(sRegion.intersects(bRegion)){
if(this.left != null){
//Search down the left with our splitting line as a bound on the right
return this.left.innerBoundedSearch(sRegion,results,depth+1,bRegion.setRight(this.getCoordinate(axis)));
}else{
//Null link return results
return results;
}
}
if(sRegion.intersects(bRegion)){
if(this.right != null){
//Search down the left with a splitting line as a bound on the left
return this.right.innerBoundedSearch(sRegion,results,depth+1,bRegion.setLeft(this.getCoordinate(axis)));
}
}
}else{
//We are splitting on the y  axis
if(sRegion.fullyContains(bRegion)){
//We are in the region so we report ourselves
results.add(this);
}else{
if(sRegion.fullyContains(this)){
results.add(this);
}
}
if(sRegion.intersects(bRegion)){
if(this.left != null){
//Search down the left with our splitting line as a bound on the right
return this.left.innerBoundedSearch(sRegion,results,depth+1,bRegion.setTop(this.getCoordinate(axis)));
}else{
//Null link return results
return results;
}
}
if(sRegion.intersects(bRegion)){
if(this.right != null){
//Search down the left with a splitting line as a bound on the left
return this.right.innerBoundedSearch(sRegion,results,depth+1,bRegion.setBottom(this.getCoordinate(axis)));
}
}
}
}
return results;
}

Alright, (one day I'll figure out how to display code nicer on these blog posts.) Let me describe how the algorithm works. The basic idea is that as you recurse down the tree, the splitting planes of the node you're visiting becomes a sides of a bounded region. This bounded region begins as the entire coordinate axis, then becomes half of that, then half of that, and etc etc. So each time you're creating a smaller and smaller space in which nodes can exist. Once this bounded region is contained within the search region, then we can report that entire sub tree as being contained. If we reach a leaf node, its a simple check to see if it is contained within the tree. 

I find that thinking of this geometrically is much easier than staring at just the code and trying to understand what it does. There is some work you have to do after the results is returned. You do need to collect all the leaves in the subtrees and add them to the results. 

public  ArrayList<Node> collectTree(ArrayList<Node> results){
//Traverse the tree and put all children in there
if(this.left!= null){
results = this.left.collectTree(results);
}
if(this.right!=null){
results = this.right.collectTree(results);
}
if(!results.contains(this)){
results.add(this);
}
return results;
}


Once you've collected these lists for each node you do need to add all together. But why write an arraylist merging function when you can do it all at once with a flatten. Observe:


private static ArrayList<Node> flatten(ArrayList<Node> nodes){
//Flatten each nodes subtree out.
ArrayList<Node> temp = new ArrayList<Node>();
for(Node node : nodes){
temp = node.collectTree(temp);
}
return temp;
}

Once you have this, its easy to collect the results of your search and report all of them in one list to a client. There is one important thing I'd like to mention though, in the collectTree function, notice the containment check at the end. This ensures no duplicate nodes are added to the list. If you didn't have this, you would report the subtree of any subtree returned by the search at least twice. Or precisely, the leaves would be reported for whatever depth they were away from the subtree that was reported.


The project I've used this in is located at:
Tree Class
Region