211 lines
6.4 KiB
C++
211 lines
6.4 KiB
C++
|
|
#include "test_common.h"
|
|
#include "./utils/utils.h"
|
|
|
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
|
|
|
//
|
|
// ---------- Basic construction & access ----------
|
|
//
|
|
TEST_CASE(Vector_Construct_Size_Access) {
|
|
Vd a; // default
|
|
CHECK(a.size() == 0, "default size must be 0");
|
|
|
|
Vf b(3, 1.0f); // (n, fill)
|
|
CHECK(b.size() == 3, "size wrong");
|
|
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "fill wrong");
|
|
|
|
b[1] = 2.5f;
|
|
CHECK(b[1] == 2.5f, "operator[] write failed");
|
|
|
|
// resize (grow + value)
|
|
b.resize(5, 7.0f);
|
|
CHECK(b.size() == 5, "resize grow size wrong");
|
|
CHECK(b[0] == 1.0f && b[1] == 2.5f && b[2] == 1.0f && b[3] == 7.0f && b[4] == 7.0f,
|
|
"resize grow values wrong");
|
|
|
|
// resize (shrink)
|
|
b.resize(2);
|
|
CHECK(b.size() == 2, "resize shrink size wrong");
|
|
}
|
|
|
|
TEST_CASE(Vector_Clear_PushBack) {
|
|
Vi v(0, 0);
|
|
v.push_back(10);
|
|
v.push_back(20);
|
|
CHECK(v.size() == 2, "push_back size wrong");
|
|
CHECK(v[0] == 10 && v[1] == 20, "push_back values wrong");
|
|
|
|
v.clear();
|
|
CHECK(v.size() == 0, "clear failed");
|
|
}
|
|
//
|
|
// ---------- Equality / Inequality (tolerant for float/double) ----------
|
|
//
|
|
TEST_CASE(Vector_Equality_Tolerant) {
|
|
Vd a(3, 1.0), b(3, 1.0);
|
|
CHECK(a == b, "== identical failed");
|
|
CHECK(!(a != b), "!= identical failed");
|
|
|
|
// Tiny perturbation within eps (1e-6 default)
|
|
b[1] += 1e-7;
|
|
CHECK(a == b, "== tolerant failed");
|
|
|
|
// Larger perturbation should fail equality
|
|
b[1] += 1e-4;
|
|
CHECK(a != b, "!= with difference failed");
|
|
}
|
|
//
|
|
// ---------- Scalar arithmetic: +, -, *, / (inplace and returning) ----------
|
|
//
|
|
TEST_CASE(Vector_Scalar_Arithmetic) {
|
|
Vf a(3, 1.0f);
|
|
|
|
// inplace
|
|
a.inplace_add(2); // int convertible to float
|
|
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_add failed");
|
|
|
|
a.inplace_subtract(1.5f);
|
|
CHECK(std::fabs(a[0] - 1.5f) < 1e-6f &&
|
|
std::fabs(a[1] - 1.5f) < 1e-6f &&
|
|
std::fabs(a[2] - 1.5f) < 1e-6f, "inplace_subtract failed");
|
|
|
|
a.inplace_multiply(4.0);
|
|
CHECK(a[0] == 6.0f && a[1] == 6.0f && a[2] == 6.0f, "inplace_multiply failed");
|
|
|
|
a.inplace_divide(2);
|
|
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_divide failed");
|
|
|
|
// returning
|
|
auto b = a + 1.0f;
|
|
CHECK(b[0] == 4.0f && b[1] == 4.0f && b[2] == 4.0f, "operator+(scalar) failed");
|
|
|
|
b = a - 2.0f;
|
|
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "operator-(scalar) failed");
|
|
|
|
b = a * 10; // int -> float
|
|
CHECK(b[0] == 30.0f && b[1] == 30.0f && b[2] == 30.0f, "operator*(scalar) failed");
|
|
|
|
b = a / 3.0f;
|
|
CHECK(std::fabs(b[0] - 1.0f) < 1e-6f &&
|
|
std::fabs(b[1] - 1.0f) < 1e-6f &&
|
|
std::fabs(b[2] - 1.0f) < 1e-6f, "operator/(scalar) failed");
|
|
|
|
// scalar on the left (friends implemented for + and *)
|
|
Vf c(3, 2.0f);
|
|
auto d = 5 + c; // friend operator+(U, Vector<T>)
|
|
CHECK(d[0] == 7.0f && d[1] == 7.0f && d[2] == 7.0f, "scalar + vector failed");
|
|
|
|
d = 3 * c; // friend operator*(U, Vector<T>)
|
|
CHECK(d[0] == 6.0f && d[1] == 6.0f && d[2] == 6.0f, "scalar * vector failed");
|
|
}
|
|
//
|
|
// ---------- Vector arithmetic: +, -, *, / (elementwise) ----------
|
|
//
|
|
TEST_CASE(Vector_Vector_Arithmetic) {
|
|
Vd a(3, 1.0), b(3, 2.0);
|
|
|
|
// returning
|
|
auto c = a + b;
|
|
CHECK(c[0]==3.0 && c[1]==3.0 && c[2]==3.0, "vec + vec failed");
|
|
|
|
c = b - a;
|
|
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec - vec failed");
|
|
|
|
c = a * b;
|
|
CHECK(c[0]==2.0 && c[1]==2.0 && c[2]==2.0, "vec * vec failed");
|
|
|
|
c = b / b;
|
|
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec / vec failed");
|
|
|
|
// inplace
|
|
a = Vd(3, 1.0);
|
|
a += b;
|
|
CHECK(a[0]==3.0 && a[1]==3.0 && a[2]==3.0, "inplace vec + vec failed");
|
|
a -= b;
|
|
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec - vec failed");
|
|
a *= b;
|
|
CHECK(a[0]==2.0 && a[1]==2.0 && a[2]==2.0, "inplace vec * vec failed");
|
|
a /= b;
|
|
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec / vec failed");
|
|
}
|
|
//
|
|
// ---------- Size mismatch error paths ----------
|
|
//
|
|
TEST_CASE(Vector_SizeMismatch_Throws) {
|
|
Vd a(3, 1.0), b(4, 2.0);
|
|
|
|
bool threw = false;
|
|
try { auto c = a + b; (void)c; } catch (const std::runtime_error&) { threw = true; }
|
|
CHECK(threw, "add should throw on size mismatch");
|
|
|
|
threw = false;
|
|
try { a.inplace_subtract(b); } catch (const std::runtime_error&) { threw = true; }
|
|
CHECK(threw, "inplace_subtract should throw on size mismatch");
|
|
|
|
threw = false;
|
|
try { auto d = a * b; (void)d; } catch (const std::runtime_error&) { threw = true; }
|
|
CHECK(threw, "multiply should throw on size mismatch");
|
|
|
|
threw = false;
|
|
try { auto s = a.dot(b); (void)s; } catch (const std::runtime_error&) { threw = true; }
|
|
CHECK(threw, "dot should throw on size mismatch");
|
|
}
|
|
|
|
//
|
|
// ---------- Power / sqrt ----------
|
|
//
|
|
TEST_CASE(Vector_Power_Sqrt) {
|
|
Vd a(3, 2.0); // [2,2,2]
|
|
auto b = a.power(3.0); // [8,8,8]
|
|
CHECK(b[0]==8.0 && b[1]==8.0 && b[2]==8.0, "scalar power failed");
|
|
|
|
Vd p(3, 3.0); // [3,3,3]
|
|
auto c = b.power(p); // 8^3 = 512
|
|
CHECK(c[0]==512.0 && c[1]==512.0 && c[2]==512.0, "vector power failed");
|
|
|
|
Vd d(3, 9.0);
|
|
auto e = d.sqrt(); // [3,3,3]
|
|
CHECK(e[0]==3.0 && e[1]==3.0 && e[2]==3.0, "sqrt failed");
|
|
|
|
// inplace
|
|
d.inplace_sqrt(); // becomes [3,3,3]
|
|
CHECK(d == e, "inplace_sqrt failed");
|
|
}
|
|
|
|
//
|
|
// ---------- Dot / Sum / Norm / Normalize ----------
|
|
//
|
|
TEST_CASE(Vector_Dot_Sum_Norm_Normalize) {
|
|
Vd a(3, 0.0);
|
|
a[0]=1.0; a[1]=2.0; a[2]=2.0;
|
|
|
|
CHECK(a.sum() == 5.0, "sum failed");
|
|
CHECK(a.dot(a) == 9.0, "dot self failed");
|
|
|
|
auto n = a.norm();
|
|
CHECK(std::fabs(n - 3.0) < 1e-12, "norm failed");
|
|
|
|
auto b = a.normalize();
|
|
CHECK(std::fabs(b.norm() - 1.0) < 1e-12, "normalize() not unit");
|
|
|
|
// inplace normalize
|
|
a.inplace_normalize();
|
|
CHECK(std::fabs(a.norm() - 1.0) < 1e-12, "inplace_normalize not unit");
|
|
|
|
// zero-norm error
|
|
Vd z(3, 0.0);
|
|
bool threw = false;
|
|
try { z.inplace_normalize(); } catch (const std::runtime_error&) { threw = true; }
|
|
CHECK(threw, "normalize zero vector must throw");
|
|
}
|
|
//
|
|
// ---------- Stream output (basic sanity) ----------
|
|
//
|
|
TEST_CASE(Vector_StreamOutput) {
|
|
Vi a(3, 2);
|
|
std::ostringstream oss;
|
|
oss << a;
|
|
auto s = oss.str();
|
|
CHECK(s == "[2, 2, 2]", "ostream<< wrong format");
|
|
} |