/**
 * @file  vector.h
 * @brief Vector class with Colour alias
 *
 * This file is part of Pyrit Ray Tracer.
 *
 * Copyright 2006, 2007, 2008  Radek Brich
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

#ifndef VECTOR_H
#define VECTOR_H

#include <math.h>
#include <iostream>

#include "common.h"
#include "simd.h"

using namespace std;

#define NO_SIMD_VECTOR

/**
 * three (four) cell vector
 */
class Vector
{
public:
	// data
	union {
#ifndef NO_SIMD
		mfloat4 mf4;
#endif
		Float cell[4];
		struct { Float x, y, z, w; };
		struct { Float r, g, b, a; };
	};

	// constructors
#ifndef NO_SIMD
	Vector(mfloat4 m): mf4(m) {};
#endif
	Vector(): x(0.0f), y(0.0f), z(0.0f), w(1.0) {};
	Vector(Float ax, Float ay, Float az): x(ax), y(ay), z(az), w(1.0) {};

	// index operator
	const Float &operator[](int index) const { return cell[index]; };
	Float &operator[](int index) { return cell[index]; };

	bool operator==(const Vector &v) const { return x==v.x && y==v.y && z==v.z; };

	/** Normalize the vector */
	Vector normalize()
	{
		const Float f = 1.0f / mag();
		*this *= f;
		return *this;
	};

	/** Get normalized copy of vector */
	friend Vector normalize(const Vector &v)
	{
		const Float f = 1.0f / v.mag();
		return v * f;
	};

	/** Square magnitude */
	Float mag2() const	{ return dot(*this, *this); };

	/** Vector magnitude */
	Float mag() const	{ return sqrtf(mag2()); };

	/** Get negative vector */
	Vector operator-() const { return Vector(-x, -y, -z); };

	/** Accumulate. Useful for colors. */
	Vector operator+=(const Vector &v)
	{
#ifdef NO_SIMD_VECTOR
		x += v.x;
		y += v.y;
		z += v.z;
#else
		mf4 = madd(mf4, v.mf4);
#endif
		return *this;
	};

	/** Multiply by scalar */
	Vector operator*=(const Float &f)
	{
		x *= f;
		y *= f;
		z *= f;
		return *this;
	};

	/** Cut with scalar */
	Vector operator/=(const Float &f)
	{
		Float finv = 1.0f / f;
		x *= finv;
		y *= finv;
		z *= finv;
		return *this;
	};

	/** Sum of two vectors */
	friend Vector operator+(const Vector &a, const Vector &b)
	{
#ifdef NO_SIMD_VECTOR
		return Vector(a.x + b.x, a.y + b.y, a.z + b.z);
#else
		return Vector(madd(a.mf4, b.mf4));
#endif
	};

	/** Difference of two vectors */
	friend Vector operator-(const Vector &a, const Vector &b)
	{
#ifdef NO_SIMD_VECTOR
		return Vector(a.x - b.x, a.y - b.y, a.z - b.z);
#else
		return Vector(msub(a.mf4, b.mf4));
#endif
	};

	/** Dot product */
	friend Float dot(const Vector &a, const Vector &b)
	{
		return a.x * b.x + a.y * b.y + a.z * b.z;
	};

	/** Cross product */
	friend Vector cross(const Vector &a, const Vector &b)
	{
		return Vector(a.y * b.z - a.z * b.y,
			a.z * b.x - a.x * b.z,
			a.x * b.y - a.y * b.x);
	};

	/** Get vector multiplied by scalar */
	friend Vector operator*(const Vector &v, const Float &f)
	{
		return Vector(f * v.x, f * v.y, f * v.z);
	};

	/** Get vector multiplied by scalar */
	friend Vector operator*(const Float &f, const Vector &v)
	{
		return v * f;
	};

	/** Get vector divided by scalar */
	friend Vector operator/(const Vector &v, const Float &f)
	{
		const Float finv = 1.0f / f;
		return Vector(v.x * finv, v.y * finv, v.z * finv);
	};

	/** Get f/v, i.e. inverted vector multiplied by scalar */
	friend Vector operator/(const Float &f, const Vector &v)
	{
#ifdef NO_SIMD_VECTOR
		return Vector(f / v.x, f / v.y, f / v.z);
#else
		return Vector(mdiv(mset1(f), v.mf4));
#endif
	};

	/** Add scalar to the vector */
	friend Vector operator+(const Vector &v, const Float &f)
	{
		return Vector(v.x + f, v.y + f, v.z + f);
	};

	/** Subtract scalar from the vector */
	friend Vector operator-(const Vector &v, const Float &f)
	{
		return Vector(v.x - f, v.y - f, v.z - f);
	};

	/** Cell by cell product (only useful for colors) */
	friend Vector operator*(const Vector &a, const Vector &b)
	{
#ifdef NO_SIMD_VECTOR
		return Vector(a.x * b.x, a.y * b.y, a.z * b.z);
#else
		return Vector(mmul(a.mf4, b.mf4));
#endif
	};

	/** Write textual representation of the vector to stream */
	friend ostream & operator<<(ostream &st, const Vector &v)
	{
		return st << "(" << v.x << "," << v.y  << "," << v.z << ")";
	};

	/** Read the vector from stream */
	friend istream & operator>>(istream &st, Vector &v)
	{
		char s[10];
		st.getline(s, 10, '(');
		st >> v.x;
		st.getline(s, 10, ',');
		st >> v.y;
		st.getline(s, 10, ',');
		st >> v.z;
		st.getline(s, 10, ')');
		return st;
	};
};

/** Colour is a alias name of Vector. */
typedef Vector Colour;

#ifndef NO_SIMD
/**
  * Packet of four Vectors
  */
class VectorPacket
{
public:
	union {
		mfloat4 ma[3];
		struct {
			mfloat4 mx;
			mfloat4 my;
			mfloat4 mz;
		};
		struct {
			float x[4];
			float y[4];
			float z[4];
		};
	};

	VectorPacket() {};
	VectorPacket(mfloat4 ax, mfloat4 ay, mfloat4 az):
		mx(ax), my(ay), mz(az) {};
	VectorPacket(const Vector &v):
		mx(mset1(v.x)), my(mset1(v.y)), mz(mset1(v.z)) {};

	Vector getVector(int i) const
	{
		return Vector(x[i], y[i], z[i]);
	};

	void setVector(int i, const Vector &v)
	{
		x[i] = v.x; y[i] = v.y; z[i] = v.z;
	};

	void normalize()
	{
		mfloat4 m,x,y,z;
		x = mmul(mx, mx); // x*x
		y = mmul(my, my); // y*y
		z = mmul(mz, mz); // z*z
		m = madd(madd(x, y), z);     // x*x + y*y + z*z
		m = mdiv(mOne, msqrt(m));   // m = 1/sqrt(m)
		mx = mmul(mx, m);
		my = mmul(my, m);
		mz = mmul(mz, m);
	};

	// accumulate
	VectorPacket operator+=(const VectorPacket &v)
	{
		mx = madd(mx, v.mx);
		my = madd(my, v.my);
		mz = madd(mz, v.mz);
		return *this;
	};

	// add to non-masked components
	VectorPacket selectiveAdd(const mfloat4 &mask, const VectorPacket &v)
	{
		mx = mselect(mask, madd(mx, v.mx), mx);
		my = mselect(mask, madd(my, v.my), my);
		mz = mselect(mask, madd(mz, v.mz), mz);
		return *this;
	};

	// add scalar to non-masked components
	VectorPacket selectiveAdd(const mfloat4 &mask, const mfloat4 &m)
	{
		mx = mselect(mask, madd(mx, m), mx);
		my = mselect(mask, madd(my, m), my);
		mz = mselect(mask, madd(mz, m), mz);
		return *this;
	};

	// dot product
	friend mfloat4 dot(const VectorPacket &a, const VectorPacket &b)
	{
		return madd(madd(
			mmul(a.mx, b.mx),
			mmul(a.my, b.my)),
			mmul(a.mz, b.mz));
	};

	friend VectorPacket operator+(const VectorPacket &a, const VectorPacket &b)
	{
		return VectorPacket(
			madd(a.mx, b.mx),
			madd(a.my, b.my),
			madd(a.mz, b.mz));
	};

	friend VectorPacket operator-(const VectorPacket &a, const VectorPacket &b)
	{
		return VectorPacket(
			msub(a.mx, b.mx),
			msub(a.my, b.my),
			msub(a.mz, b.mz));
	};

	friend VectorPacket operator*(const VectorPacket &v,  const mfloat4 &m)
	{
		return VectorPacket(
			mmul(v.mx, m),
			mmul(v.my, m),
			mmul(v.mz, m));
	};

	friend VectorPacket operator/(const mfloat4 &m, const VectorPacket &v)
	{
		return VectorPacket(
			mdiv(m, v.mx),
			mdiv(m, v.my),
			mdiv(m, v.mz));
	};

	// cell by cell product (only usable for colours)
	friend VectorPacket operator*(const VectorPacket &a,  const VectorPacket &b)
	{
		return VectorPacket(
			mmul(a.mx, b.mx),
			mmul(a.my, b.my),
			mmul(a.mz, b.mz));
	};

	// write to character stream
	friend ostream & operator<<(ostream &st, const VectorPacket &v)
	{
		return st << "[" << v.getVector(0) << "," << v.getVector(1)
			<< "," << v.getVector(2) << "," << v.getVector(3) << ")";
	};

};
#endif

#endif
