/*
 * kdtree.cc: KdTree class
 *
 * This file is part of Pyrit Ray Tracer.
 *
 * Copyright 2006, 2007  Radek Brich
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#include <algorithm>
#include <stack>

#include "common.h"
#include "kdtree.h"

class ShapeBound
{
public:
	Shape *shape;
	Float pos;
	bool end;
	ShapeBound(Shape *ashape, const Float apos, const bool aend):
		shape(ashape), pos(apos), end(aend) {};
	friend bool operator<(const ShapeBound& a, const ShapeBound& b)
	{
		if (a.pos == b.pos)
			return a.shape < b.shape;
		else
			return a.pos < b.pos;
	};
};

// stack element for kd-tree traversal
class StackElem
{
public:
	KdNode* node; /* pointer to far child */
	Float t; /* the entry/exit signed distance */
	Vector3 pb; /* the coordinates of entry/exit point */
	StackElem(KdNode *anode, const Float &at, const Vector3 &apb):
		node(anode), t(at), pb(apb) {};
};

// ----------------------------------------

KdNode::~KdNode()
{
	if (isLeaf())
		delete shapes;
	else
		delete[] children;
}

// kd-tree recursive build algorithm, inspired by PBRT (www.pbrt.org)
void KdNode::subdivide(BBox bounds, int maxdepth)
{
	if (maxdepth <= 0 || shapes->size() <= 2)
	{
		setLeaf();
		return;
	}

	// choose split axis
	axis = 0;
	if (bounds.h() > bounds.w() && bounds.h() > bounds.d())
		axis = 1;
	if (bounds.d() > bounds.w() && bounds.d() > bounds.h())
		axis = 2;

	// create sorted list of shape bounds (= find all posible splits)
	vector<ShapeBound> edges[3];
	ShapeList::iterator shape;
	for (shape = shapes->begin(); shape != shapes->end(); shape++)
	{
		BBox shapebounds = (*shape)->get_bbox();
	//	for (int ax = 0; ax < 3; ax++)
	//	{
			edges[axis].push_back(ShapeBound(*shape, shapebounds.L[axis], 0));
			edges[axis].push_back(ShapeBound(*shape, shapebounds.H[axis], 1));
	//	}
	}
	sort(edges[axis].begin(), edges[axis].end());

	// choose best split pos
	const Float K = 1.4; // constant, K = cost of traversal / cost of ray-triangle intersection
	Float SAV = (bounds.w()*bounds.h() +  // surface area of node
		bounds.w()*bounds.d() + bounds.h()*bounds.d());
	Float cost = SAV * (K + shapes->size()); // initial cost = non-split cost
	BBox lbb = bounds;
	BBox rbb = bounds;

	vector<ShapeBound>::iterator edge, splitedge = edges[axis].end();
	int lnum = 0, rnum = shapes->size();
	for (edge = edges[axis].begin(); edge != edges[axis].end(); edge++)
	{
		if (edge->end)
			rnum--;

		// calculate SAH cost of this split
		lbb.H.cell[axis] = edge->pos;
		rbb.L.cell[axis] = edge->pos;
		Float SAL = (lbb.w()*lbb.h() + lbb.w()*lbb.d() + lbb.h()*lbb.d());
		Float SAR = (rbb.w()*rbb.h() + rbb.w()*rbb.d() + rbb.h()*rbb.d());
		Float splitcost = K*SAV + SAL*(K + lnum) + SAR*(K + rnum);

		if (splitcost < cost)
		{
			splitedge = edge;
			cost = splitcost;
			split = edge->pos;
		}

		if (!edge->end)
			lnum++;
	}

	if (splitedge == edges[axis].end())
	{
		setLeaf();
		return;
	}

#if 0
// export kd-tree as .obj for visualization
// this would be hard to reconstruct later
	static ofstream F("kdtree.obj");
	Vector3 v;
	static int f=0;
	v.cell[axis] = split;
	v.cell[(axis+1)%3] = bounds.L.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.L.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bounds.L.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.H.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bounds.H.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.H.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bounds.H.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bounds.L.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	F << "f " << f+1 << " " << f+2 << " " << f+3 << " " << f+4 << endl;
	f += 4;
#endif

	// split this node
	delete shapes;
	children = new KdNode[2];
	for (edge = edges[axis].begin(); edge != splitedge; edge++)
		if (!edge->end)
			children[0].addShape(edge->shape);
	for (edge = splitedge; edge < edges[axis].end(); edge++)
		if (edge->end)
			children[1].addShape(edge->shape);

	lbb.H.cell[axis] = split;
	rbb.L.cell[axis] = split;

	children[0].subdivide(lbb, maxdepth-1);
	children[1].subdivide(rbb, maxdepth-1);
}

void KdTree::build()
{
	dbgmsg(1, "* building kd-tree\n");
	root = new KdNode();
	ShapeList::iterator shape;
	for (shape = shapes.begin(); shape != shapes.end(); shape++)
		root->addShape(*shape);
	root->subdivide(bbox, max_depth);
	built = true;
}

/* algorithm by Vlastimil Havran, Heuristic Ray Shooting Algorithms, appendix C */
Shape *KdTree::nearest_intersection(const Shape *origin_shape, const Ray &ray,
	Float &nearest_distance)
{
	Float a, b; /* entry/exit signed distance */
	Float t;    /* signed distance to the splitting plane */

	/* if we have no tree, fall back to naive test */
	if (!built)
		return Container::nearest_intersection(origin_shape, ray, nearest_distance);

	if (!bbox.intersect(ray, a, b))
		return NULL;

	/* pointers to the far child node and current node */
	KdNode *farchild, *node;
	node = root; /* start from the kd-tree root node */

	/* std vector is much faster than stack */
	vector<StackElem*> st;

	StackElem *enPt = new StackElem(NULL, a,
		/* distinguish between internal and external origin of a ray*/
		a >= 0.0 ?
			ray.o + ray.dir * a : /* external */
			ray.o);               /* internal */

	/* setup initial exit point in the stack */
	StackElem *exPt = new StackElem(NULL, b, ray.o + ray.dir * b);
	st.push_back(exPt);

	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
	while (node)
	{
		exPt = st.back();
		/* loop until a leaf is found */
		while (!node->isLeaf())
		{
			/* retrieve position of splitting plane */
			Float splitVal = node->getSplit();
			short axis = node->getAxis();

			if (enPt->pb[axis] <= splitVal)
			{
				if (exPt->pb[axis] <= splitVal)
				{ /* case N1, N2, N3, P5, Z2, and Z3 */
					node = node->getLeftChild();
					continue;
				}
				if (exPt->pb[axis] == splitVal)
				{ /* case Z1 */
					node = node->getRightChild();
					continue;
				}
				/* case N4 */
				farchild = node->getRightChild();
				node = node->getLeftChild();
			}
			else
			{ /* (enPt->pb[axis] > splitVal) */
				if (splitVal < exPt->pb[axis])
				{
					/* case P1, P2, P3, and N5 */
					node = node->getRightChild();
					continue;
				}
				/* case P4*/
				farchild = node->getLeftChild();
				node = node->getRightChild();
			}

			/* case P4 or N4 . . . traverse both children */

			/* signed distance to the splitting plane */
			t = (splitVal - ray.o.cell[axis]) / ray.dir.cell[axis];

			/* setup the new exit point and push it onto stack */
			exPt = new StackElem(farchild, t, Vector3());
			exPt->pb.cell[axis] = splitVal;
			exPt->pb.cell[(axis+1)%3] = ray.o.cell[(axis+1)%3] + t * ray.dir.cell[(axis+1)%3];
			exPt->pb.cell[(axis+2)%3] = ray.o.cell[(axis+2)%3] + t * ray.dir.cell[(axis+2)%3];
			st.push_back(exPt);
		} /* while */

		/* current node is the leaf . . . empty or full */
		/* "intersect ray with each object in the object list, discarding "
		"those lying before stack[enPt].t or farther than stack[exPt].t" */
		Shape *nearest_shape = NULL;
		Float dist = exPt->t;
		ShapeList::iterator shape;
		for (shape = node->shapes->begin(); shape != node->shapes->end(); shape++)
			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
			&& dist >= enPt->t)
			{
				nearest_shape = *shape;
				nearest_distance = dist;
			}

		delete enPt;

		if (nearest_shape)
		{
			while (!st.empty())
			{
				delete st.back();
				st.pop_back();
			}
			return nearest_shape;
		}

		enPt = exPt;

		/* retrieve the pointer to the next node, it is possible that ray traversal terminates */
		node = enPt->node;
		st.pop_back();
	} /* while */

	delete enPt;

	/* ray leaves the scene */
	return NULL;
}

// this should save whole kd-tree with triangles distributed into leaves
void KdTree::save(ostream &str, KdNode *node)
{
	if (!built)
		return;
	if (node == NULL)
		node = root;
	if (node->isLeaf())
		str << "(leaf: " << node->shapes->size() << " shapes)";
	else
	{
		str << "(split " << (char)('x'+node->getAxis()) << " at " << node->getSplit() << "; L=";
		save(str, node->getLeftChild());
		str << "; R=";
		save(str, node->getRightChild());
		str << ";)";
	}
}

// load kd-tree from file/stream
void KdTree::load(istream &str, KdNode *node)
{

}
