src/kdtree.cc
author Radek Brich <radek.brich@devl.cz>
Sun, 27 Apr 2008 22:55:17 +0200 (2008-04-27)
branchpyrit
changeset 87 1081e3dd3f3e
parent 86 ce6abe0aeeae
child 91 9d66d323c354
permissions -rw-r--r--
Sphere, Box - RayPacket intersection replace 5x oversampling with 4x uniform oversampling
/*
 * kdtree.cc: KdTree class
 *
 * This file is part of Pyrit Ray Tracer.
 *
 * Copyright 2006, 2007  Radek Brich
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include <algorithm>
#include <stack>
#include <string>
#include <sstream>

#include "kdtree.h"
#include "serialize.h"

class ShapeBound
{
public:
	Shape *shape;
	Float pos;
	bool end;
	ShapeBound(Shape *ashape, const Float apos, const bool aend):
		shape(ashape), pos(apos), end(aend) {};
	friend bool operator<(const ShapeBound& a, const ShapeBound& b)
	{
		if (a.pos == b.pos)
			return a.shape < b.shape;
		else
			return a.pos < b.pos;
	};
};

// stack element for kd-tree traversal
struct StackElem
{
	KdNode* node; /* pointer to far child */
	Float t; /* the entry/exit signed distance */
	Vector3 pb; /* the coordinates of entry/exit point */
	int prev;
};

// ----------------------------------------

KdNode::~KdNode()
{
	if (isLeaf())
		delete getShapes();
	else
		delete[] getLeftChild();
}

// kd-tree recursive build algorithm, inspired by PBRT (www.pbrt.org)
void KdTree::recursive_build(KdNode *node, BBox bounds, int maxdepth)
{
	ShapeList *shapes = node->getShapes();

	if (maxdepth <= 0 || shapes->size() <= 2)
	{
		node->setLeaf();
		return;
	}

	// choose split axis
	/*axis = 0;
	if (bounds.h() > bounds.w() && bounds.h() > bounds.d())
		axis = 1;
	if (bounds.d() > bounds.w() && bounds.d() > bounds.h())
		axis = 2;
*/
	// create sorted list of shape bounds (= find all posible splits)
	vector<ShapeBound> edges[3];
	ShapeList::iterator shape;
	for (shape = shapes->begin(); shape != shapes->end(); shape++)
	{
		BBox shapebounds = (*shape)->get_bbox();
		for (int ax = 0; ax < 3; ax++)
		{
			edges[ax].push_back(ShapeBound(*shape, shapebounds.L[ax], 0));
			edges[ax].push_back(ShapeBound(*shape, shapebounds.H[ax], 1));
		}
	}
	for (int ax = 0; ax < 3; ax++)
		sort(edges[ax].begin(), edges[ax].end());

	// choose best split pos
	const Float K = 1.4; // constant, K = cost of traversal / cost of ray-triangle intersection
	Float SAV = (bounds.w()*bounds.h() +  // surface area of node
		bounds.w()*bounds.d() + bounds.h()*bounds.d());
	Float cost = SAV * (K + shapes->size()); // initial cost = non-split cost

	vector<ShapeBound>::iterator edge, splitedge = edges[2].end();
	int axis = 0;
	for (int ax = 0; ax < 3; ax++)
	{
		int lnum = 0, rnum = shapes->size();
		BBox lbb = bounds;
		BBox rbb = bounds;
		for (edge = edges[ax].begin(); edge != edges[ax].end(); edge++)
		{
			if (edge->end)
				rnum--;

			// calculate SAH cost of this split
			lbb.H.cell[ax] = edge->pos;
			rbb.L.cell[ax] = edge->pos;
			Float SAL = (lbb.w()*lbb.h() + lbb.w()*lbb.d() + lbb.h()*lbb.d());
			Float SAR = (rbb.w()*rbb.h() + rbb.w()*rbb.d() + rbb.h()*rbb.d());
			Float splitcost = K*SAV + SAL*(K + lnum) + SAR*(K + rnum);

			if (splitcost < cost)
			{
				axis = ax;
				splitedge = edge;
				cost = splitcost;
			}

			if (!edge->end)
				lnum++;
		}
	}

	if (splitedge == edges[2].end())
	{
		node->setLeaf();
		return;
	}

	node->setSplit(splitedge->pos);

#if 0
// export kd-tree as .obj for visualization
// this would be hard to reconstruct later
	static ofstream F("kdtree.obj");
	Vector3 v;
	static int f=0;
	v.cell[axis] = node->getSplit();
	v.cell[(axis+1)%3] = bounds.L.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.L.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bounds.L.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.H.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bounds.H.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.H.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bounds.H.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.L.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	F << "f " << f+1 << " " << f+2 << " " << f+3 << " " << f+4 << endl;
	f += 4;
#endif

	// split this node
	delete shapes;

	BBox lbb = bounds;
	BBox rbb = bounds;
	lbb.H.cell[axis] = node->getSplit();
	rbb.L.cell[axis] = node->getSplit();
	node->setChildren(new KdNode[2]);
	node->setAxis(axis);

	for (edge = edges[axis].begin(); edge != splitedge; edge++)
		if (!edge->end && edge->shape->intersect_bbox(lbb))
			node->getLeftChild()->addShape(edge->shape);
	for (edge = splitedge; edge < edges[axis].end(); edge++)
		if (edge->end && edge->shape->intersect_bbox(rbb))
			node->getRightChild()->addShape(edge->shape);

	recursive_build(node->getLeftChild(), lbb, maxdepth-1);
	recursive_build(node->getRightChild(), rbb, maxdepth-1);
}

void KdTree::build()
{
	dbgmsg(1, "* building kd-tree\n");
	root = new KdNode();
	ShapeList::iterator shape;
	for (shape = shapes.begin(); shape != shapes.end(); shape++)
		root->addShape(*shape);
	recursive_build(root, bbox, max_depth);
	built = true;
}

/* algorithm by Vlastimil Havran, Heuristic Ray Shooting Algorithms, appendix C */
Shape *KdTree::nearest_intersection(const Shape *origin_shape, const Ray &ray,
	Float &nearest_distance)
{
	Float a, b; /* entry/exit signed distance */
	Float t;    /* signed distance to the splitting plane */

	/* if we have no tree, fall back to naive test */
	if (!built)
		return Container::nearest_intersection(origin_shape, ray, nearest_distance);

	if (!bbox.intersect(ray, a, b))
		return NULL;

	/* pointers to the far child node and current node */
	KdNode *farchild, *node;
	node = root;

	StackElem stack[max_depth];

	int entry = 0, exit = 1;
	stack[entry].t = a;

	/* distinguish between internal and external origin of a ray*/
	if (a >= 0.0)
		stack[entry].pb = ray.o + ray.dir * a; /* external */
	else
		stack[entry].pb = ray.o;               /* internal */

	/* setup initial exit point in the stack */
	stack[exit].t = b;
	stack[exit].pb = ray.o + ray.dir * b;
	stack[exit].node = NULL;

	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
	Float splitVal;
	int axis;
	static const int mod3[] = {0,1,2,0,1};
	const Vector3 invdir = 1 / ray.dir;
	while (node)
	{
		/* loop until a leaf is found */
		while (!node->isLeaf())
		{
			/* retrieve position of splitting plane */
			splitVal = node->getSplit();
			axis = node->getAxis();

			if (stack[entry].pb[axis] <= splitVal)
			{
				if (stack[exit].pb[axis] <= splitVal)
				{ /* case N1, N2, N3, P5, Z2, and Z3 */
					node = node->getLeftChild();
					continue;
				}
				if (stack[entry].pb[axis] == splitVal)
				{ /* case Z1 */
					node = node->getRightChild();
					continue;
				}
				/* case N4 */
				farchild = node->getRightChild();
				node = node->getLeftChild();
			}
			else
			{ /* (stack[entry].pb[axis] > splitVal) */
				if (stack[exit].pb[axis] > splitVal)
				{
					/* case P1, P2, P3, and N5 */
					node = node->getRightChild();
					continue;
				}
				/* case P4 */
				farchild = node->getLeftChild();
				node = node->getRightChild();
			}

			/* case P4 or N4 . . . traverse both children */

			/* signed distance to the splitting plane */
			t = (splitVal - ray.o[axis]) * invdir[axis];

			/* setup the new exit point and push it onto stack */
			register int tmp = exit;

			exit++;
			if (exit == entry)
				exit++;
			assert(exit < max_depth);

			stack[exit].prev = tmp;
			stack[exit].t = t;
			stack[exit].node = farchild;
			stack[exit].pb.cell[axis] = splitVal;
			stack[exit].pb.cell[mod3[axis+1]] = ray.o.cell[mod3[axis+1]]
				+ t * ray.dir.cell[mod3[axis+1]];
			stack[exit].pb.cell[mod3[axis+2]] = ray.o.cell[mod3[axis+2]]
				+ t * ray.dir.cell[mod3[axis+2]];
		}

		/* current node is the leaf . . . empty or full */
		Shape *nearest_shape = NULL;
		Float dist = stack[exit].t;
		ShapeList::iterator shape;
		for (shape = node->getShapes()->begin(); shape != node->getShapes()->end(); shape++)
			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
			&& dist >= stack[entry].t - Eps)
			{
				nearest_shape = *shape;
				nearest_distance = dist;
			}

		if (nearest_shape)
			return nearest_shape;

		entry = exit;

		/* retrieve the pointer to the next node,
		it is possible that ray traversal terminates */
		node = stack[entry].node;
		exit = stack[entry].prev;
	}

	/* ray leaves the scene */
	return NULL;
}

// stack element for kd-tree traversal (packet version)
struct StackElem4
{
	KdNode* node; /* pointer to far child */
	__m128 t; /* the entry/exit signed distance */
	VectorPacket pb; /* the coordinates of entry/exit point */
	int prev;
};

void KdTree::packet_intersection(const Shape **origin_shapes, const RayPacket &rays,
		Float *nearest_distances, Shape **nearest_shapes)
{
	__m128 a, b; /* entry/exit signed distance */
	__m128 t;    /* signed distance to the splitting plane */
	__m128 mask = mZero;

	/* if we have no tree, fall back to naive test */
	if (!built)
		Container::packet_intersection(origin_shapes, rays, nearest_distances, nearest_shapes);

	// nearest_shapes[0..4] = NULL
	memset(nearest_shapes, 0, 4*sizeof(Shape*));

	mask = bbox.intersect_packet(rays, a, b);
	if (!_mm_movemask_ps(mask))
		return;

	/* pointers to the far child node and current node */
	KdNode *farchild, *node;
	node = root;

	StackElem4 stack[max_depth];

	int entry = 0, exit = 1;
	stack[entry].t = a;

	/* distinguish between internal and external origin of a ray*/
	t = _mm_cmplt_ps(a, mZero);
	stack[entry].pb = rays.o + rays.dir * a;
	stack[entry].pb.mx = _mm_or_ps(_mm_andnot_ps(t, stack[entry].pb.mx),
		_mm_and_ps(t, rays.o.mx));
	stack[entry].pb.my = _mm_or_ps(_mm_andnot_ps(t, stack[entry].pb.my),
		_mm_and_ps(t, rays.o.my));
	stack[entry].pb.mz = _mm_or_ps(_mm_andnot_ps(t, stack[entry].pb.mz),
		_mm_and_ps(t, rays.o.mz));

	/* setup initial exit point in the stack */
	stack[exit].t = b;
	stack[exit].pb = rays.o + rays.dir * b;
	stack[exit].node = NULL;

	/* loop, traverse through the whole kd-tree,
	until an object is intersected or ray leaves the scene */
	__m128 splitVal;
	int axis;
	static const int mod3[] = {0,1,2,0,1};
	const VectorPacket invdirs = mOne / rays.dir;
	while (node)
	{
		/* loop until a leaf is found */
		while (!node->isLeaf())
		{
			/* retrieve position of splitting plane */
			splitVal = _mm_set_ps1(node->getSplit());
			axis = node->getAxis();

			// mask out invalid rays with near > far
			const __m128 curmask = _mm_and_ps(mask, _mm_cmple_ps(stack[entry].t, stack[exit].t));
			const __m128 entry_lt = _mm_cmplt_ps(stack[entry].pb.ma[axis], splitVal);
			const __m128 entry_gt = _mm_cmpgt_ps(stack[entry].pb.ma[axis], splitVal);
			const __m128 exit_lt = _mm_cmplt_ps(stack[exit].pb.ma[axis], splitVal);
			const __m128 exit_gt = _mm_cmpgt_ps(stack[exit].pb.ma[axis], splitVal);

			// if all of
			// stack[entry].pb[axis] <= splitVal,
			// stack[exit].pb[axis] <= splitVal
			if (!_mm_movemask_ps(
				_mm_and_ps(_mm_or_ps(entry_gt, exit_gt), curmask)))
			{
				node = node->getLeftChild();
				continue;
			}

			// if all of
			// stack[entry].pb[axis] >= splitVal,
			// stack[exit].pb[axis] >= splitVal
			if (!_mm_movemask_ps(
				_mm_and_ps(_mm_or_ps(entry_lt, exit_lt), curmask)))
			{
				node = node->getRightChild();
				continue;
			}

			// any of
			// stack[entry].pb[axis] < splitVal,
			// stack[exit].pb[axis] > splitVal
			bool cond1 = _mm_movemask_ps(
				_mm_and_ps(_mm_and_ps(entry_lt, exit_gt), curmask));

			// any of
			// stack[entry].pb[axis] > splitVal,
			// stack[exit].pb[axis] < splitVal
			bool cond2 = _mm_movemask_ps(
				_mm_and_ps(_mm_and_ps(entry_gt, exit_lt), curmask));

			if ((!cond1 && !cond2) || (cond1 && cond2))
			{
				// fall back to single rays
				// FIXME: split rays and continue
				for (int i = 0; i < 4; i++)
					if (!nearest_shapes[i])
						nearest_shapes[i] = nearest_intersection(origin_shapes[i],
							rays[i], nearest_distances[i]);
				return;
			}

			if (cond1)
			{
				farchild = node->getRightChild();
				node = node->getLeftChild();
			}
			else
			{
				farchild = node->getLeftChild();
				node = node->getRightChild();
			}

			/* traverse both children */

			/* signed distance to the splitting plane */
			t = _mm_mul_ps(_mm_sub_ps(splitVal, rays.o.ma[axis]), invdirs.ma[axis]);

			/* setup the new exit point and push it onto stack */
			register int tmp = exit;

			exit++;
			if (exit == entry)
				exit++;
			assert(exit < max_depth);

			stack[exit].prev = tmp;
			stack[exit].t = t;
			stack[exit].node = farchild;
			stack[exit].pb.ma[axis] = splitVal;
			stack[exit].pb.ma[mod3[axis+1]] =
				_mm_add_ps(rays.o.ma[mod3[axis+1]], _mm_mul_ps(t, rays.dir.ma[mod3[axis+1]]));
			stack[exit].pb.ma[mod3[axis+2]] =
				_mm_add_ps(rays.o.ma[mod3[axis+2]], _mm_mul_ps(t, rays.dir.ma[mod3[axis+2]]));
		}

		/* current node is the leaf . . . empty or full */
		__m128 dists = stack[exit].t;
		ShapeList::iterator shape;
		__m128 results;
		__m128 newmask = mask;
		for (shape = node->getShapes()->begin(); shape != node->getShapes()->end(); shape++)
		{
			results = (*shape)->intersect_packet(rays, dists);
			int valid = _mm_movemask_ps(
				_mm_and_ps(mask, _mm_and_ps(results,
				_mm_cmpge_ps(dists, _mm_sub_ps(stack[entry].t, mEps)))));
			for (int i = 0; i < 4; i++)
			{
				if (*shape != origin_shapes[i] && ((valid>>i)&1))
				{
					nearest_shapes[i] = *shape;
					nearest_distances[i] = ((float*)&dists)[i];
					((int*)&newmask)[i] = 0;
				}
			}
		}

		mask = newmask;
		if (!_mm_movemask_ps(mask))
			return;

		entry = exit;

		/* retrieve the pointer to the next node,
		it is possible that ray traversal terminates */
		node = stack[entry].node;
		exit = stack[entry].prev;
	}

	/* ray leaves the scene */
}

ostream & operator<<(ostream &st, KdNode &node)
{
	if (node.isLeaf())
	{
		st << "(l," << node.getShapes()->size();
		ShapeList::iterator shape;
		for (shape = node.getShapes()->begin(); shape != node.getShapes()->end(); shape++)
			st << "," << shape_index[*shape];
		st << ")";
	}
	else
	{
		st << "(s," << (char)('x'+node.getAxis()) << "," << node.getSplit() << ",";
		st << *node.getLeftChild() << ",";
		st << *node.getRightChild() << ")";
	}
	return st;
}

ostream & KdTree::dump(ostream &st)
{
	if (!built)
		return Container::dump(st);

	st << "(kdtree," << shapes.size();
	ShapeList::iterator shape;
	for (shape = shapes.begin(); shape != shapes.end(); shape++)
	{
		int idx;
		if (!shape_index.get(*shape, idx))
			st << "," << endl << **shape;
	}
	return st << "," << endl << *getRootNode() << ")";
}

void KdTree::recursive_load(istream &st, KdNode *node)
{
	string s;
	istringstream is;

	getline(st, s, ',');
	trim(s);

	if (s.compare("(s") == 0)
	{
		// split
		int axis;
		Float split;

		delete node->getShapes();
		node->setChildren(new KdNode[2]);

		getline(st, s, ',');
		axis = s.c_str()[0]-'x';
		node->setAxis(axis);

		st >> split;
		getline(st, s, ',');
		node->setSplit(split);

		recursive_load(st, node->getLeftChild());
		getline(st, s, ',');
		recursive_load(st, node->getRightChild());
		getline(st, s, ')');
	}

	if (s.compare("(l") == 0)
	{
		// leaf
		int count, idx;

		node->setLeaf();

		st >> count;
		for (int i = 0; i < count; i++)
		{
			getline(st, s, ',');
			st >> idx;
			node->addShape(shapes[idx]);
		}
		getline(st, s, ')');
	}
}

istream & KdTree::load(istream &st, Material *mat)
{
	string s;
	istringstream is;

	getline(st, s, ',');
	if (s.compare("(kdtree") != 0)
		return st;

	dbgmsg(1, "* loading kd-tree\n");

	shapes.clear();
	if (root) delete root;
	root = new KdNode();

	getline(st, s, ',');
	int shape_count;
	is.str(s);
	is >> shape_count;

	Shape *shape;
	for (int i = 0; i < shape_count; i++)
	{
		shape = loadShape(st, mat);
		Container::addShape(shape);
		getline(st, s, ',');
	}

	recursive_load(st, root);

	built = true;
	return st;
}