#include <algorithm>
#include <stack>

#include "kdtree.h"

class SortableShape
{
public:
	Shape *shape;
	BBox bbox;
	short axis;
	short mark;

	SortableShape(Shape *aShape, short &aAxis): shape(aShape), axis(aAxis), mark(0)
		{ bbox = shape->get_bbox(); };
	friend bool operator<(const SortableShape& a, const SortableShape& b)
		{ return a.bbox.L.cell[a.axis] < b.bbox.L.cell[b.axis]; };
	void setAxis(short aAxis) { axis = aAxis; };
	void setMark() { mark = 1; };
	short hasMark() { return mark; };
};

class SortableShapeList: public vector<SortableShape>
{
public:
	SortableShapeList(ShapeList &shapes, short axis)
	{
		ShapeList::iterator shape;
		for (shape = shapes.begin(); shape != shapes.end(); shape++)
			push_back(SortableShape(*shape, axis));
	};
};

class SplitPos
{
public:
	float pos;
	int lnum, rnum;
	SplitPos(): pos(0.0), lnum(0), rnum(0) {};
	SplitPos(float &aPos): pos(aPos), lnum(0), rnum(0) {};
	friend bool operator<(const SplitPos& a, const SplitPos& b)
		{ return a.pos < b.pos; };
	friend bool operator==(const SplitPos& a, const SplitPos& b)
		{ return a.pos == b.pos; };
};

class SplitList: public vector<SplitPos>
{
};

// stack element for kd-tree traversal
struct StackElem
{
	KdNode* node; /* pointer to far child */
	float t; /* the entry/exit signed distance */
	Vector3 pb; /* the coordinates of entry/exit point */
};

// ----------------------------------------

void Container::addShape(Shape* aShape)
{
	shapes.push_back(aShape);
	if (shapes.size() == 0) {
		/* initialize bounding box */
		bbox = aShape->get_bbox();
	} else {
		/* adjust bounding box */
		BBox shapebb = aShape->get_bbox();
		if (shapebb.L.x < bbox.L.x)  bbox.L.x = shapebb.L.x;
		if (shapebb.L.y < bbox.L.y)  bbox.L.y = shapebb.L.y;
		if (shapebb.L.z < bbox.L.z)  bbox.L.z = shapebb.L.z;
		if (shapebb.H.x > bbox.H.x)  bbox.H.x = shapebb.H.x;
		if (shapebb.H.y > bbox.H.y)  bbox.H.y = shapebb.H.y;
		if (shapebb.H.z > bbox.H.z)  bbox.H.z = shapebb.H.z;
	}
};

Shape *Container::nearest_intersection(const Shape *origin_shape, const Ray &ray,
	float &nearest_distance)
{
	Shape *nearest_shape = NULL;
	ShapeList::iterator shape;
	for (shape = shapes.begin(); shape != shapes.end(); shape++)
		if (*shape != origin_shape && (*shape)->intersect(ray, nearest_distance))
			nearest_shape = *shape;
	return nearest_shape;
}

void KdNode::subdivide(BBox bbox, int depth)
{
	if (depth >= 20 || shapes.size() <= 4)
	{
		leaf = true;
		return;
	}

	// choose split axis
	axis = 0;
	if (bbox.h() > bbox.w() && bbox.h() > bbox.d())
		axis = 1;
	if (bbox.d() > bbox.w() && bbox.d() > bbox.h())
		axis = 2;

	// *** find optimal split position
	SortableShapeList sslist(shapes, axis);
	sort(sslist.begin(), sslist.end());

	SplitList splitlist;

	SortableShapeList::iterator sh;
	for (sh = sslist.begin(); sh != sslist.end(); sh++)
	{
		splitlist.push_back(SplitPos(sh->bbox.L.cell[axis]));
		splitlist.push_back(SplitPos(sh->bbox.H.cell[axis]));
	}
	sort(splitlist.begin(), splitlist.end());
	// remove duplicate splits
	splitlist.resize(unique(splitlist.begin(), splitlist.end()) - splitlist.begin());

	// find all posible splits
	SplitList::iterator spl;
	int rest;
	for (spl = splitlist.begin(); spl != splitlist.end(); spl++)
	{
		for (sh = sslist.begin(), rest = sslist.size(); sh != sslist.end(); sh++, rest--)
		{
			if (sh->hasMark())
			{
				spl->lnum++;
				continue;
			}
			// if shape is completely contained in split plane
			if (spl->pos == sh->bbox.L.cell[axis] == sh->bbox.H.cell[axis])
			{
				if (spl->pos - bbox.L.cell[axis] < bbox.H.cell[axis] - spl->pos)
				{
					// left subcell is smaller -> if not empty, put shape here
					if (spl->lnum)
						spl->lnum++;
					else
						spl->rnum++;
				} else {
					// right subcell is smaller
					if (spl->rnum)
						spl->rnum++;
					else
						spl->lnum++;
				}
				sh->setMark();
			} else
			// if shape is on left side of split plane
			if (sh->bbox.H.cell[axis] <= spl->pos)
			{
				spl->lnum++;
				sh->setMark();
			} else
			// if shape occupies both sides of split plane
			if (sh->bbox.L.cell[axis] < spl->pos && sh->bbox.H.cell[axis] > spl->pos)
			{
				spl->lnum++;
				spl->rnum++;
			} else
			// if shape is on right side of split plane
			if (sh->bbox.L.cell[axis] >= spl->pos)
			{
				spl->rnum += rest;
				break;
			}
		}
	}

	// choose best split pos
	const float K = 1.4; // constant, K = cost of traversal / cost of ray-triangle intersection
	float SAV = 2*(bbox.w()*bbox.h() + bbox.w()*bbox.d() + bbox.h()*bbox.d()); // surface area of node
	float cost = SAV * (K + shapes.size()); // initial cost = non-split cost
	SplitPos *splitpos = NULL;
	leaf = true;
	BBox lbb = bbox;
	BBox rbb = bbox;
	for (spl = splitlist.begin(); spl != splitlist.end(); spl++)
	{
		// calculate SAH cost of this split
		lbb.H.cell[axis] = spl->pos;
		rbb.L.cell[axis] = spl->pos;
		float SAL = 2*(lbb.w()*lbb.h() + lbb.w()*lbb.d() + lbb.h()*lbb.d());
		float SAR = 2*(rbb.w()*rbb.h() + rbb.w()*rbb.d() + rbb.h()*rbb.d());
		float splitcost = K + SAL/SAV*(K+spl->lnum) + SAR/SAV*(K+spl->rnum);

		if (splitcost < cost)
		{
			leaf = false;
			cost = splitcost;
			splitpos = &*spl;
		}
	}

	if (leaf)
		return;

#if 1
// export kd-tree as .obj for visualization
// this would be hard to reconstruct later
	static ofstream F("kdtree.obj");
	Vector3 v;
	static int f=0;
	v.cell[axis] = splitpos->pos;
	v.cell[(axis+1)%3] = bbox.L.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bbox.L.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bbox.L.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bbox.H.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bbox.H.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bbox.H.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	v.cell[(axis+1)%3] = bbox.H.cell[(axis+1)%3];
	v.cell[(axis+2)%3] = bbox.L.cell[(axis+2)%3];
	F << "v " << v.x << " " << v.y << " " << v.z << endl;
	F << "f " << f+1 << " " << f+2 << " " << f+3 << " " << f+4 << endl;
	f += 4;
#endif

	split = splitpos->pos;
	float lnum = splitpos->lnum;
	float rnum = splitpos->rnum;

	// split this node
	children = new KdNode[2];
	int state = 0;
	for (sh = sslist.begin(); sh != sslist.end(); sh++)
	{
		// if shape is on left side of split plane
		if (state == 1)
		{ // only right
			children[1].addShape(sh->shape);
			continue;
		}
		if (state == 0)
		{
			if (sh->bbox.H.cell[axis] < split)
			{ // left
				children[0].addShape(sh->shape);
			} else
			if (sh->bbox.H.cell[axis] > split)
			{
				if (sh->bbox.L.cell[axis] < split)
				{ // both
					children[0].addShape(sh->shape);
					children[1].addShape(sh->shape);
				} else
				{ // right
					children[1].addShape(sh->shape);
					state = 1;
				}
			} else
			{ // H == split
				if (sh->bbox.L.cell[axis] < split)
				{ // left
					children[0].addShape(sh->shape);
				} else
				{ // contained
					if (split - bbox.L.cell[axis] < bbox.H.cell[axis] - split)
					{
						// left subcell is smaller -> if not empty, put shape here
						if (lnum)
							children[0].addShape(sh->shape);
						else
							children[1].addShape(sh->shape);
					} else {
						// right subcell is smaller
						if (rnum)
							children[1].addShape(sh->shape);
						else
							children[0].addShape(sh->shape);
					}
				}
			}
		}
	}

	lbb.H.cell[axis] = split;
	rbb.L.cell[axis] = split;

	children[0].subdivide(lbb, depth+1);
	children[1].subdivide(rbb, depth+1);
}

void KdTree::build()
{
	root = new KdNode();
	root->shapes = shapes;
	root->subdivide(bbox, 0);
	built = true;
}

/* algorithm by Vlastimil Havran, Heuristic Ray Shooting Algorithms, appendix C */
Shape *KdTree::nearest_intersection(const Shape *origin_shape, const Ray &ray,
	float &nearest_distance)
{
	float a, b; /* entry/exit signed distance */
	float t;    /* signed distance to the splitting plane */

	/* if we have no tree, fall back to naive test */
	if (!built)
		return Container::nearest_intersection(origin_shape, ray, nearest_distance);

	if (!bbox.intersect(ray, a, b))
		return NULL;

	stack<StackElem*> st;

	/* pointers to the far child node and current node */
	KdNode *farchild, *node;
	node = root; /* start from the kd-tree root node */

	StackElem *enPt = new StackElem();
	enPt->t = a; /* set the signed distance */
	enPt->node = NULL;

	/* distinguish between internal and external origin */
	if (a >= 0.0) /* a ray with external origin */
		enPt->pb = ray.a + ray.dir * a;
	else /* a ray with internal origin */
		enPt->pb = ray.a;

	/* setup initial exit point in the stack */
	StackElem *exPt = new StackElem();
	exPt->t = b;
	exPt->pb = ray.a + ray.dir * b;
	exPt->node = NULL; /* set termination ﬂag */

	st.push(enPt);

	/* loop, traverse through the whole kd-tree, until an object is intersected or ray leaves the scene */
	while (node)
	{
		/* loop until a leaf is found */
		while (!node->isLeaf())
		{
			/* retrieve position of splitting plane */
			float splitVal = node->getSplit();
			short axis = node->getAxis();

			if (enPt->pb[axis] <= splitVal)
			{
				if (exPt->pb[axis] <= splitVal)
				{ /* case N1, N2, N3, P5, Z2, and Z3 */
					node = node->getLeftChild();
					continue;
				}
				if (exPt->pb[axis] == splitVal)
				{ /* case Z1 */
					node = node->getRightChild();
					continue;
				}
				/* case N4 */
				farchild = node->getRightChild();
				node = node->getLeftChild();
			}
			else
			{ /* (enPt->pb[axis] > splitVal) */
				if (splitVal < exPt->pb[axis])
				{
					/* case P1, P2, P3, and N5 */
					node = node->getRightChild();
					continue;
				}
				/* case P4*/
				farchild = node->getLeftChild();
				node = node->getRightChild();
			}

			/* case P4 or N4 . . . traverse both children */

			/* signed distance to the splitting plane */
			t = (splitVal - ray.a.cell[axis]) / ray.dir.cell[axis];

			/* setup the new exit point */
			st.push(exPt);
			exPt = new StackElem();

			/* push values onto the stack */
			exPt->t = t;
			exPt->node = farchild;
			exPt->pb[axis] = splitVal;
			exPt->pb[(axis+1)%3] = ray.a.cell[(axis+1)%3] + t * ray.dir.cell[(axis+1)%3];
			exPt->pb[(axis+2)%3] = ray.a.cell[(axis+2)%3] + t * ray.dir.cell[(axis+2)%3];
		} /* while */

		/* current node is the leaf . . . empty or full */
		/* "intersect ray with each object in the object list, discarding "
		"those lying before stack[enPt].t or farther than stack[exPt].t" */
		Shape *nearest_shape = NULL;
		float dist = exPt->t;
		ShapeList::iterator shape;
		for (shape = node->shapes.begin(); shape != node->shapes.end(); shape++)
			if (*shape != origin_shape && (*shape)->intersect(ray, dist)
			&& dist >= enPt->t)
				nearest_shape = *shape;

		if (nearest_shape)
		{
			nearest_distance = dist;
			return nearest_shape;
		}

		/* pop from the stack */
		delete enPt;
		enPt = exPt; /* the signed distance intervals are adjacent */

		/* retrieve the pointer to the next node, it is possible that ray traversal terminates */
		node = enPt->node;

		// pop
		exPt = st.top();
		st.pop();
	} /* while */

	delete enPt;

	/* ray leaves the scene */
	return NULL;
}

// this should save whole kd-tree with triangles distributed into leaves
void KdTree::save(ostream &str, KdNode *node)
{
	if (!built)
		return;
	if (node == NULL)
		node = root;
	if (node->isLeaf())
		str << "(leaf: " << node->shapes.size() << " shapes)";
	else
	{
		str << "(split " << (char)('x'+node->getAxis()) << " at " << node->getSplit() << "; L=";
		save(str, node->getLeftChild());
		str << "; R=";
		save(str, node->getRightChild());
		str << ";)";
	}
}

// load kd-tree from file/stream
void KdTree::load(istream &str, KdNode *node)
{

}
