ready for parralization

This commit is contained in:
2025-09-12 22:58:52 +02:00
parent cb825aec40
commit 320436ce98
14 changed files with 920 additions and 294 deletions
+393 -9
View File
@@ -1,10 +1,12 @@
#include "./utils/utils.h"
#include "./numerics/numerics.h"
#include "./core/omp_config.h"
#include <iostream>
#include <stdexcept>
#define CHECK(cond, msg) \
do { if (!(cond)) throw std::runtime_error(msg); } while (0)
@@ -21,9 +23,17 @@ void expect_throw(F&& f, const char* msg_if_no_throw) {
int main(int argc, char const *argv[])
{
{
using utils::Vf;
// Single-level, 16 threads, runtime may adjust
omp_configure(/*max_levels=*/1, /*dynamic=*/true, /*threads_per_level=*/{16});
using utils::Vi;
using utils::Vf;
using utils::Vd;
using utils::Mi;
using utils::Mf;
using utils::Md;
// ---------------- Equality / Inequality ----------------
{
@@ -154,6 +164,82 @@ int main(int argc, char const *argv[])
CHECK(a == expect, "a /= 2 should produce [3,3,3]");
}
// ---------- sum ----------
{
Vf a(3, 2.0f); // [2,2,2]
CHECK(a.sum() == 6.0f, "sum failed");
}
// ---------- dot ----------
{
Vf a(3, 2.0f); // [2,2,2]
Vf b(3, 3.0f); // [3,3,3]
CHECK(a.dot(b) == 18.0f, "dot failed"); // 2*3 * 3 = 18
Vf c(4, 1.0f);
expect_throw([&]{ (void)a.dot(c); }, "dot should throw on size mismatch");
}
// ---------- norm ----------
{
Vf a(3, 2.0f); // [2,2,2]
float n = a.norm();
CHECK(std::fabs(n - std::sqrt(12.0f)) < 1e-6f, "norm failed");
}
// ---------- normalize ----------
{
Vf a(3, 3.0f); // [3,3,3], norm = sqrt(27)
Vf b = a.normalize();
float n = b.norm();
CHECK(std::fabs(n - 1.0f) < 1e-6f, "normalize failed");
expect_throw([&]{ Vf z(3, 0.0f); z.inplace_normalize(); }, "normalize should throw on zero norm");
}
// ---------- scalar power ----------
{
Vf a(3, 2.0f); // [2,2,2]
Vf c = a.power(3); // [8,8,8]
CHECK(c == Vf(3, 8.0f), "power(scalar) failed");
Vf d = a; d.inplace_power(4); // [16,16,16]
CHECK(d == Vf(3, 16.0f), "inplace_power(scalar) failed");
}
// ---------- vector power ----------
{
Vf base(3, 2.0f); // [2,2,2]
Vf exps; exps.v = {1.0f, 2.0f, 3.0f}; // explicit construction for clarity
Vf out = base.power(exps); // [2^1, 2^2, 2^3] = [2,4,8]
Vf expect; expect.v = {2.0f, 4.0f, 8.0f};
CHECK(out == expect, "power(vector) failed");
expect_throw([&]{ Vf bad(2, 1.0f); (void)base.power(bad); },
"power(vector) should throw on size mismatch");
}
// ---------- square ----------
{
Vf a; a.v = {4.0f, 9.0f, 16.0f};
Vf b = a.sqrt(); // [4,9,16]
Vf expect; expect.v = {2.0f, 3.0f, 4.0f};
CHECK(b == expect, "sqrt failed");
a.inplace_sqrt(); // mutate a to [4,9,16]
CHECK(a == expect, "inplace_square failed");
}
// ---------- scalar commutative friends (s + v, s * v) ----------
{
Vf a(3, 2.0f); // [2,2,2]
Vf b = 3.0f + a; // [5,5,5]
Vf c = a + 3.0f; // [5,5,5]
CHECK(b == c, "s+v commutative failed");
Vf d = 4.0f * a; // [8,8,8]
Vf e = a * 4.0f; // [8,8,8]
CHECK(d == e, "s*v commutative failed");
}
// ---------------- Size mismatch throws ----------------
{
Vf a(3, 1.0f);
@@ -177,12 +263,310 @@ int main(int argc, char const *argv[])
"operator/ should throw (through divide) on size mismatch");
}
Vf b(3, 8.0f); // [1,1,1]
Vf c(3, 2.0f); // [2,2,2]
b.print();
b.inplace_power(2);
b.print();
std::cout << b.norm() << std::endl;
{
auto* a = new utils::Vf(3, 1.0f); // constructor runs
delete a; // <- calls ~Vector() and frees memory
}
{
Vf a(2, 1.0f); // a = [1, 1]
Vf b(2, 1.0f); // b = [1, 1]
a.clear(); // a = []
CHECK(a.size() == 0, "clear() did not empty vector");
a.resize(2, 1.0f); // a = [1, 1]
CHECK(a == b, "clear/resize lifecycle failed");
}
std::cout << "All Vector tests passed ✅\n";
// shape + element access
{
Mf M(3, 4, 0.0f);
CHECK(M.rows()==3 && M.cols()==4, "shape failed");
M(1,1) = 5.0f;
CHECK(M(1,1) == 5.0f, "write/read element failed");
// ensure independence of other cells
CHECK(M(0,0) == 0.0f && M(2,3) == 0.0f, "unexpected element modified");
}
// set/get row (with size checks)
{
Mf M(2, 3, 0.0f); // 2x3
Vf r(3, 0.0f);
r[0]=1; r[1]=2; r[2]=3;
M.set_row(1, r);
Vf g = M.get_row(1);
CHECK(g.size()==3, "get_row size wrong");
CHECK(g[0]==1 && g[1]==2 && g[2]==3, "get_row values wrong");
// size mismatch should throw
bool threw=false;
try {
Vf bad(2, 9.0f);
M.set_row(0, bad);
} catch (const std::exception&) { threw=true; }
CHECK(threw, "set_row should throw on size mismatch");
}
// set/get col (with size checks)
{
Mf M(3, 2, 0.0f); // 3x2
Vf c(3, 0.0f);
c[0]=4; c[1]=5; c[2]=6;
M.set_col(1, c);
Vf h = M.get_col(1);
CHECK(h.size()==3, "get_col size wrong");
CHECK(h[0]==4 && h[1]==5 && h[2]==6, "get_col values wrong");
bool threw=false;
try {
Vf bad(2, 7.0f);
M.set_col(0, bad);
} catch (const std::exception&) { threw=true; }
CHECK(threw, "set_col should throw on size mismatch");
}
// swap_rows / swap_cols
{
Mf M(3, 3, 0.0f);
// set rows to [1,2,3], [4,5,6], [7,8,9]
for (uint64_t j=0;j<3;++j) M(0,j) = 1.0f + j;
for (uint64_t j=0;j<3;++j) M(1,j) = 4.0f + j;
for (uint64_t j=0;j<3;++j) M(2,j) = 7.0f + j;
M.swap_rows(0,2);
CHECK(M(0,0)==7 && M(0,1)==8 && M(0,2)==9, "swap_rows top row wrong");
CHECK(M(2,0)==1 && M(2,1)==2 && M(2,2)==3, "swap_rows bottom row wrong");
M.swap_cols(0,2);
// after col swap: first row should be [9,8,7]
CHECK(M(0,0)==9 && M(0,1)==8 && M(0,2)==7, "swap_cols first row wrong");
// bottom row should be [3,2,1]
CHECK(M(2,0)==3 && M(2,1)==2 && M(2,2)==1, "swap_cols last row wrong");
}
// Exact integer comparison / Floating-point exact equality / Floating-point with small perturbation
{
Mi A(2,2,0);
A(0,0)=1; A(0,1)=2;
A(1,0)=3; A(1,1)=4;
Mi B(2,2,0);
B(0,0)=1; B(0,1)=2;
B(1,0)=3; B(1,1)=4;
Mi C(2,2,0);
C(0,0)=9; C(0,1)=9;
C(1,0)=9; C(1,1)=9;
CHECK(A == B, "Matrix == failed on identical int matrices");
CHECK(!(A != B), "Matrix != failed on identical int matrices");
CHECK(A != C, "Matrix != failed on different int matrices");
// Floating-point exact equality
Md F1(2,2,0.0);
F1(0,0)=1.0; F1(0,1)=2.0;
F1(1,0)=3.0; F1(1,1)=4.0;
Md F2(2,2,0.0);
F2(0,0)=1.0; F2(0,1)=2.0;
F2(1,0)=3.0; F2(1,1)=4.0;
CHECK(F1 == F2, "Matrix == failed on identical float matrices");
// Floating-point with small perturbation
Md F3 = F1;
F3(1,1) += 1e-10; // tiny difference
CHECK(!(F1 == F3), "Matrix == should fail on exact compare with perturbation");
CHECK(F1.nearly_equal(F3, 1e-9), "Matrix nearly_equal failed with tolerance");
// Larger perturbation
F3(1,1) += 1e-3;
CHECK(!F1.nearly_equal(F3, 1e-6), "Matrix nearly_equal should fail when tolerance too small");
CHECK(F1.nearly_equal(F3, 1e-2), "Matrix nearly_equal should pass with loose tolerance");
}
std::cout << "Matrix basic tests passed ✅\n";
// --- Test: normal transpose ---
{
Mf M(2, 3, 0.0f);
// Fill: [ [1,2,3],
// [4,5,6] ]
M(0,0)=1; M(0,1)=2; M(0,2)=3;
M(1,0)=4; M(1,1)=5; M(1,2)=6;
Mf MT = numerics::transpose(M);
// Should be shape 3x2
CHECK(MT.rows()==3 && MT.cols()==2, "transpose shape wrong");
// Values: [ [1,4], [2,5], [3,6] ]
CHECK(MT(0,0)==1 && MT(0,1)==4, "transpose value (0,*) wrong");
CHECK(MT(1,0)==2 && MT(1,1)==5, "transpose value (1,*) wrong");
CHECK(MT(2,0)==3 && MT(2,1)==6, "transpose value (2,*) wrong");
//std::cout << "Original M:\n" << M << "\n";
//std::cout << "Transposed MT:\n" << MT << "\n\n";
}
// --- Test: inplace transpose (square only) ---
{
Mf S(3, 3, 0.0f);
// Fill with row-major increasing
float val = 1.0f;
for (uint64_t i=0;i<S.rows();++i) {
for (uint64_t j=0;j<S.cols();++j) {
S(i,j) = val++;
}
}
// S =
// [1,2,3]
// [4,5,6]
// [7,8,9]
numerics::inplace_transpose(S);
// Expected after transpose:
// [1,4,7]
// [2,5,8]
// [3,6,9]
CHECK(S(0,1)==4 && S(0,2)==7, "inplace_transpose first row wrong");
CHECK(S(1,0)==2 && S(1,2)==8, "inplace_transpose second row wrong");
CHECK(S(2,0)==3 && S(2,1)==6, "inplace_transpose third row wrong");
//std::cout << "Square matrix after inplace_transpose:\n" << S << "\n\n";
}
// --- Test: inplace transpose throws on non-square ---
{
Mf Rect(2, 3, 1.0f);
bool threw = false;
try {
numerics::inplace_transpose(Rect);
} catch (const std::runtime_error&) {
threw = true;
}
CHECK(threw, "inplace_transpose should throw on non-square matrix");
}
std::cout << "Transpose tests passed ✅\n";
// matmul test
{
Md A(2,2,0.0);
A(0,0) = 1; A(0,1) = 2;
A(1,0) = 3; A(1,1) = 4;
Md B(2,2,0.0);
B(0,0) = 2; B(0,1) = 0;
B(1,0) = 1; B(1,1) = 2;
Md C = numerics::matmul(A, B);
// Expected result:
// [1*2+2*1, 1*0+2*2] = [4, 4]
// [3*2+4*1, 3*0+4*2] = [10, 8]
CHECK(C(0,0)==4 && C(0,1)==4, "matmul: first row wrong");
CHECK(C(1,0)==10 && C(1,1)==8, "matmul: second row wrong");
}
std::cout << "Matmul test passed ✅\n";
// matvec test
{
// A = [[1,2,3],
// [4,5,6]] (2x3)
Md A(2,3,0.0);
A(0,0)=1; A(0,1)=2; A(0,2)=3;
A(1,0)=4; A(1,1)=5; A(1,2)=6;
// x = [7,8,9]
Vd x(3,0.0);
x[0]=7; x[1]=8; x[2]=9;
// y = A*x = [50, 122]
Vd y = numerics::matvec<double>(A, x);
CHECK(y.size()==2, "matvec size wrong");
CHECK(y[0]==50 && y[1]==122, "matvec values wrong");
// dimension mismatch should throw
bool threw = false;
try {
Vd bad(4,1.0);
(void)numerics::matvec<double>(A, bad);
} catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "matvec: expected throw on dim mismatch");
}
std::cout << "matvec tests passed ✅\n";
// vecmat test
{
// A = [[1,2],
// [3,4]] (2x2)
Md A(2,2,0.0);
A(0,0)=1; A(0,1)=2;
A(1,0)=3; A(1,1)=4;
// x^T = [5,6]
Vd x(2,0.0);
x[0]=5; x[1]=6;
// y = x^T * A = [5*1+6*3, 5*2+6*4] = [23, 34]
Vd y = numerics::vecmat<double>(x, A);
CHECK(y.size()==2, "vecmat size wrong");
CHECK(y[0]==23 && y[1]==34, "vecmat values wrong");
// mismatch should throw
bool threw = false;
try {
Md B(3,2,0.0); // 3x2, doesn't match x size 2
(void)numerics::vecmat<double>(x, B);
} catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "vecmat: expected throw on dim mismatch");
}
std::cout << "vecmat tests passed ✅\n";
// Inverse 'Gauss-Jordan' tests
{
Md A(2,2,0.0);
A(0,0)=4; A(0,1)=7;
A(1,0)=2; A(1,1)=6;
Md Ai = numerics::inverse(A, "Gauss-Jordan");
Md I1 = numerics::matmul(A, Ai);
Md I2 = numerics::matmul(Ai, A);
Md I(2,2,0.0);
I(0,0)=1; I(1,1)=1;
CHECK(I1.nearly_equal(I), "A*inv(A) != I");
CHECK(I2.nearly_equal(I), "inv(A)*A != I");
}
std::cout << "Inverse 'Gauss-Jordan' tests passed ✅\n";
return 0;
}