/**
 * @file  shapes.h
 * @brief Shape classes: Box, Sphere, Triangle and helpers
 *
 * This file is part of Pyrit Ray Tracer.
 *
 * Copyright 2006, 2007, 2008  Radek Brich
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#ifndef SHAPES_H
#define SHAPES_H

#include "common.h"
#include "scene.h"
#include "material.h"

/*
triangle intersection alghoritm
options are:
TRI_PLUCKER
TRI_BARI
TRI_BARI_PRE
*/
#if !defined(TRI_PLUCKER) && !defined(TRI_BARI) && !defined(TRI_BARI_PRE)
#	define TRI_BARI_PRE
#endif

/**
 * abstract shape class
 */
class Shape
{
public:
	Material *material;

	Shape() {};
	virtual ~Shape() {};

	/**
	 * intersect ray with sphere
	 * @param[in] ray	the ray
	 * @param[in] dist	maximum allowed distance of intersection
	 * @param[out] dist	distance of the intersection if found, unchanged otherwise
	 * @return true if ray intersects the sphere
	 */
	virtual bool intersect(const Ray &ray, Float &dist) const = 0;

	/**
	 * same as intersect, but for ray packets
	 */
#ifndef NO_SIMD
	virtual mfloat4 intersect_packet(const RayPacket &rays, mfloat4 &dists) const
	{
		union {
			mfloat4 mresults;
			int32_t results[4];
		};
		union {
			mfloat4 m;
			float arr[4];
		} d = { dists };
		for (int i = 0; i < 4; i++)
			results[i] = intersect(rays[i], d.arr[i]) ? -1 : 0;
		return mresults;
	};
#endif

	/** get all intersections -- not needed nor used currently */
	virtual bool intersect_all(const Ray &ray, Float dist, vector<Float> &allts) const = 0;

	/** test intersection with bounding box */
	virtual bool intersect_bbox(const BBox &bbox) const = 0;

	/** get surface normal at point P */
	virtual const Vector normal(const Vector &P) const = 0;

	/** get bounding box of this shape */
	virtual BBox get_bbox() const = 0;

	/** write textual representation of the shape to stream */
	virtual ostream & dump(ostream &st) const = 0;
};

/**
 * list of shapes
 */
typedef vector<const Shape*> ShapeList;

/**
 * sphere shape
 */
class Sphere: public Shape
{
	Vector center;
	Float radius;

	Float sqr_radius;
	Float inv_radius;
public:
	Sphere(const Vector &acenter, const Float aradius, Material *amaterial):
		center(acenter), radius(aradius),
		sqr_radius(aradius*aradius), inv_radius(1.0f/aradius)
		{ material = amaterial; }

	bool intersect(const Ray &ray, Float &dist) const;
#ifndef NO_SIMD
	mfloat4 intersect_packet(const RayPacket &rays, mfloat4 &dists) const;
#endif
	bool intersect_all(const Ray &ray, Float dist, vector<Float> &allts) const;
	bool intersect_bbox(const BBox &bbox) const;
	const Vector normal(const Vector &P) const { return (P - center) * inv_radius; };
	BBox get_bbox() const;

	const Vector getCenter() const { return center; };
	Float getRadius() const { return radius; };

	ostream & dump(ostream &st) const;
};

/**
 * box shape
 */
class Box: public Shape
{
	Vector L;
	Vector H;
public:
	Box(const Vector &aL, const Vector &aH, Material *amaterial): L(aL), H(aH)
	{
		for (int i = 0; i < 3; i++)
			if (L[i] > H[i])
				swap(L[i], H[i]);
		material = amaterial;
	};
	bool intersect(const Ray &ray, Float &dist) const;
#ifndef NO_SIMD
	mfloat4 intersect_packet(const RayPacket &rays, mfloat4 &dists) const;
#endif
	bool intersect_all(const Ray &ray, Float dist, vector<Float> &allts) const { return false; };
	bool intersect_bbox(const BBox &bbox) const;
	const Vector normal(const Vector &P) const;
	BBox get_bbox() const { return BBox(L, H); };

	const Vector getL() const { return L; };
	const Vector getH() const { return H; };

	ostream & dump(ostream &st) const;
};

/**
 * triangle vertex
 */
class Vertex
{
public:
	Vector P;

	Vertex(const Vector &aP): P(aP) {};
	virtual ~Vertex() {};
	virtual ostream & dump(ostream &st) const;
};

/**
 * triangle vertex with normal
 */
class NormalVertex: public Vertex
{
public:
	Vector N;

	NormalVertex(const NormalVertex *v): Vertex(v->P), N(v->N) {};
	NormalVertex(const Vector &aP): Vertex(aP) {};
	NormalVertex(const Vector &aP, const Vector &aN): Vertex(aP), N(aN) {};
	const Vector &getNormal() { return N; };
	void setNormal(const Vector &aN) { N = aN; };
	ostream & dump(ostream &st) const;
};

/**
 * triangle shape
 */
class Triangle: public Shape
{
	Vector N;
#ifdef TRI_BARI_PRE
	Float nu, nv, nd;
	Float bnu, bnv;
	Float cnu, cnv;
	int k; // dominant axis
#endif
#ifdef TRI_BARI
	int k; // dominant axis
#endif
#ifdef TRI_PLUCKER
	Float pla[6], plb[6], plc[6];
#endif

	const Vector smooth_normal(const Vector &P) const
	{
#ifdef TRI_BARI_PRE
		const Vector &NA = static_cast<NormalVertex*>(A)->N;
		const Vector &NB = static_cast<NormalVertex*>(B)->N;
		const Vector &NC = static_cast<NormalVertex*>(C)->N;
		static const int modulo3[5] = {0,1,2,0,1};
		register const int ku = modulo3[k+1];
		register const int kv = modulo3[k+2];
		const Float pu = P[ku] - A->P[ku];
		const Float pv = P[kv] - A->P[kv];
		const Float u = pv * bnu + pu * bnv;
		const Float v = pu * cnv + pv * cnu;
		Vector n = NA + u * (NB - NA) + v * (NC - NA);
		n.normalize();
		return n;
#else
		return N; // not implemented for other algorithms
#endif
	};

public:
	Vertex *A, *B, *C;

	Triangle() {};
	Triangle(Vertex *aA, Vertex *aB, Vertex *aC, Material *amaterial);
	bool intersect(const Ray &ray, Float &dist) const;
#if !defined(NO_SIMD) && defined(TRI_BARI_PRE)
	mfloat4 intersect_packet(const RayPacket &rays, mfloat4 &dists) const;
#endif
	bool intersect_all(const Ray &ray, Float dist, vector<Float> &allts) const {return false;};
	bool intersect_bbox(const BBox &bbox) const;
	const Vector normal(const Vector &P) const { return (material->smooth ? smooth_normal(P) : N); };
	BBox get_bbox() const;

	/** get real normal of the triangle */
	const Vector getNormal() const { return N; };

	ostream & dump(ostream &st) const;
};

/** template for triangle arrays, currently not used */
template <class T> class Array
{
	T *array;
public:
	Array(int n) { array = new T[n]; };
	~Array() { delete[] array; };
	const T &operator[](int i) const { return array[i]; };
};

typedef Array<Vertex> VertexArray;
typedef Array<NormalVertex> NormalVertexArray;
typedef Array<Triangle> TriangleArray;

#endif
