src/kdtree.cc
author Radek Brich <radek.brich@devl.cz>
Tue, 26 Jul 2016 18:19:37 +0200 (2016-07-26)
branchpyrit
changeset 104 2274a07510c1
parent 103 3b3257a410fe
permissions -rw-r--r--
Cleanup, dropped Windows support
/*
 * kdtree.cc: KdTree class
 *
 * This file is part of Pyrit Ray Tracer.
 *
 * Copyright 2006, 2007, 2008  Radek Brich
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include <algorithm>
#include <stack>
#include <string>
#include <sstream>

#include "kdtree.h"
#include "serialize.h"

class ShapeBound
{
public:
	const Shape *shape;
	Float pos;
	bool end;
	ShapeBound(const Shape *ashape, const Float apos, const bool aend):
		shape(ashape), pos(apos), end(aend) {};
	friend bool operator<(const ShapeBound& a, const ShapeBound& b)
	{
		if (a.pos == b.pos)
			return a.shape < b.shape;
		else
			return a.pos < b.pos;
	};
};

// stack element for kd-tree traversal
struct StackElem
{
	KdNode* node; /* pointer to far child */
	Float t; /* the entry/exit signed distance */
	Vector pb; /* the coordinates of entry/exit point */
	int prev;
};

// ----------------------------------------

KdNode::~KdNode()
{
	if (isLeaf())
		delete getShapes();
}

const int KdTree::MAX_DEPTH = 32;

// kd-tree recursive build algorithm, inspired by PBRT (www.pbrt.org)
void KdTree::recursive_build(KdNode *node, const BBox &bounds, int maxdepth)
{
	ShapeList *shapes = node->getShapes();

	if (maxdepth <= 0 || shapes->size() <= 2)
	{
		node->setLeaf();
		return;
	}

	// choose split axis
	/*axis = 0;
	if (bounds.h() > bounds.w() && bounds.h() > bounds.d())
		axis = 1;
	if (bounds.d() > bounds.w() && bounds.d() > bounds.h())
		axis = 2;
*/
	// create sorted list of shape bounds (= find all posible splits)
	vector<ShapeBound> edges[3];
	ShapeList::iterator shape;
	for (shape = shapes->begin(); shape != shapes->end(); shape++)
	{
		BBox shapebounds = (*shape)->get_bbox();
		for (int ax = 0; ax < 3; ax++)
		{
			edges[ax].push_back(ShapeBound(*shape, shapebounds.L[ax], 0));
			edges[ax].push_back(ShapeBound(*shape, shapebounds.H[ax], 1));
		}
	}
	for (int ax = 0; ax < 3; ax++)
		sort(edges[ax].begin(), edges[ax].end());

	// choose best split pos
	const Float K = 1.4f; // constant, K = cost of traversal / cost of ray-triangle intersection
	Float SAV = (bounds.w()*bounds.h() +  // surface area of node
		bounds.w()*bounds.d() + bounds.h()*bounds.d());
	Float cost = SAV * (K + shapes->size()); // initial cost = non-split cost

	int axis = 0;
	vector<ShapeBound>::iterator edge, splitedge = edges[axis].end();
	for (int ax = 0; ax < 3; ax++)
	{
		int lnum = 0, rnum = (int)shapes->size();
		BBox lbb = bounds;
		BBox rbb = bounds;
		for (edge = edges[ax].begin(); edge != edges[ax].end(); edge++)
		{
			if (edge->end)
				rnum--;

			// calculate SAH cost of this split
			lbb.H[ax] = edge->pos;
			rbb.L[ax] = edge->pos;
			Float SAL = (lbb.w()*lbb.h() + lbb.w()*lbb.d() + lbb.h()*lbb.d());
			Float SAR = (rbb.w()*rbb.h() + rbb.w()*rbb.d() + rbb.h()*rbb.d());
			Float splitcost = K*SAV + SAL*(K + lnum) + SAR*(K + rnum);

			if (splitcost < cost)
			{
				axis = ax;
				splitedge = edge;
				cost = splitcost;
			}

			if (!edge->end)
				lnum++;
		}
	}

	if (splitedge == edges[axis].end())
	{
		node->setLeaf();
		return;
	}

	node->setSplit(splitedge->pos);

#if 0
// export kd-tree as .obj for visualization
// this would be hard to reconstruct later
	static ofstream F("kdtree.obj");
	Vector v;
	static int f=0;
	v[axis] = node->getSplit();
	v[(axis+1)%3] = bounds.L[(axis+1)%3];
	v[(axis+2)%3] = bounds.L[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v[(axis+1)%3] = bounds.L[(axis+1)%3];
	v[(axis+2)%3] = bounds.H[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v[(axis+1)%3] = bounds.H[(axis+1)%3];
	v[(axis+2)%3] = bounds.H[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v[(axis+1)%3] = bounds.H[(axis+1)%3];
	v[(axis+2)%3] = bounds.L[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	F << "f " << f+1 << " " << f+2 << " " << f+3 << " " << f+4 << endl;
	f += 4;
#endif

	// split this node
	delete shapes;

	BBox lbb = bounds;
	BBox rbb = bounds;
	lbb.H[axis] = node->getSplit();
	rbb.L[axis] = node->getSplit();
	node->setChildren(new (mempool.alloc()) KdNode);
	new (mempool.alloc()) KdNode;
	node->setAxis(axis);

	for (edge = edges[axis].begin(); edge != splitedge; edge++)
		if (!edge->end && edge->shape->intersect_bbox(lbb))
			node->getLeftChild()->addShape(edge->shape);
	for (edge = splitedge; edge < edges[axis].end(); edge++)
		if (edge->end && edge->shape->intersect_bbox(rbb))
			node->getRightChild()->addShape(edge->shape);

	recursive_build(node->getLeftChild(), lbb, maxdepth-1);
	recursive_build(node->getRightChild(), rbb, maxdepth-1);
}

void KdTree::build()
{
	dbgmsg(1, "* building kd-tree\n");
	root = new KdNode();
	ShapeList::iterator shape;
	for (shape = shapes.begin(); shape != shapes.end(); shape++)
		root->addShape(*shape);
	recursive_build(root, bbox, MAX_DEPTH);
	built = true;
}

/* algorithm by Vlastimil Havran, Heuristic Ray Shooting Algorithms, appendix C */
const Shape *KdTree::nearest_intersection(const Shape *origin_shape, const Ray &ray,
	Float &nearest_distance)
{
	Float a, b; /* entry/exit signed distance */
	Float t;    /* signed distance to the splitting plane */

	/* if we have no tree, fall back to naive test */
	if (!built)
		return Container::nearest_intersection(origin_shape, ray, nearest_distance);

	if (!bbox.intersect(ray, a, b))
		return NULL;

	/* pointers to the far child node and current node */
	KdNode *farchild, *node;
	node = root;

	StackElem stack[MAX_DEPTH];

	int entry = 0, exit = 1;
	stack[entry].t = a;

	/* distinguish between internal and external origin of a ray*/
	if (a >= 0.0)
		stack[entry].pb = ray.o + ray.dir * a; /* external */
	else
		stack[entry].pb = ray.o;               /* internal */

	/* setup initial exit point in the stack */
	stack[exit].t = b;
	stack[exit].pb = ray.o + ray.dir * b;
	stack[exit].node = NULL;

	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
	Float splitVal;
	int axis;
	static const int mod3[] = {0,1,2,0,1};
	const Vector invdir = 1 / ray.dir;
	while (node)
	{
		/* loop until a leaf is found */
		while (!node->isLeaf())
		{
			/* retrieve position of splitting plane */
			splitVal = node->getSplit();
			axis = node->getAxis();

			if (stack[entry].pb[axis] <= splitVal)
			{
				if (stack[exit].pb[axis] <= splitVal)
				{ /* case N1, N2, N3, P5, Z2, and Z3 */
					node = node->getLeftChild();
					continue;
				}
				if (stack[entry].pb[axis] == splitVal)
				{ /* case Z1 */
					node = node->getRightChild();
					continue;
				}
				/* case N4 */
				farchild = node->getRightChild();
				node = node->getLeftChild();
			}
			else
			{ /* (stack[entry].pb[axis] > splitVal) */
				if (stack[exit].pb[axis] > splitVal)
				{
					/* case P1, P2, P3, and N5 */
					node = node->getRightChild();
					continue;
				}
				/* case P4 */
				farchild = node->getLeftChild();
				node = node->getRightChild();
			}

			/* case P4 or N4 . . . traverse both children */

			/* signed distance to the splitting plane */
			t = (splitVal - ray.o[axis]) * invdir[axis];

			/* setup the new exit point and push it onto stack */
			register int tmp = exit;

			exit++;
			if (exit == entry)
				exit++;
			assert(exit < MAX_DEPTH);

			stack[exit].prev = tmp;
			stack[exit].t = t;
			stack[exit].node = farchild;
			stack[exit].pb[axis] = splitVal;
			stack[exit].pb[mod3[axis+1]] = ray.o[mod3[axis+1]]
				+ t * ray.dir[mod3[axis+1]];
			stack[exit].pb[mod3[axis+2]] = ray.o[mod3[axis+2]]
				+ t * ray.dir[mod3[axis+2]];
		}

		/* current node is the leaf . . . empty or full */
		const Shape *nearest_shape = NULL;
		Float dist = stack[exit].t;
		ShapeList::iterator shape;
		for (shape = node->getShapes()->begin(); shape != node->getShapes()->end(); shape++)
			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
			&& dist >= stack[entry].t - Eps)
			{
				nearest_shape = *shape;
				nearest_distance = dist;
			}

		if (nearest_shape)
			return nearest_shape;

		entry = exit;

		/* retrieve the pointer to the next node,
		it is possible that ray traversal terminates */
		node = stack[entry].node;
		exit = stack[entry].prev;
	}

	/* ray leaves the scene */
	return NULL;
}

#ifndef NO_SIMD
// stack element for kd-tree traversal (packet version)
struct StackElem4
{
	KdNode* node; /* pointer to far child */
	mfloat4 t; /* the entry/exit signed distance */
	VectorPacket pb; /* the coordinates of entry/exit point */
	int prev;
};

void KdTree::packet_intersection(const Shape* const* origin_shapes, const RayPacket &rays,
		Float *nearest_distances, const Shape **nearest_shapes)
{
	mfloat4 a, b; /* entry/exit signed distance */
	mfloat4 t;    /* signed distance to the splitting plane */
	mfloat4 mask = mZero;

	/* if we have no tree, fall back to naive test */
	if (!built)
		Container::packet_intersection(origin_shapes, rays, nearest_distances, nearest_shapes);

	// nearest_shapes[0..4] = NULL
	memset(nearest_shapes, 0, 4*sizeof(Shape*));

	mask = bbox.intersect_packet(rays, a, b);
	if (!mmovemask(mask))
		return;

	/* pointers to the far child node and current node */
	KdNode *farchild, *node;
	node = root;

	StackElem4 stack[MAX_DEPTH];

	int entry = 0, exit = 1;
	stack[entry].t = a;

	/* distinguish between internal and external origin of a ray*/
	t = mcmplt(a, mZero);
	stack[entry].pb = rays.o + rays.dir * a;
	stack[entry].pb.mx = mselect(t, rays.o.mx, stack[entry].pb.mx);
	stack[entry].pb.my = mselect(t, rays.o.my, stack[entry].pb.my);
	stack[entry].pb.mz = mselect(t, rays.o.mz, stack[entry].pb.mz);

	/* setup initial exit point in the stack */
	stack[exit].t = b;
	stack[exit].pb = rays.o + rays.dir * b;
	stack[exit].node = NULL;

	/* loop, traverse through the whole kd-tree,
	until an object is intersected or ray leaves the scene */
	mfloat4 splitVal;
	int axis;
	static const int mod3[] = {0,1,2,0,1};
	const VectorPacket invdirs = mOne / rays.dir;
	while (node)
	{
		/* loop until a leaf is found */
		while (!node->isLeaf())
		{
			/* retrieve position of splitting plane */
			splitVal = mset1(node->getSplit());
			axis = node->getAxis();

			// mask out invalid rays with near > far
			const mfloat4 curmask = mand(mask, mcmple(stack[entry].t, stack[exit].t));
			const mfloat4 entry_lt = mcmplt(stack[entry].pb.ma[axis], splitVal);
			const mfloat4 entry_gt = mcmpgt(stack[entry].pb.ma[axis], splitVal);
			const mfloat4 exit_lt = mcmplt(stack[exit].pb.ma[axis], splitVal);
			const mfloat4 exit_gt = mcmpgt(stack[exit].pb.ma[axis], splitVal);

			// if all of
			// stack[entry].pb[axis] <= splitVal,
			// stack[exit].pb[axis] <= splitVal
			if (!mmovemask(
				mand(mor(entry_gt, exit_gt), curmask)))
			{
				node = node->getLeftChild();
				continue;
			}

			// if all of
			// stack[entry].pb[axis] >= splitVal,
			// stack[exit].pb[axis] >= splitVal
			if (!mmovemask(
				mand(mor(entry_lt, exit_lt), curmask)))
			{
				node = node->getRightChild();
				continue;
			}

			// any of
			// stack[entry].pb[axis] < splitVal,
			// stack[exit].pb[axis] > splitVal
			int cond1 = mmovemask(
				mand(mand(entry_lt, exit_gt), curmask));

			// any of
			// stack[entry].pb[axis] > splitVal,
			// stack[exit].pb[axis] < splitVal
			int cond2 = mmovemask(
				mand(mand(entry_gt, exit_lt), curmask));

			if ((!cond1 && !cond2) || (cond1 && cond2))
			{
				// fall back to single rays
				// FIXME: split rays and continue
				for (int i = 0; i < 4; i++)
					if (!nearest_shapes[i])
						nearest_shapes[i] = nearest_intersection(origin_shapes[i],
							rays[i], nearest_distances[i]);
				return;
			}

			if (cond1)
			{
				farchild = node->getRightChild();
				node = node->getLeftChild();
			}
			else
			{
				farchild = node->getLeftChild();
				node = node->getRightChild();
			}

			/* traverse both children */

			/* signed distance to the splitting plane */
			t = mmul(msub(splitVal, rays.o.ma[axis]), invdirs.ma[axis]);

			/* setup the new exit point and push it onto stack */
			register int tmp = exit;

			exit++;
			if (exit == entry)
				exit++;
			assert(exit < MAX_DEPTH);

			stack[exit].prev = tmp;
			stack[exit].t = t;
			stack[exit].node = farchild;
			stack[exit].pb.ma[axis] = splitVal;
			stack[exit].pb.ma[mod3[axis+1]] =
				madd(rays.o.ma[mod3[axis+1]], mmul(t, rays.dir.ma[mod3[axis+1]]));
			stack[exit].pb.ma[mod3[axis+2]] =
				madd(rays.o.ma[mod3[axis+2]], mmul(t, rays.dir.ma[mod3[axis+2]]));
		}

		/* current node is the leaf . . . empty or full */
		mfloat4 dists = stack[exit].t;
		ShapeList::iterator shape;
		mfloat4 results;
		mfloat4 newmask = mask;
		for (shape = node->getShapes()->begin(); shape != node->getShapes()->end(); shape++)
		{
			results = (*shape)->intersect_packet(rays, dists);
			int valid = mmovemask(
				mand(mask, mand(results,
				mcmpge(dists, msub(stack[entry].t, mEps)))));
			for (int i = 0; i < 4; i++)
			{
				if (*shape != origin_shapes[i] && ((valid>>i)&1))
				{
					nearest_shapes[i] = *shape;
					nearest_distances[i] = ((float*)&dists)[i];
					((int*)&newmask)[i] = 0;
				}
			}
		}

		mask = newmask;
		if (!mmovemask(mask))
			return;

		entry = exit;

		/* retrieve the pointer to the next node,
		it is possible that ray traversal terminates */
		node = stack[entry].node;
		exit = stack[entry].prev;
	}

	/* ray leaves the scene */
}
#endif

ostream & operator<<(ostream &st, KdNode &node)
{
	if (node.isLeaf())
	{
		st << "(l," << node.getShapes()->size();
		ShapeList::iterator shape;
		for (shape = node.getShapes()->begin(); shape != node.getShapes()->end(); shape++)
			st << "," << shape_index[*shape];
		st << ")";
	}
	else
	{
		st << "(s," << (char)('x'+node.getAxis()) << "," << node.getSplit() << ",";
		st << *node.getLeftChild() << ",";
		st << *node.getRightChild() << ")";
	}
	return st;
}

ostream & KdTree::dump(ostream &st) const
{
	if (!built)
		return Container::dump(st);

	st << "(kdtree," << shapes.size();
	ShapeList::const_iterator shape;
	for (shape = shapes.begin(); shape != shapes.end(); shape++)
	{
		int idx;
		if (!shape_index.get(*shape, idx))
			st << "," << endl << **shape;
	}
	return st << "," << endl << *getRootNode() << ")";
}

void KdTree::recursive_load(istream &st, KdNode *node)
{
	string s;

	getline(st, s, ',');
	trim(s);

	if (s.compare("(s") == 0)
	{
		// split
		int axis;
		Float split;

		delete node->getShapes();
		node->setChildren(new KdNode[2]);

		getline(st, s, ',');
		axis = s.c_str()[0]-'x';
		node->setAxis(axis);

		st >> split;
		getline(st, s, ',');
		node->setSplit(split);

		recursive_load(st, node->getLeftChild());
		getline(st, s, ',');
		recursive_load(st, node->getRightChild());
		getline(st, s, ')');
	}

	if (s.compare("(l") == 0)
	{
		// leaf
		int count, idx;

		node->setLeaf();

		st >> count;
		for (int i = 0; i < count; i++)
		{
			getline(st, s, ',');
			st >> idx;
			node->addShape(shapes[idx]);
		}
		getline(st, s, ')');
	}
}

istream & KdTree::load(istream &st, Material *mat)
{
	string s;
	istringstream is;

	getline(st, s, ',');
	if (s.compare("(kdtree") != 0)
		return st;

	dbgmsg(1, "* loading kd-tree\n");

	shapes.clear();
	if (root) delete root;
	root = new KdNode();

	getline(st, s, ',');
	int shape_count;
	is.str(s);
	is >> shape_count;

	Shape *shape;
	for (int i = 0; i < shape_count; i++)
	{
		shape = loadShape(st, mat);
		Container::addShape(shape);
		getline(st, s, ',');
	}

	recursive_load(st, root);

	built = true;
	return st;
}