Finishing up and starting lu decomp

This commit is contained in:
2025-09-13 21:44:20 +02:00
parent 320436ce98
commit 88087ea6a6
24 changed files with 1502 additions and 699 deletions
+211
View File
@@ -0,0 +1,211 @@
#include "test_common.h"
#include "./utils/utils.h"
using utils::Vf; using utils::Vd; using utils::Vi;
//
// ---------- Basic construction & access ----------
//
TEST_CASE(Vector_Construct_Size_Access) {
Vd a; // default
CHECK(a.size() == 0, "default size must be 0");
Vf b(3, 1.0f); // (n, fill)
CHECK(b.size() == 3, "size wrong");
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "fill wrong");
b[1] = 2.5f;
CHECK(b[1] == 2.5f, "operator[] write failed");
// resize (grow + value)
b.resize(5, 7.0f);
CHECK(b.size() == 5, "resize grow size wrong");
CHECK(b[0] == 1.0f && b[1] == 2.5f && b[2] == 1.0f && b[3] == 7.0f && b[4] == 7.0f,
"resize grow values wrong");
// resize (shrink)
b.resize(2);
CHECK(b.size() == 2, "resize shrink size wrong");
}
TEST_CASE(Vector_Clear_PushBack) {
Vi v(0, 0);
v.push_back(10);
v.push_back(20);
CHECK(v.size() == 2, "push_back size wrong");
CHECK(v[0] == 10 && v[1] == 20, "push_back values wrong");
v.clear();
CHECK(v.size() == 0, "clear failed");
}
//
// ---------- Equality / Inequality (tolerant for float/double) ----------
//
TEST_CASE(Vector_Equality_Tolerant) {
Vd a(3, 1.0), b(3, 1.0);
CHECK(a == b, "== identical failed");
CHECK(!(a != b), "!= identical failed");
// Tiny perturbation within eps (1e-6 default)
b[1] += 1e-7;
CHECK(a == b, "== tolerant failed");
// Larger perturbation should fail equality
b[1] += 1e-4;
CHECK(a != b, "!= with difference failed");
}
//
// ---------- Scalar arithmetic: +, -, *, / (inplace and returning) ----------
//
TEST_CASE(Vector_Scalar_Arithmetic) {
Vf a(3, 1.0f);
// inplace
a.inplace_add(2); // int convertible to float
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_add failed");
a.inplace_subtract(1.5f);
CHECK(std::fabs(a[0] - 1.5f) < 1e-6f &&
std::fabs(a[1] - 1.5f) < 1e-6f &&
std::fabs(a[2] - 1.5f) < 1e-6f, "inplace_subtract failed");
a.inplace_multiply(4.0);
CHECK(a[0] == 6.0f && a[1] == 6.0f && a[2] == 6.0f, "inplace_multiply failed");
a.inplace_divide(2);
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_divide failed");
// returning
auto b = a + 1.0f;
CHECK(b[0] == 4.0f && b[1] == 4.0f && b[2] == 4.0f, "operator+(scalar) failed");
b = a - 2.0f;
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "operator-(scalar) failed");
b = a * 10; // int -> float
CHECK(b[0] == 30.0f && b[1] == 30.0f && b[2] == 30.0f, "operator*(scalar) failed");
b = a / 3.0f;
CHECK(std::fabs(b[0] - 1.0f) < 1e-6f &&
std::fabs(b[1] - 1.0f) < 1e-6f &&
std::fabs(b[2] - 1.0f) < 1e-6f, "operator/(scalar) failed");
// scalar on the left (friends implemented for + and *)
Vf c(3, 2.0f);
auto d = 5 + c; // friend operator+(U, Vector<T>)
CHECK(d[0] == 7.0f && d[1] == 7.0f && d[2] == 7.0f, "scalar + vector failed");
d = 3 * c; // friend operator*(U, Vector<T>)
CHECK(d[0] == 6.0f && d[1] == 6.0f && d[2] == 6.0f, "scalar * vector failed");
}
//
// ---------- Vector arithmetic: +, -, *, / (elementwise) ----------
//
TEST_CASE(Vector_Vector_Arithmetic) {
Vd a(3, 1.0), b(3, 2.0);
// returning
auto c = a + b;
CHECK(c[0]==3.0 && c[1]==3.0 && c[2]==3.0, "vec + vec failed");
c = b - a;
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec - vec failed");
c = a * b;
CHECK(c[0]==2.0 && c[1]==2.0 && c[2]==2.0, "vec * vec failed");
c = b / b;
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec / vec failed");
// inplace
a = Vd(3, 1.0);
a += b;
CHECK(a[0]==3.0 && a[1]==3.0 && a[2]==3.0, "inplace vec + vec failed");
a -= b;
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec - vec failed");
a *= b;
CHECK(a[0]==2.0 && a[1]==2.0 && a[2]==2.0, "inplace vec * vec failed");
a /= b;
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec / vec failed");
}
//
// ---------- Size mismatch error paths ----------
//
TEST_CASE(Vector_SizeMismatch_Throws) {
Vd a(3, 1.0), b(4, 2.0);
bool threw = false;
try { auto c = a + b; (void)c; } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "add should throw on size mismatch");
threw = false;
try { a.inplace_subtract(b); } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "inplace_subtract should throw on size mismatch");
threw = false;
try { auto d = a * b; (void)d; } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "multiply should throw on size mismatch");
threw = false;
try { auto s = a.dot(b); (void)s; } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "dot should throw on size mismatch");
}
//
// ---------- Power / sqrt ----------
//
TEST_CASE(Vector_Power_Sqrt) {
Vd a(3, 2.0); // [2,2,2]
auto b = a.power(3.0); // [8,8,8]
CHECK(b[0]==8.0 && b[1]==8.0 && b[2]==8.0, "scalar power failed");
Vd p(3, 3.0); // [3,3,3]
auto c = b.power(p); // 8^3 = 512
CHECK(c[0]==512.0 && c[1]==512.0 && c[2]==512.0, "vector power failed");
Vd d(3, 9.0);
auto e = d.sqrt(); // [3,3,3]
CHECK(e[0]==3.0 && e[1]==3.0 && e[2]==3.0, "sqrt failed");
// inplace
d.inplace_sqrt(); // becomes [3,3,3]
CHECK(d == e, "inplace_sqrt failed");
}
//
// ---------- Dot / Sum / Norm / Normalize ----------
//
TEST_CASE(Vector_Dot_Sum_Norm_Normalize) {
Vd a(3, 0.0);
a[0]=1.0; a[1]=2.0; a[2]=2.0;
CHECK(a.sum() == 5.0, "sum failed");
CHECK(a.dot(a) == 9.0, "dot self failed");
auto n = a.norm();
CHECK(std::fabs(n - 3.0) < 1e-12, "norm failed");
auto b = a.normalize();
CHECK(std::fabs(b.norm() - 1.0) < 1e-12, "normalize() not unit");
// inplace normalize
a.inplace_normalize();
CHECK(std::fabs(a.norm() - 1.0) < 1e-12, "inplace_normalize not unit");
// zero-norm error
Vd z(3, 0.0);
bool threw = false;
try { z.inplace_normalize(); } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "normalize zero vector must throw");
}
//
// ---------- Stream output (basic sanity) ----------
//
TEST_CASE(Vector_StreamOutput) {
Vi a(3, 2);
std::ostringstream oss;
oss << a;
auto s = oss.str();
CHECK(s == "[2, 2, 2]", "ostream<< wrong format");
}