src/kdtree.cc
branchpyrit
changeset 74 09aedbf5f95f
parent 72 7c3f38dff082
child 76 3b60fd0bea64
equal deleted inserted replaced
73:a5127346fbcd 74:09aedbf5f95f
    46 			return a.pos < b.pos;
    46 			return a.pos < b.pos;
    47 	};
    47 	};
    48 };
    48 };
    49 
    49 
    50 // stack element for kd-tree traversal
    50 // stack element for kd-tree traversal
    51 class StackElem
    51 struct StackElem
    52 {
    52 {
    53 public:
       
    54 	KdNode* node; /* pointer to far child */
    53 	KdNode* node; /* pointer to far child */
    55 	Float t; /* the entry/exit signed distance */
    54 	Float t; /* the entry/exit signed distance */
    56 	Vector3 pb; /* the coordinates of entry/exit point */
    55 	Vector3 pb; /* the coordinates of entry/exit point */
    57 	StackElem(KdNode *anode, const Float &at, const Vector3 &apb):
    56 	int prev;
    58 		node(anode), t(at), pb(apb) {};
       
    59 };
    57 };
    60 
    58 
    61 // ----------------------------------------
    59 // ----------------------------------------
    62 
    60 
    63 KdNode::~KdNode()
    61 KdNode::~KdNode()
    64 {
    62 {
    65 	if (isLeaf())
    63 	if (isLeaf())
    66 		delete shapes;
    64 		delete getShapes();
    67 	else
    65 	else
    68 		delete[] children;
    66 		delete[] getLeftChild();
    69 }
    67 }
    70 
    68 
    71 // kd-tree recursive build algorithm, inspired by PBRT (www.pbrt.org)
    69 // kd-tree recursive build algorithm, inspired by PBRT (www.pbrt.org)
    72 void KdNode::subdivide(BBox bounds, int maxdepth)
    70 void KdNode::subdivide(BBox bounds, int maxdepth)
    73 {
    71 {
    74 	if (maxdepth <= 0 || shapes->size() <= 2)
    72 	if (maxdepth <= 0 || getShapes()->size() <= 2)
    75 	{
    73 	{
    76 		setLeaf();
    74 		setLeaf();
    77 		return;
    75 		return;
    78 	}
    76 	}
    79 
    77 
    85 		axis = 2;
    83 		axis = 2;
    86 */
    84 */
    87 	// create sorted list of shape bounds (= find all posible splits)
    85 	// create sorted list of shape bounds (= find all posible splits)
    88 	vector<ShapeBound> edges[3];
    86 	vector<ShapeBound> edges[3];
    89 	ShapeList::iterator shape;
    87 	ShapeList::iterator shape;
    90 	for (shape = shapes->begin(); shape != shapes->end(); shape++)
    88 	for (shape = getShapes()->begin(); shape != getShapes()->end(); shape++)
    91 	{
    89 	{
    92 		BBox shapebounds = (*shape)->get_bbox();
    90 		BBox shapebounds = (*shape)->get_bbox();
    93 		for (int ax = 0; ax < 3; ax++)
    91 		for (int ax = 0; ax < 3; ax++)
    94 		{
    92 		{
    95 			edges[ax].push_back(ShapeBound(*shape, shapebounds.L[ax], 0));
    93 			edges[ax].push_back(ShapeBound(*shape, shapebounds.L[ax], 0));
   101 
    99 
   102 	// choose best split pos
   100 	// choose best split pos
   103 	const Float K = 1.4; // constant, K = cost of traversal / cost of ray-triangle intersection
   101 	const Float K = 1.4; // constant, K = cost of traversal / cost of ray-triangle intersection
   104 	Float SAV = (bounds.w()*bounds.h() +  // surface area of node
   102 	Float SAV = (bounds.w()*bounds.h() +  // surface area of node
   105 		bounds.w()*bounds.d() + bounds.h()*bounds.d());
   103 		bounds.w()*bounds.d() + bounds.h()*bounds.d());
   106 	Float cost = SAV * (K + shapes->size()); // initial cost = non-split cost
   104 	Float cost = SAV * (K + getShapes()->size()); // initial cost = non-split cost
   107 
   105 
   108 	vector<ShapeBound>::iterator edge, splitedge = edges[2].end();
   106 	vector<ShapeBound>::iterator edge, splitedge = edges[2].end();
       
   107 	int axis = 0;
   109 	for (int ax = 0; ax < 3; ax++)
   108 	for (int ax = 0; ax < 3; ax++)
   110 	{
   109 	{
   111 		int lnum = 0, rnum = shapes->size();
   110 		int lnum = 0, rnum = getShapes()->size();
   112 		BBox lbb = bounds;
   111 		BBox lbb = bounds;
   113 		BBox rbb = bounds;
   112 		BBox rbb = bounds;
   114 		for (edge = edges[ax].begin(); edge != edges[ax].end(); edge++)
   113 		for (edge = edges[ax].begin(); edge != edges[ax].end(); edge++)
   115 		{
   114 		{
   116 			if (edge->end)
   115 			if (edge->end)
   164 	F << "f " << f+1 << " " << f+2 << " " << f+3 << " " << f+4 << endl;
   163 	F << "f " << f+1 << " " << f+2 << " " << f+3 << " " << f+4 << endl;
   165 	f += 4;
   164 	f += 4;
   166 #endif
   165 #endif
   167 
   166 
   168 	// split this node
   167 	// split this node
   169 	delete shapes;
   168 	delete getShapes();
       
   169 
   170 	BBox lbb = bounds;
   170 	BBox lbb = bounds;
   171 	BBox rbb = bounds;
   171 	BBox rbb = bounds;
   172 	lbb.H.cell[axis] = split;
   172 	lbb.H.cell[axis] = split;
   173 	rbb.L.cell[axis] = split;
   173 	rbb.L.cell[axis] = split;
   174 	children = new KdNode[2];
   174 	setChildren(new KdNode[2]);
       
   175 	setAxis(axis);
   175 
   176 
   176 	for (edge = edges[axis].begin(); edge != splitedge; edge++)
   177 	for (edge = edges[axis].begin(); edge != splitedge; edge++)
   177 		if (!edge->end && edge->shape->intersect_bbox(lbb))
   178 		if (!edge->end && edge->shape->intersect_bbox(lbb))
   178 			children[0].addShape(edge->shape);
   179 			getLeftChild()->addShape(edge->shape);
   179 	for (edge = splitedge; edge < edges[axis].end(); edge++)
   180 	for (edge = splitedge; edge < edges[axis].end(); edge++)
   180 		if (edge->end && edge->shape->intersect_bbox(rbb))
   181 		if (edge->end && edge->shape->intersect_bbox(rbb))
   181 			children[1].addShape(edge->shape);
   182 			getRightChild()->addShape(edge->shape);
   182 
   183 
   183 	children[0].subdivide(lbb, maxdepth-1);
   184 	getLeftChild()->subdivide(lbb, maxdepth-1);
   184 	children[1].subdivide(rbb, maxdepth-1);
   185 	getRightChild()->subdivide(rbb, maxdepth-1);
   185 }
   186 }
   186 
   187 
   187 void KdTree::build()
   188 void KdTree::build()
   188 {
   189 {
   189 	dbgmsg(1, "* building kd-tree\n");
   190 	dbgmsg(1, "* building kd-tree\n");
   209 	if (!bbox.intersect(ray, a, b))
   210 	if (!bbox.intersect(ray, a, b))
   210 		return NULL;
   211 		return NULL;
   211 
   212 
   212 	/* pointers to the far child node and current node */
   213 	/* pointers to the far child node and current node */
   213 	KdNode *farchild, *node;
   214 	KdNode *farchild, *node;
   214 	node = root; /* start from the kd-tree root node */
   215 	node = root;
   215 
   216 
   216 	/* std vector is much faster than stack */
   217 	StackElem stack[max_depth];
   217 	vector<StackElem*> st;
   218 
   218 
   219 	int entry = 0, exit = 1;
   219 	StackElem *enPt = new StackElem(NULL, a,
   220 	stack[entry].t = a;
   220 		/* distinguish between internal and external origin of a ray*/
   221 
   221 		a >= 0.0 ?
   222 	/* distinguish between internal and external origin of a ray*/
   222 			ray.o + ray.dir * a : /* external */
   223 	if (a >= 0.0)
   223 			ray.o);               /* internal */
   224 		stack[entry].pb = ray.o + ray.dir * a; /* external */
       
   225 	else
       
   226 		stack[entry].pb = ray.o;               /* internal */
   224 
   227 
   225 	/* setup initial exit point in the stack */
   228 	/* setup initial exit point in the stack */
   226 	StackElem *exPt = new StackElem(NULL, b, ray.o + ray.dir * b);
   229 	stack[exit].t = b;
   227 	st.push_back(exPt);
   230 	stack[exit].pb = ray.o + ray.dir * b;
       
   231 	stack[exit].node = NULL;
   228 
   232 
   229 	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
   233 	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
       
   234 	Float splitVal;
       
   235 	int axis;
       
   236 	static const int mod3[] = {0,1,2,0,1};
       
   237 	const Vector3 invdir = 1 / ray.dir;
   230 	while (node)
   238 	while (node)
   231 	{
   239 	{
   232 		exPt = st.back();
       
   233 		/* loop until a leaf is found */
   240 		/* loop until a leaf is found */
   234 		while (!node->isLeaf())
   241 		while (!node->isLeaf())
   235 		{
   242 		{
   236 			/* retrieve position of splitting plane */
   243 			/* retrieve position of splitting plane */
   237 			Float splitVal = node->getSplit();
   244 			splitVal = node->getSplit();
   238 			short axis = node->getAxis();
   245 			axis = node->getAxis();
   239 
   246 
   240 			if (enPt->pb[axis] <= splitVal)
   247 			if (stack[entry].pb[axis] <= splitVal)
   241 			{
   248 			{
   242 				if (exPt->pb[axis] <= splitVal)
   249 				if (stack[exit].pb[axis] <= splitVal)
   243 				{ /* case N1, N2, N3, P5, Z2, and Z3 */
   250 				{ /* case N1, N2, N3, P5, Z2, and Z3 */
   244 					node = node->getLeftChild();
   251 					node = node->getLeftChild();
   245 					continue;
   252 					continue;
   246 				}
   253 				}
   247 				if (exPt->pb[axis] == splitVal)
   254 				if (stack[exit].pb[axis] == splitVal)
   248 				{ /* case Z1 */
   255 				{ /* case Z1 */
   249 					node = node->getRightChild();
   256 					node = node->getRightChild();
   250 					continue;
   257 					continue;
   251 				}
   258 				}
   252 				/* case N4 */
   259 				/* case N4 */
   253 				farchild = node->getRightChild();
   260 				farchild = node->getRightChild();
   254 				node = node->getLeftChild();
   261 				node = node->getLeftChild();
   255 			}
   262 			}
   256 			else
   263 			else
   257 			{ /* (enPt->pb[axis] > splitVal) */
   264 			{ /* (stack[entry].pb[axis] > splitVal) */
   258 				if (splitVal < exPt->pb[axis])
   265 				if (splitVal < stack[exit].pb[axis])
   259 				{
   266 				{
   260 					/* case P1, P2, P3, and N5 */
   267 					/* case P1, P2, P3, and N5 */
   261 					node = node->getRightChild();
   268 					node = node->getRightChild();
   262 					continue;
   269 					continue;
   263 				}
   270 				}
   264 				/* case P4*/
   271 				/* case P4 */
   265 				farchild = node->getLeftChild();
   272 				farchild = node->getLeftChild();
   266 				node = node->getRightChild();
   273 				node = node->getRightChild();
   267 			}
   274 			}
   268 
   275 
   269 			/* case P4 or N4 . . . traverse both children */
   276 			/* case P4 or N4 . . . traverse both children */
   270 
   277 
   271 			/* signed distance to the splitting plane */
   278 			/* signed distance to the splitting plane */
   272 			t = (splitVal - ray.o.cell[axis]) / ray.dir.cell[axis];
   279 			t = (splitVal - ray.o[axis]) * invdir[axis];
   273 
   280 
   274 			/* setup the new exit point and push it onto stack */
   281 			/* setup the new exit point and push it onto stack */
   275 			exPt = new StackElem(farchild, t, Vector3());
   282 			register int tmp = exit;
   276 			exPt->pb.cell[axis] = splitVal;
   283 
   277 			exPt->pb.cell[(axis+1)%3] = ray.o.cell[(axis+1)%3] + t * ray.dir.cell[(axis+1)%3];
   284 			exit++;
   278 			exPt->pb.cell[(axis+2)%3] = ray.o.cell[(axis+2)%3] + t * ray.dir.cell[(axis+2)%3];
   285 			if (exit == entry)
   279 			st.push_back(exPt);
   286 				exit++;
   280 		} /* while */
   287 			assert(exit < max_depth);
       
   288 
       
   289 			stack[exit].prev = tmp;
       
   290 			stack[exit].t = t;
       
   291 			stack[exit].node = farchild;
       
   292 			stack[exit].pb.cell[axis] = splitVal;
       
   293 			stack[exit].pb.cell[mod3[axis+1]] = ray.o.cell[mod3[axis+1]]
       
   294 				+ t * ray.dir.cell[mod3[axis+1]];
       
   295 			stack[exit].pb.cell[mod3[axis+2]] = ray.o.cell[mod3[axis+2]]
       
   296 				+ t * ray.dir.cell[mod3[axis+2]];
       
   297 		}
   281 
   298 
   282 		/* current node is the leaf . . . empty or full */
   299 		/* current node is the leaf . . . empty or full */
   283 		/* "intersect ray with each object in the object list, discarding "
       
   284 		"those lying before stack[enPt].t or farther than stack[exPt].t" */
       
   285 		Shape *nearest_shape = NULL;
   300 		Shape *nearest_shape = NULL;
   286 		Float dist = exPt->t;
   301 		Float dist = stack[exit].t;
   287 		ShapeList::iterator shape;
   302 		ShapeList::iterator shape;
   288 		for (shape = node->shapes->begin(); shape != node->shapes->end(); shape++)
   303 		for (shape = node->getShapes()->begin(); shape != node->getShapes()->end(); shape++)
   289 			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
   304 			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
   290 			&& dist >= enPt->t)
   305 			&& dist >= stack[entry].t)
   291 			{
   306 			{
   292 				nearest_shape = *shape;
   307 				nearest_shape = *shape;
   293 				nearest_distance = dist;
   308 				nearest_distance = dist;
   294 			}
   309 			}
   295 
   310 
   296 		delete enPt;
       
   297 
       
   298 		if (nearest_shape)
   311 		if (nearest_shape)
   299 		{
       
   300 			while (!st.empty())
       
   301 			{
       
   302 				delete st.back();
       
   303 				st.pop_back();
       
   304 			}
       
   305 			return nearest_shape;
   312 			return nearest_shape;
   306 		}
   313 
   307 
   314 		entry = exit;
   308 		enPt = exPt;
   315 
   309 
   316 		/* retrieve the pointer to the next node,
   310 		/* retrieve the pointer to the next node, it is possible that ray traversal terminates */
   317 		it is possible that ray traversal terminates */
   311 		node = enPt->node;
   318 		node = stack[entry].node;
   312 		st.pop_back();
   319 		exit = stack[entry].prev;
   313 	} /* while */
   320 	}
   314 
       
   315 	delete enPt;
       
   316 
   321 
   317 	/* ray leaves the scene */
   322 	/* ray leaves the scene */
   318 	return NULL;
   323 	return NULL;
   319 }
   324 }
   320 
   325 
   324 	if (!built)
   329 	if (!built)
   325 		return;
   330 		return;
   326 	if (node == NULL)
   331 	if (node == NULL)
   327 		node = root;
   332 		node = root;
   328 	if (node->isLeaf())
   333 	if (node->isLeaf())
   329 		str << "(leaf: " << node->shapes->size() << " shapes)";
   334 		str << "(leaf: " << node->getShapes()->size() << " shapes)";
   330 	else
   335 	else
   331 	{
   336 	{
   332 		str << "(split " << (char)('x'+node->getAxis()) << " at " << node->getSplit() << "; L=";
   337 		str << "(split " << (char)('x'+node->getAxis()) << " at " << node->getSplit() << "; L=";
   333 		save(str, node->getLeftChild());
   338 		save(str, node->getLeftChild());
   334 		str << "; R=";
   339 		str << "; R=";