kd-tree traversal - avoid dynamic memory allocation pyrit
authorRadek Brich <radek.brich@devl.cz>
Mon, 21 Apr 2008 08:47:36 +0200
branchpyrit
changeset 74 09aedbf5f95f
parent 73 a5127346fbcd
child 75 20dee9819b17
kd-tree traversal - avoid dynamic memory allocation use minimum storage size for KdNode (8B on 32bit cpu) vector.h - add division operator, fix semicolons
TODO
ccdemos/realtime.cc
demos/render_nff.py
include/kdtree.h
include/vector.h
src/container.cc
src/kdtree.cc
--- a/TODO	Sun Apr 20 19:27:59 2008 +0200
+++ b/TODO	Mon Apr 21 08:47:36 2008 +0200
@@ -10,10 +10,9 @@
  * kd-tree:
    - optimize structures
    - optimize construction: use box-shape intersection instead of bounding boxes of shapes
-   - optimize traversal -- no std::vector, no dynamic memory allocation
    - save/load
  * textures (3D procedural, pixmaps)
- * update Python binding: Camera, other new classes
+ * Python binding for all classes
  * stochastic oversampling
  * absorbtion of refracted rays in dense materials (can be computed using shape distance and some 'absorbance' constant)
  * implement efficient AABB-ray intersection using Plucker coordinates
--- a/ccdemos/realtime.cc	Sun Apr 20 19:27:59 2008 +0200
+++ b/ccdemos/realtime.cc	Mon Apr 21 08:47:36 2008 +0200
@@ -1,14 +1,14 @@
 #include <stdlib.h>
 
 #include "raytracer.h"
-#include "octree.h"
+#include "kdtree.h"
 
 #include "common_sdl.h"
 
 int main(int argc, char **argv)
 {
 	Raytracer rt;
-	Octree top;
+	KdTree top;
 	Camera cam;
 
 	rt.setMaxDepth(3);
--- a/demos/render_nff.py	Sun Apr 20 19:27:59 2008 +0200
+++ b/demos/render_nff.py	Mon Apr 21 08:47:36 2008 +0200
@@ -84,6 +84,8 @@
 		rt.addshape(Triangle(vertices[0], vertices[1], vertices[2], mat))
 		for i in range(vertex_count)[3:]:
 			rt.addshape(Triangle(vertices[0], vertices[i-1], vertices[i], mat))
+	elif ln[0] == '#':	# Comment
+		pass
 	else:
 		print "Not implemented:", line
 f.close()
--- a/include/kdtree.h	Sun Apr 20 19:27:59 2008 +0200
+++ b/include/kdtree.h	Mon Apr 21 08:47:36 2008 +0200
@@ -29,6 +29,7 @@
 
 #include <iostream>
 #include <fstream>
+#include <assert.h>
 
 #include "container.h"
 #include "vector.h"
@@ -42,29 +43,30 @@
 class KdNode
 {
 	Float split;
-	short axis; /* 0,1,2 => x,y,z; 3 => leaf */
-public:
 	union {
 		KdNode *children;
 		ShapeList *shapes;
+		int flags; /* 0,1,2 => x,y,z; 3 => leaf */
 	};
-
-	KdNode() : axis(3) { shapes = new ShapeList(); };
+public:
+	KdNode() { shapes = new ShapeList(); assert((flags & 3) == 0); setLeaf(); };
 	~KdNode();
 
-	void setAxis(short aAxis) { axis = aAxis; };
-	short getAxis() { return axis; };
+	void setLeaf() { flags |= 3; };
+	bool isLeaf() { return (flags & 3) == 3; };
+
+	void setAxis(int aAxis) { flags &= ~3; flags |= aAxis; };
+	short getAxis() { return flags & 3; };
 
 	void setSplit(Float aSplit) { split = aSplit; };
 	Float getSplit() { return split; };
 
-	void setLeaf() { axis = 3; };
-	bool isLeaf() { return axis == 3; };
+	void setChildren(KdNode *node) { children = node; assert((flags & 3) == 0); };
+	KdNode *getLeftChild() { return (KdNode*)((off_t)children & ~3); };
+	KdNode *getRightChild() { return (KdNode*)((off_t)children & ~3) + 1; };
 
-	KdNode *getLeftChild() { return children; };
-	KdNode *getRightChild() { return children+1; };
-
-	void addShape(Shape* aShape) { shapes->push_back(aShape); };
+	ShapeList *getShapes() { return (ShapeList*)((off_t)shapes & ~3); };
+	void addShape(Shape* aShape) { getShapes()->push_back(aShape); };
 
 	void subdivide(BBox bbox, int maxdepth);
 };
--- a/include/vector.h	Sun Apr 20 19:27:59 2008 +0200
+++ b/include/vector.h	Mon Apr 21 08:47:36 2008 +0200
@@ -66,21 +66,21 @@
 		y *= f;
 		z *= f;
 		return *this;
-	}
+	};
 
 	// get normalized copy
 	Vector3 unit() const
 	{
 		Vector3 u(*this);
-		return u.normalize();;
-	}
+		return u.normalize();
+	};
 
 	// square magnitude, magnitude
-	Float mag2() const	{ return x * x + y * y + z * z; }
-	Float mag() const	{ return sqrtf(mag2()); }
+	Float mag2() const	{ return x * x + y * y + z * z; };
+	Float mag() const	{ return sqrtf(mag2()); };
 
 	// negative
-	Vector3 operator-() const { return Vector3(-x, -y, -z); }
+	Vector3 operator-() const { return Vector3(-x, -y, -z); };
 
 	// accumulate
 	Vector3 operator+=(const Vector3 &v)
@@ -130,24 +130,35 @@
 	friend Vector3 operator*(const Vector3 &v, const Float &f)
 	{
 		return Vector3(f * v.x, f * v.y, f * v.z);
-	}
+	};
 
 	friend Vector3 operator*(const Float &f, const Vector3 &v)
 	{
 		return v * f;
 	};
 
+	// scalar division
+	friend Vector3 operator/(const Vector3 &v, const Float &f)
+	{
+		return Vector3(v.x / f, v.y / f, v.z / f);
+	};
+
+	friend Vector3 operator/(const Float &f, const Vector3 &v)
+	{
+		return Vector3(f / v.x, f / v.y, f / v.z);
+	};
+
 	// vector plus scalar
 	friend Vector3 operator+(const Vector3 &v, const Float &f)
 	{
 		return Vector3(v.x + f, v.y + f, v.z + f);
-	}
+	};
 
 	// vector minus scalar
 	friend Vector3 operator-(const Vector3 &v, const Float &f)
 	{
 		return Vector3(v.x - f, v.y - f, v.z - f);
-	}
+	};
 
 	// cell by cell product (only usable for colours)
 	friend Vector3 operator*(const Vector3 &a,  const Vector3 &b)
@@ -159,7 +170,7 @@
 	friend ostream & operator<<(ostream &st, const Vector3 &v)
 	{
 		return st << "(" << v.x << ", " << v.y  << ", " << v.z << ")";
-	}
+	};
 };
 
 typedef Vector3 Colour;
--- a/src/container.cc	Sun Apr 20 19:27:59 2008 +0200
+++ b/src/container.cc	Mon Apr 21 08:47:36 2008 +0200
@@ -29,21 +29,22 @@
 
 void Container::addShape(Shape* aShape)
 {
+	const Float e = 10*Eps;
 	shapes.push_back(aShape);
 	if (shapes.size() == 0) {
 		/* initialize bounding box */
 		bbox = aShape->get_bbox();
-		Vector3 eps(Eps,Eps,Eps);
-		bbox = BBox(bbox.L - eps, bbox.H + eps);
+		const Vector3 E(e, e, e);
+		bbox = BBox(bbox.L - E, bbox.H + E);
 	} else {
 		/* adjust bounding box */
 		BBox shapebb = aShape->get_bbox();
-		if (shapebb.L.x - Eps < bbox.L.x)  bbox.L.x = shapebb.L.x - Eps;
-		if (shapebb.L.y - Eps < bbox.L.y)  bbox.L.y = shapebb.L.y - Eps;
-		if (shapebb.L.z - Eps < bbox.L.z)  bbox.L.z = shapebb.L.z - Eps;
-		if (shapebb.H.x + Eps > bbox.H.x)  bbox.H.x = shapebb.H.x + Eps;
-		if (shapebb.H.y + Eps > bbox.H.y)  bbox.H.y = shapebb.H.y + Eps;
-		if (shapebb.H.z + Eps > bbox.H.z)  bbox.H.z = shapebb.H.z + Eps;
+		if (shapebb.L.x - e < bbox.L.x)  bbox.L.x = shapebb.L.x - e;
+		if (shapebb.L.y - e < bbox.L.y)  bbox.L.y = shapebb.L.y - e;
+		if (shapebb.L.z - e < bbox.L.z)  bbox.L.z = shapebb.L.z - e;
+		if (shapebb.H.x + e > bbox.H.x)  bbox.H.x = shapebb.H.x + e;
+		if (shapebb.H.y + e > bbox.H.y)  bbox.H.y = shapebb.H.y + e;
+		if (shapebb.H.z + e > bbox.H.z)  bbox.H.z = shapebb.H.z + e;
 	}
 };
 
--- a/src/kdtree.cc	Sun Apr 20 19:27:59 2008 +0200
+++ b/src/kdtree.cc	Mon Apr 21 08:47:36 2008 +0200
@@ -48,14 +48,12 @@
 };
 
 // stack element for kd-tree traversal
-class StackElem
+struct StackElem
 {
-public:
 	KdNode* node; /* pointer to far child */
 	Float t; /* the entry/exit signed distance */
 	Vector3 pb; /* the coordinates of entry/exit point */
-	StackElem(KdNode *anode, const Float &at, const Vector3 &apb):
-		node(anode), t(at), pb(apb) {};
+	int prev;
 };
 
 // ----------------------------------------
@@ -63,15 +61,15 @@
 KdNode::~KdNode()
 {
 	if (isLeaf())
-		delete shapes;
+		delete getShapes();
 	else
-		delete[] children;
+		delete[] getLeftChild();
 }
 
 // kd-tree recursive build algorithm, inspired by PBRT (www.pbrt.org)
 void KdNode::subdivide(BBox bounds, int maxdepth)
 {
-	if (maxdepth <= 0 || shapes->size() <= 2)
+	if (maxdepth <= 0 || getShapes()->size() <= 2)
 	{
 		setLeaf();
 		return;
@@ -87,7 +85,7 @@
 	// create sorted list of shape bounds (= find all posible splits)
 	vector<ShapeBound> edges[3];
 	ShapeList::iterator shape;
-	for (shape = shapes->begin(); shape != shapes->end(); shape++)
+	for (shape = getShapes()->begin(); shape != getShapes()->end(); shape++)
 	{
 		BBox shapebounds = (*shape)->get_bbox();
 		for (int ax = 0; ax < 3; ax++)
@@ -103,12 +101,13 @@
 	const Float K = 1.4; // constant, K = cost of traversal / cost of ray-triangle intersection
 	Float SAV = (bounds.w()*bounds.h() +  // surface area of node
 		bounds.w()*bounds.d() + bounds.h()*bounds.d());
-	Float cost = SAV * (K + shapes->size()); // initial cost = non-split cost
+	Float cost = SAV * (K + getShapes()->size()); // initial cost = non-split cost
 
 	vector<ShapeBound>::iterator edge, splitedge = edges[2].end();
+	int axis = 0;
 	for (int ax = 0; ax < 3; ax++)
 	{
-		int lnum = 0, rnum = shapes->size();
+		int lnum = 0, rnum = getShapes()->size();
 		BBox lbb = bounds;
 		BBox rbb = bounds;
 		for (edge = edges[ax].begin(); edge != edges[ax].end(); edge++)
@@ -166,22 +165,24 @@
 #endif
 
 	// split this node
-	delete shapes;
+	delete getShapes();
+
 	BBox lbb = bounds;
 	BBox rbb = bounds;
 	lbb.H.cell[axis] = split;
 	rbb.L.cell[axis] = split;
-	children = new KdNode[2];
+	setChildren(new KdNode[2]);
+	setAxis(axis);
 
 	for (edge = edges[axis].begin(); edge != splitedge; edge++)
 		if (!edge->end && edge->shape->intersect_bbox(lbb))
-			children[0].addShape(edge->shape);
+			getLeftChild()->addShape(edge->shape);
 	for (edge = splitedge; edge < edges[axis].end(); edge++)
 		if (edge->end && edge->shape->intersect_bbox(rbb))
-			children[1].addShape(edge->shape);
+			getRightChild()->addShape(edge->shape);
 
-	children[0].subdivide(lbb, maxdepth-1);
-	children[1].subdivide(rbb, maxdepth-1);
+	getLeftChild()->subdivide(lbb, maxdepth-1);
+	getRightChild()->subdivide(rbb, maxdepth-1);
 }
 
 void KdTree::build()
@@ -211,40 +212,46 @@
 
 	/* pointers to the far child node and current node */
 	KdNode *farchild, *node;
-	node = root; /* start from the kd-tree root node */
+	node = root;
 
-	/* std vector is much faster than stack */
-	vector<StackElem*> st;
+	StackElem stack[max_depth];
 
-	StackElem *enPt = new StackElem(NULL, a,
-		/* distinguish between internal and external origin of a ray*/
-		a >= 0.0 ?
-			ray.o + ray.dir * a : /* external */
-			ray.o);               /* internal */
+	int entry = 0, exit = 1;
+	stack[entry].t = a;
+
+	/* distinguish between internal and external origin of a ray*/
+	if (a >= 0.0)
+		stack[entry].pb = ray.o + ray.dir * a; /* external */
+	else
+		stack[entry].pb = ray.o;               /* internal */
 
 	/* setup initial exit point in the stack */
-	StackElem *exPt = new StackElem(NULL, b, ray.o + ray.dir * b);
-	st.push_back(exPt);
+	stack[exit].t = b;
+	stack[exit].pb = ray.o + ray.dir * b;
+	stack[exit].node = NULL;
 
 	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
+	Float splitVal;
+	int axis;
+	static const int mod3[] = {0,1,2,0,1};
+	const Vector3 invdir = 1 / ray.dir;
 	while (node)
 	{
-		exPt = st.back();
 		/* loop until a leaf is found */
 		while (!node->isLeaf())
 		{
 			/* retrieve position of splitting plane */
-			Float splitVal = node->getSplit();
-			short axis = node->getAxis();
+			splitVal = node->getSplit();
+			axis = node->getAxis();
 
-			if (enPt->pb[axis] <= splitVal)
+			if (stack[entry].pb[axis] <= splitVal)
 			{
-				if (exPt->pb[axis] <= splitVal)
+				if (stack[exit].pb[axis] <= splitVal)
 				{ /* case N1, N2, N3, P5, Z2, and Z3 */
 					node = node->getLeftChild();
 					continue;
 				}
-				if (exPt->pb[axis] == splitVal)
+				if (stack[exit].pb[axis] == splitVal)
 				{ /* case Z1 */
 					node = node->getRightChild();
 					continue;
@@ -254,14 +261,14 @@
 				node = node->getLeftChild();
 			}
 			else
-			{ /* (enPt->pb[axis] > splitVal) */
-				if (splitVal < exPt->pb[axis])
+			{ /* (stack[entry].pb[axis] > splitVal) */
+				if (splitVal < stack[exit].pb[axis])
 				{
 					/* case P1, P2, P3, and N5 */
 					node = node->getRightChild();
 					continue;
 				}
-				/* case P4*/
+				/* case P4 */
 				farchild = node->getLeftChild();
 				node = node->getRightChild();
 			}
@@ -269,50 +276,48 @@
 			/* case P4 or N4 . . . traverse both children */
 
 			/* signed distance to the splitting plane */
-			t = (splitVal - ray.o.cell[axis]) / ray.dir.cell[axis];
+			t = (splitVal - ray.o[axis]) * invdir[axis];
 
 			/* setup the new exit point and push it onto stack */
-			exPt = new StackElem(farchild, t, Vector3());
-			exPt->pb.cell[axis] = splitVal;
-			exPt->pb.cell[(axis+1)%3] = ray.o.cell[(axis+1)%3] + t * ray.dir.cell[(axis+1)%3];
-			exPt->pb.cell[(axis+2)%3] = ray.o.cell[(axis+2)%3] + t * ray.dir.cell[(axis+2)%3];
-			st.push_back(exPt);
-		} /* while */
+			register int tmp = exit;
+
+			exit++;
+			if (exit == entry)
+				exit++;
+			assert(exit < max_depth);
+
+			stack[exit].prev = tmp;
+			stack[exit].t = t;
+			stack[exit].node = farchild;
+			stack[exit].pb.cell[axis] = splitVal;
+			stack[exit].pb.cell[mod3[axis+1]] = ray.o.cell[mod3[axis+1]]
+				+ t * ray.dir.cell[mod3[axis+1]];
+			stack[exit].pb.cell[mod3[axis+2]] = ray.o.cell[mod3[axis+2]]
+				+ t * ray.dir.cell[mod3[axis+2]];
+		}
 
 		/* current node is the leaf . . . empty or full */
-		/* "intersect ray with each object in the object list, discarding "
-		"those lying before stack[enPt].t or farther than stack[exPt].t" */
 		Shape *nearest_shape = NULL;
-		Float dist = exPt->t;
+		Float dist = stack[exit].t;
 		ShapeList::iterator shape;
-		for (shape = node->shapes->begin(); shape != node->shapes->end(); shape++)
+		for (shape = node->getShapes()->begin(); shape != node->getShapes()->end(); shape++)
 			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
-			&& dist >= enPt->t)
+			&& dist >= stack[entry].t)
 			{
 				nearest_shape = *shape;
 				nearest_distance = dist;
 			}
 
-		delete enPt;
-
 		if (nearest_shape)
-		{
-			while (!st.empty())
-			{
-				delete st.back();
-				st.pop_back();
-			}
 			return nearest_shape;
-		}
 
-		enPt = exPt;
+		entry = exit;
 
-		/* retrieve the pointer to the next node, it is possible that ray traversal terminates */
-		node = enPt->node;
-		st.pop_back();
-	} /* while */
-
-	delete enPt;
+		/* retrieve the pointer to the next node,
+		it is possible that ray traversal terminates */
+		node = stack[entry].node;
+		exit = stack[entry].prev;
+	}
 
 	/* ray leaves the scene */
 	return NULL;
@@ -326,7 +331,7 @@
 	if (node == NULL)
 		node = root;
 	if (node->isLeaf())
-		str << "(leaf: " << node->shapes->size() << " shapes)";
+		str << "(leaf: " << node->getShapes()->size() << " shapes)";
 	else
 	{
 		str << "(split " << (char)('x'+node->getAxis()) << " at " << node->getSplit() << "; L=";