include/vector.h
branchpyrit
changeset 91 9d66d323c354
parent 87 1081e3dd3f3e
child 92 9af5c039b678
--- a/include/vector.h	Tue Apr 29 23:31:08 2008 +0200
+++ b/include/vector.h	Fri May 02 13:27:47 2008 +0200
@@ -1,5 +1,5 @@
 /*
- * vector.h: Vector3 class with Colour alias
+ * vector.h: Vector class with Colour alias
  *
  * This file is part of Pyrit Ray Tracer.
  *
@@ -37,145 +37,177 @@
 /**
  * three cell vector
  */
-class Vector3
+class Vector
 {
 public:
 	// data
 	union {
-		struct {
-			Float x, y, z;
-		};
-		struct {
-			Float r, g, b;
-		};
-		Float cell[3];
+#ifndef NO_SSE
+		__m128 mps;
+#endif
+		Float cell[4];
+		struct { Float x, y, z, w; };
+		struct { Float r, g, b, a; };
 	};
 
 	// constructors
-	Vector3(): x(0.0f), y(0.0f), z(0.0f) {};
-	Vector3(Float ax, Float ay, Float az): x(ax), y(ay), z(az) {};
+#ifndef NO_SSE
+	Vector(__m128 m): mps(m) {};
+#endif
+	Vector(): x(0.0f), y(0.0f), z(0.0f), w(1.0) {};
+	Vector(Float ax, Float ay, Float az): x(ax), y(ay), z(az), w(1.0) {};
 
 	// index operator
 	const Float &operator[](int index) const { return cell[index]; };
 
-	bool operator==(Vector3 &v) const { return x==v.x && y==v.y && z==v.z; };
+	bool operator==(Vector &v) const { return x==v.x && y==v.y && z==v.z; };
 
 	// normalize
-	Vector3 normalize()
+	Vector normalize()
 	{
-		Float f = 1.0f / mag();
-		x *= f;
-		y *= f;
-		z *= f;
+		const Float f = 1.0f / mag();
+		*this *= f;
 		return *this;
 	};
 
 	// get normalized copy
-	friend Vector3 normalize(Vector3 &v)
+	friend Vector normalize(const Vector &v)
 	{
 		const Float f = 1.0f / v.mag();
 		return v * f;
 	};
 
 	// square magnitude, magnitude
-	Float mag2() const	{ return x * x + y * y + z * z; };
+	Float mag2() const	{ return dot(*this, *this); };
 	Float mag() const	{ return sqrtf(mag2()); };
 
 	// negative
-	Vector3 operator-() const { return Vector3(-x, -y, -z); };
+	Vector operator-() const { return Vector(-x, -y, -z); };
 
 	// accumulate
-	Vector3 operator+=(const Vector3 &v)
+	Vector operator+=(const Vector &v)
 	{
+#ifdef NO_SSE
 		x += v.x;
 		y += v.y;
 		z += v.z;
+#else
+		mps = _mm_add_ps(mps, v.mps);
+#endif
 		return *this;
 	};
 
-	// cut
-	Vector3 operator/=(const Float &f)
+	// multiply
+	Vector operator*=(const Float &f)
 	{
-		x /= f;
-		y /= f;
-		z /= f;
+		x *= f;
+		y *= f;
+		z *= f;
+		return *this;
+	};
+
+
+	// cut
+	Vector operator/=(const Float &f)
+	{
+		Float finv = 1./f;
+		x *= finv;
+		y *= finv;
+		z *= finv;
 		return *this;
 	};
 
 	// sum
-	friend Vector3 operator+(const Vector3 &a, const Vector3 &b)
+	friend Vector operator+(const Vector &a, const Vector &b)
 	{
-		return Vector3(a.x + b.x, a.y + b.y, a.z + b.z);
+#ifdef NO_SSE
+		return Vector(a.x + b.x, a.y + b.y, a.z + b.z);
+#else
+		return Vector(_mm_add_ps(a.mps, b.mps));
+#endif
 	};
 
 	// difference
-	friend Vector3 operator-(const Vector3 &a, const Vector3 &b)
+	friend Vector operator-(const Vector &a, const Vector &b)
 	{
-		return Vector3(a.x - b.x, a.y - b.y, a.z - b.z);
+#ifdef NO_SSE
+		return Vector(a.x - b.x, a.y - b.y, a.z - b.z);
+#else
+		return Vector(_mm_sub_ps(a.mps, b.mps));
+#endif
 	};
 
 	// dot product
-	friend Float dot(const Vector3 &a, const Vector3 &b)
+	friend Float dot(const Vector &a, const Vector &b)
 	{
 		return a.x * b.x + a.y * b.y + a.z * b.z;
 	};
 
 	// cross product
-	friend Vector3 cross(const Vector3 &a, const Vector3 &b)
+	friend Vector cross(const Vector &a, const Vector &b)
 	{
-		return Vector3(a.y * b.z - a.z * b.y,
+		return Vector(a.y * b.z - a.z * b.y,
 			a.z * b.x - a.x * b.z,
 			a.x * b.y - a.y * b.x);
 	};
 
 	// product of vector and scalar
-	friend Vector3 operator*(const Vector3 &v, const Float &f)
+	friend Vector operator*(const Vector &v, const Float &f)
 	{
-		return Vector3(f * v.x, f * v.y, f * v.z);
+		return Vector(f * v.x, f * v.y, f * v.z);
 	};
 
-	friend Vector3 operator*(const Float &f, const Vector3 &v)
+	friend Vector operator*(const Float &f, const Vector &v)
 	{
 		return v * f;
 	};
 
 	// scalar division
-	friend Vector3 operator/(const Vector3 &v, const Float &f)
+	friend Vector operator/(const Vector &v, const Float &f)
 	{
-		return Vector3(v.x / f, v.y / f, v.z / f);
+		const Float finv = 1./f;
+		return Vector(v.x * finv, v.y * finv, v.z * finv);
 	};
 
-	friend Vector3 operator/(const Float &f, const Vector3 &v)
+	friend Vector operator/(const Float &f, const Vector &v)
 	{
-		return Vector3(f / v.x, f / v.y, f / v.z);
+#ifdef NO_SSE
+		return Vector(f / v.x, f / v.y, f / v.z);
+#else
+		return Vector(_mm_div_ps(_mm_set_ps1(f), v.mps));
+#endif
 	};
 
 	// vector plus scalar
-	friend Vector3 operator+(const Vector3 &v, const Float &f)
+	friend Vector operator+(const Vector &v, const Float &f)
 	{
-		return Vector3(v.x + f, v.y + f, v.z + f);
+		return Vector(v.x + f, v.y + f, v.z + f);
 	};
 
 	// vector minus scalar
-	friend Vector3 operator-(const Vector3 &v, const Float &f)
+	friend Vector operator-(const Vector &v, const Float &f)
 	{
-		return Vector3(v.x - f, v.y - f, v.z - f);
+		return Vector(v.x - f, v.y - f, v.z - f);
 	};
 
 	// cell by cell product (only usable for colours)
-	friend Vector3 operator*(const Vector3 &a,  const Vector3 &b)
+	friend Vector operator*(const Vector &a,  const Vector &b)
 	{
-		return Vector3(a.x * b.x, a.y * b.y, a.z * b.z);
+#ifdef NO_SSE
+		return Vector(a.x * b.x, a.y * b.y, a.z * b.z);
+#else
+		return Vector(_mm_mul_ps(a.mps, b.mps));
+#endif
 	};
 
 	// write
-	friend ostream & operator<<(ostream &st, const Vector3 &v)
+	friend ostream & operator<<(ostream &st, const Vector &v)
 	{
 		return st << "(" << v.x << "," << v.y  << "," << v.z << ")";
 	};
 
 	// read
-	friend istream & operator>>(istream &st, Vector3 &v)
+	friend istream & operator>>(istream &st, Vector &v)
 	{
 		char s[10];
 		st.getline(s, 10, '(');
@@ -189,8 +221,9 @@
 	};
 };
 
-typedef Vector3 Colour;
+typedef Vector Colour;
 
+#ifndef NO_SSE
 class VectorPacket
 {
 public:
@@ -211,12 +244,17 @@
 	VectorPacket() {};
 	VectorPacket(__m128 ax, __m128 ay, __m128 az):
 		mx(ax), my(ay), mz(az) {};
-	VectorPacket(const Vector3 &v):
+	VectorPacket(const Vector &v):
 		mx(_mm_set_ps1(v.x)), my(_mm_set_ps1(v.y)), mz(_mm_set_ps1(v.z)) {};
 
-	Vector3 getVector(int i) const
+	Vector getVector(int i) const
 	{
-		return Vector3(x[i], y[i], z[i]);
+		return Vector(x[i], y[i], z[i]);
+	};
+
+	void setVector(int i, const Vector &v)
+	{
+		x[i] = v.x; y[i] = v.y; z[i] = v.z;
 	};
 
 	void normalize()
@@ -234,6 +272,39 @@
 		mz = _mm_mul_ps(mz, m);
 	};
 
+	// accumulate
+	VectorPacket operator+=(const VectorPacket &v)
+	{
+		mx = _mm_add_ps(mx, v.mx);
+		my = _mm_add_ps(my, v.my);
+		mz = _mm_add_ps(mz, v.mz);
+		return *this;
+	};
+
+	// add to non-masked components
+	VectorPacket selectiveAdd(__m128 mask, const VectorPacket &v)
+	{
+		mx = _mm_or_ps(_mm_and_ps(mask, _mm_add_ps(mx, v.mx)),
+			_mm_andnot_ps(mask, mx));
+		my = _mm_or_ps(_mm_and_ps(mask, _mm_add_ps(my, v.my)),
+			_mm_andnot_ps(mask, my));
+		mz = _mm_or_ps(_mm_and_ps(mask, _mm_add_ps(mz, v.mz)),
+			_mm_andnot_ps(mask, mz));
+		return *this;
+	};
+
+	// add scalar to non-masked components
+	VectorPacket selectiveAdd(__m128 mask, const __m128 m)
+	{
+		mx = _mm_or_ps(_mm_and_ps(mask, _mm_add_ps(mx, m)),
+			_mm_andnot_ps(mask, mx));
+		my = _mm_or_ps(_mm_and_ps(mask, _mm_add_ps(my, m)),
+			_mm_andnot_ps(mask, my));
+		mz = _mm_or_ps(_mm_and_ps(mask, _mm_add_ps(mz, m)),
+			_mm_andnot_ps(mask, mz));
+		return *this;
+	};
+
 	// dot product
 	friend __m128 dot(const VectorPacket &a, const VectorPacket &b)
 	{
@@ -275,6 +346,15 @@
 			_mm_div_ps(m, v.mz));
 	};
 
+	// cell by cell product (only usable for colours)
+	friend VectorPacket operator*(const VectorPacket &a,  const VectorPacket &b)
+	{
+		return VectorPacket(
+			_mm_mul_ps(a.mx, b.mx),
+			_mm_mul_ps(a.my, b.my),
+			_mm_mul_ps(a.mz, b.mz));
+	};
+
 	// write to character stream
 	friend ostream & operator<<(ostream &st, const VectorPacket &v)
 	{
@@ -283,5 +363,6 @@
 	};
 
 };
+#endif
 
 #endif