Finishing up and starting lu decomp
This commit is contained in:
@@ -0,0 +1,2 @@
|
||||
#define TEST_MAIN
|
||||
#include "test_common.h"
|
||||
@@ -0,0 +1,52 @@
|
||||
#ifndef _test_common_n_
|
||||
#define _test_common_n_
|
||||
|
||||
#pragma once
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct TestFailure : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
|
||||
#define CHECK(cond, msg) do { if (!(cond)) throw TestFailure(msg); } while (0)
|
||||
#define CHECK_EQ(a,b,msg) do { if (!((a)==(b))) { throw TestFailure(std::string(msg) + " (" #a " != " #b ")"); } } while (0)
|
||||
|
||||
#define TEST_CASE(name) \
|
||||
static void name(); \
|
||||
struct name##_registrar { name##_registrar(){ TestRegistry::add(#name, &name);} } name##_registrar_instance; \
|
||||
static void name()
|
||||
|
||||
struct TestRegistry {
|
||||
using Fn = void(*)();
|
||||
static std::vector<std::pair<std::string, Fn>>& list() {
|
||||
static std::vector<std::pair<std::string, Fn>> v; return v;
|
||||
}
|
||||
static void add(const std::string& name, Fn fn) { list().push_back({name, fn}); }
|
||||
};
|
||||
|
||||
// Default test runner main()
|
||||
#ifdef TEST_MAIN
|
||||
int main() {
|
||||
int fails = 0;
|
||||
for (auto& t : TestRegistry::list()) {
|
||||
try {
|
||||
t.second();
|
||||
std::cout << "[PASS] " << t.first << "\n";
|
||||
} catch (const TestFailure& e) {
|
||||
std::cerr << "[FAIL] " << t.first << " -> " << e.what() << "\n";
|
||||
++fails;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "[ERROR] " << t.first << " -> " << e.what() << "\n";
|
||||
++fails;
|
||||
}
|
||||
}
|
||||
std::cout << (fails ? "Some tests failed ❌\n" : "All tests passed ✅\n");
|
||||
return fails ? 1 : 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // _test_common_n_
|
||||
@@ -0,0 +1,126 @@
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h"
|
||||
#include "./numerics/inverse.h"
|
||||
#include "./numerics/matmul.h"
|
||||
|
||||
|
||||
TEST_CASE(Inverse_2x2_WellConditioned) {
|
||||
using T = double;
|
||||
// A = [[4,7],[2,6]] inverse = (1/10) * [[6,-7],[-2,4]]
|
||||
utils::Matrix<T> A(2,2, T{0});
|
||||
A(0,0)=4; A(0,1)=7;
|
||||
A(1,0)=2; A(1,1)=6;
|
||||
|
||||
auto Ainv = numerics::inverse<T>(A); // out-of-place
|
||||
|
||||
// Check A * Ainv ≈ I and Ainv * A ≈ I
|
||||
auto Ileft = numerics::matmul(A, Ainv);
|
||||
auto Iright = numerics::matmul(Ainv, A);
|
||||
|
||||
utils::Md Iref(2,2, T{0});
|
||||
for (uint64_t i=0;i<Iref.rows();++i) Iref(i,i)=T{1};
|
||||
|
||||
//auto Iref = eye<T>(2);
|
||||
|
||||
CHECK((Ileft.nearly_equal(Iref, 1e-12)), "A * inverse(A) ≠ I");
|
||||
CHECK((Iright.nearly_equal(Iref, 1e-12)), "inverse(A) * A ≠ I");
|
||||
}
|
||||
|
||||
TEST_CASE(Inverse_InPlace_Equals_OutOfPlace) {
|
||||
using T = double;
|
||||
utils::Matrix<T> A(3,3, T{0});
|
||||
// A = [[3, 0, 2],
|
||||
// [2, 0, -2],
|
||||
// [0, 1, 1]]
|
||||
A(0,0)=3; A(0,1)=0; A(0,2)= 2;
|
||||
A(1,0)=2; A(1,1)=0; A(1,2)=-2;
|
||||
A(2,0)=0; A(2,1)=1; A(2,2)= 1;
|
||||
|
||||
auto Ainv_ref = numerics::inverse<T>(A); // copy path
|
||||
|
||||
auto A_inp = A;
|
||||
numerics::inplace_inverse<T>(A_inp); // in-place path
|
||||
|
||||
CHECK((A_inp.nearly_equal(Ainv_ref, 1e-12)), "in-place inverse differs from out-of-place");
|
||||
}
|
||||
|
||||
TEST_CASE(Inverse_Singular_Throws) {
|
||||
using T = double;
|
||||
utils::Matrix<T> S(2,2, T{0});
|
||||
// Singular: rows are multiples → det = 0
|
||||
S(0,0)=1; S(0,1)=2;
|
||||
S(1,0)=2; S(1,1)=4;
|
||||
|
||||
bool threw=false;
|
||||
try {
|
||||
auto _ = numerics::inverse<T>(S);
|
||||
(void)_;
|
||||
} catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "inverse should throw on singular matrix");
|
||||
|
||||
threw=false;
|
||||
try {
|
||||
numerics::inplace_inverse<T>(S);
|
||||
} catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "inplace_inverse should throw on singular matrix");
|
||||
}
|
||||
|
||||
TEST_CASE(Inverse_RoundTrip_DiagonallyDominant_5x5) {
|
||||
// Build a well-conditioned 5x5: diagonally dominant
|
||||
utils::Md A(5,5,0.0);
|
||||
for (uint64_t i=0;i<5;++i) {
|
||||
double rowsum = 0.0;
|
||||
for (uint64_t j=0;j<5;++j) {
|
||||
if (i==j) continue;
|
||||
A(i,j) = 0.01 * double(1 + ((i+1)*(j+3)) % 7);
|
||||
rowsum += std::fabs(A(i,j));
|
||||
}
|
||||
A(i,i) = rowsum + 1.0; // strictly diagonally dominant
|
||||
}
|
||||
|
||||
utils::Md A_copy = A; // ensure wrapper doesn't mutate input
|
||||
utils::Md Ainv = numerics::inverse<double>(A);
|
||||
|
||||
// Input must be unchanged by the non-inplace wrapper
|
||||
CHECK(A.nearly_equal(A_copy, 0.0), "inverse wrapper modified input");
|
||||
|
||||
|
||||
utils::Md I(5,5, 0);
|
||||
for (uint64_t i=0;i<I.rows();++i) I(i,i)=1;
|
||||
|
||||
|
||||
auto L = numerics::matmul<double>(A, Ainv);
|
||||
auto R = numerics::matmul<double>(Ainv, A);
|
||||
|
||||
CHECK(L.nearly_equal(I, 1e-10), "A * Ainv not close to I");
|
||||
CHECK(R.nearly_equal(I, 1e-10), "Ainv * A not close to I");
|
||||
}
|
||||
|
||||
TEST_CASE(Inverse_NonSquare_Throws) {
|
||||
// Non-square: 2x3 — algorithm expects square; should throw
|
||||
utils::Md A(2,3,0.0);
|
||||
bool threw = false;
|
||||
try {
|
||||
numerics::inplace_inverse<double>(A);
|
||||
} catch (const std::runtime_error&) {
|
||||
threw = true;
|
||||
} catch (...) {
|
||||
threw = true; // any failure is fine; must not silently succeed
|
||||
}
|
||||
CHECK(threw, "inplace_inverse should throw on non-square matrix");
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE(Inverse_Unknown_Method_Throws) {
|
||||
|
||||
utils::Md A(3,3, 0);
|
||||
for (uint64_t i=0;i<A.rows();++i) A(i,i)=1;
|
||||
|
||||
bool threw = false;
|
||||
try {
|
||||
numerics::inplace_inverse<double>(A, "NotARealMethod");
|
||||
} catch (const std::runtime_error&) {
|
||||
threw = true;
|
||||
}
|
||||
CHECK(threw, "should throw for unknown inverse method");
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h"
|
||||
#include "./numerics/matmul.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
|
||||
// ============ Basic correctness ============
|
||||
TEST_CASE(Matmul_Serial_Simple3x3) {
|
||||
utils::Md A(3,3,0.0), B(3,3,0.0);
|
||||
// A = [[1,2,3],[4,5,6],[7,8,9]]
|
||||
double v=1.0;
|
||||
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) A(i,j)=v++;
|
||||
// B = [[9,8,7],[6,5,4],[3,2,1]]
|
||||
double w=9.0;
|
||||
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) B(i,j)=w--;
|
||||
|
||||
auto C = numerics::matmul<double>(A,B);
|
||||
// Hand-checked first row:
|
||||
// row0 dot columns:
|
||||
// c00 = 1*9 + 2*6 + 3*3 = 30
|
||||
// c01 = 1*8 + 2*5 + 3*2 = 24
|
||||
// c02 = 1*7 + 2*4 + 3*1 = 18
|
||||
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
|
||||
CHECK(C(0,0)==30.0 && C(0,1)==24.0 && C(0,2)==18.0, "first row wrong");
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_OMP_Equals_Serial) {
|
||||
utils::Md A(4,5,0.0), B(5,3,0.0);
|
||||
// Fill deterministic
|
||||
for (uint64_t i=0;i<A.rows();++i)
|
||||
for (uint64_t j=0;j<A.cols();++j)
|
||||
A(i,j) = 0.1*(1 + (i*17 + j*19)%10);
|
||||
for (uint64_t i=0;i<B.rows();++i)
|
||||
for (uint64_t j=0;j<B.cols();++j)
|
||||
B(i,j) = 0.2*(1 + (i*23 + j*29)%10);
|
||||
|
||||
auto Cs = numerics::matmul<double>(A,B);
|
||||
auto Cr = numerics::matmul_rows_omp<double>(A,B);
|
||||
auto Cc = numerics::matmul_collapse_omp<double>(A,B);
|
||||
auto Ca = numerics::matmul_auto<double>(A,B);
|
||||
|
||||
CHECK((Cs.nearly_equal(Cr, 1e-12)), "rows_omp != serial");
|
||||
CHECK((Cs.nearly_equal(Cc, 1e-12)), "collapse_omp != serial");
|
||||
CHECK((Cs.nearly_equal(Ca, 1e-12)), "auto != serial");
|
||||
}
|
||||
|
||||
// ============ Dimension mismatch ============
|
||||
TEST_CASE(Matmul_DimensionMismatch_Throws) {
|
||||
utils::Md A(2,3,0.0), B(4,2,0.0);
|
||||
bool threw=false;
|
||||
try { auto _ = numerics::matmul<double>(A,B); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "serial should throw on dim mismatch");
|
||||
|
||||
threw=false; try { auto _ = numerics::matmul_rows_omp<double>(A,B); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "rows_omp should throw on dim mismatch");
|
||||
|
||||
threw=false; try { auto _ = numerics::matmul_collapse_omp<double>(A,B); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "collapse_omp should throw on dim mismatch");
|
||||
}
|
||||
|
||||
// ============ Edge cases ============
|
||||
TEST_CASE(Matmul_Edges_ZeroDims) {
|
||||
// (0xK) * (KxP) -> (0xP)
|
||||
utils::Md A0(0,5,0.0), B1(5,3,0.0);
|
||||
auto C0 = numerics::matmul<double>(A0,B1);
|
||||
CHECK(C0.rows()==0 && C0.cols()==3, "0xK * KxP shape wrong");
|
||||
|
||||
// (MxK) * (Kx0) -> (Mx0)
|
||||
utils::Md A2(7,4,0.0), B0(4,0,0.0);
|
||||
auto C1 = numerics::matmul<double>(A2,B0);
|
||||
CHECK(C1.rows()==7 && C1.cols()==0, "MxK * Kx0 shape wrong");
|
||||
}
|
||||
|
||||
// ============ Identity sanity ============
|
||||
TEST_CASE(Matmul_Identity) {
|
||||
const uint64_t n=5;
|
||||
utils::Md I(n,n,0.0), A(n,n,0.0);
|
||||
for (uint64_t i=0;i<n;++i) I(i,i)=1.0;
|
||||
for (uint64_t i=0;i<n;++i)
|
||||
for (uint64_t j=0;j<n;++j)
|
||||
A(i,j) = (i==j)? 2.0 : ( (i<j)? 1.0 : -1.0 );
|
||||
|
||||
auto L = numerics::matmul<double>(I,A);
|
||||
auto R = numerics::matmul<double>(A,I);
|
||||
CHECK(L == A, "I*A != A");
|
||||
CHECK(R == A, "A*I != A");
|
||||
}
|
||||
|
||||
// ============ Perf sanity (same kernel: 1 thread vs many) ============
|
||||
template <class F>
|
||||
static double time_it(F&& f, int iters=1) {
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
for (int i=0;i<iters;++i) f();
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
return std::chrono::duration<double>(t1 - t0).count();
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_Perf_Sanity_RowOMP) {
|
||||
#ifndef _OPENMP
|
||||
return;
|
||||
#else
|
||||
int hw = omp_get_max_threads();
|
||||
if (hw <= 1) return;
|
||||
|
||||
const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed
|
||||
utils::Md A(m,k,0.0), B(k,p,0.0);
|
||||
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i+j%7)*0.001;
|
||||
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i*3+j%5)*0.0005;
|
||||
|
||||
// Warm-up
|
||||
(void) numerics::matmul_rows_omp<double>(A,B);
|
||||
|
||||
int prev = omp_get_max_threads();
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(1);
|
||||
numerics::matmul_rows_omp<double>(A,B);
|
||||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(hw);
|
||||
t0 = std::chrono::high_resolution_clock::now();
|
||||
numerics::matmul_rows_omp<double>(A,B);
|
||||
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(prev);
|
||||
|
||||
// Must not be notably slower with many threads
|
||||
CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_Perf_Sanity_CollapseOMP) {
|
||||
#ifndef _OPENMP
|
||||
return;
|
||||
#else
|
||||
int hw = omp_get_max_threads();
|
||||
if (hw <= 1) return;
|
||||
|
||||
const uint64_t m=512, k=512, p=512;
|
||||
utils::Md A(m,k,0.0), B(k,p,0.0);
|
||||
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i*7+j%11)*0.0003;
|
||||
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i%13+j)*0.0002;
|
||||
|
||||
(void) numerics::matmul_collapse_omp<double>(A,B); // warm-up
|
||||
|
||||
int prev = omp_get_max_threads();
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(1);
|
||||
numerics::matmul_collapse_omp<double>(A,B);
|
||||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(hw);
|
||||
t0 = std::chrono::high_resolution_clock::now();
|
||||
numerics::matmul_collapse_omp<double>(A,B);
|
||||
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(prev);
|
||||
|
||||
CHECK(tN <= t1 * 1.05, "collapse_omp: multi-thread slower than single-thread");
|
||||
#endif
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h"
|
||||
|
||||
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||
using utils::Mf; using utils::Md; using utils::Mi;
|
||||
|
||||
|
||||
// ---------- Construction & element access ----------
|
||||
TEST_CASE(Matrix_Construct_Access) {
|
||||
Md M; // default
|
||||
CHECK(M.rows()==0 && M.cols()==0, "default ctor dims wrong");
|
||||
|
||||
Mf A(2,3, 1.0f);
|
||||
CHECK(A.rows()==2 && A.cols()==3, "ctor dims wrong");
|
||||
CHECK(A(0,0)==1.0f && A(1,2)==1.0f, "fill wrong");
|
||||
|
||||
A(0,1)=2.5f; A(1,0)=3.5f;
|
||||
CHECK(A(0,1)==2.5f && A(1,0)==3.5f, "operator() set/get failed");
|
||||
}
|
||||
|
||||
// ---------- Equality, inequality, nearly_equal ----------
|
||||
TEST_CASE(Matrix_Equality) {
|
||||
Mi A(2,2,0), B(2,2,0), C(2,2,1);
|
||||
A(0,0)=1; A(1,1)=1; // A = I
|
||||
B(0,0)=1; B(1,1)=1; // B = I
|
||||
|
||||
CHECK(A == B, "== failed identical");
|
||||
CHECK(!(A != B), "!= failed identical");
|
||||
CHECK(A != C, "!= failed different");
|
||||
|
||||
Md F1(2,2,0.0), F2(2,2,0.0);
|
||||
F1(0,0)=1.0; F1(1,1)=2.0;
|
||||
F2(0,0)=1.0; F2(1,1)=2.0 + 5e-10; // tiny perturbation
|
||||
CHECK(!(F1 == F2), "operator== is exact; should differ");
|
||||
CHECK(F1.nearly_equal(F2, 1e-9), "nearly_equal should accept tiny delta");
|
||||
CHECK(!F1.nearly_equal(F2, 1e-12), "nearly_equal too strict should fail");
|
||||
}
|
||||
|
||||
// ---------- Row helpers ----------
|
||||
TEST_CASE(Matrix_Row_Get_Set) {
|
||||
Mf M(3,4, 0.0f);
|
||||
Vf r(4, 0.0f);
|
||||
for (uint64_t j=0;j<4;++j) r[j] = float(j+1); // [1,2,3,4]
|
||||
|
||||
M.set_row(1, r);
|
||||
auto out = M.get_row(1);
|
||||
CHECK(out == r, "set_row/get_row mismatch");
|
||||
|
||||
// size mismatch should throw
|
||||
bool threw=false;
|
||||
try { Vf bad(3, 9.0f); M.set_row(2, bad); } catch (const std::exception&) { threw=true; }
|
||||
CHECK(threw, "set_row should throw on size mismatch");
|
||||
|
||||
// out of range
|
||||
threw=false;
|
||||
try { (void)M.get_row(3); } catch (const std::out_of_range&) { threw=true; }
|
||||
CHECK(threw, "get_row should throw on OOB index");
|
||||
}
|
||||
|
||||
// ---------- Column helpers ----------
|
||||
TEST_CASE(Matrix_Col_Get_Set) {
|
||||
Md M(3,2, 0.0);
|
||||
Vd c(3, 0.0);
|
||||
c[0]=10; c[1]=20; c[2]=30;
|
||||
|
||||
M.set_col(1, c);
|
||||
auto out = M.get_col(1);
|
||||
CHECK(out == c, "set_col/get_col mismatch");
|
||||
|
||||
// size mismatch should throw
|
||||
bool threw=false;
|
||||
try { Vd bad(2, 9.0); M.set_col(0, bad); } catch (const std::exception&) { threw=true; }
|
||||
CHECK(threw, "set_col should throw on size mismatch");
|
||||
|
||||
// out of range
|
||||
threw=false;
|
||||
try { (void)M.get_col(2); } catch (const std::out_of_range&) { threw=true; }
|
||||
CHECK(threw, "get_col should throw on OOB index");
|
||||
}
|
||||
|
||||
// ---------- swap_rows / swap_cols ----------
|
||||
TEST_CASE(Matrix_Swap_Rows_Cols) {
|
||||
Mi M(2,3,0);
|
||||
// Row 0: [1,2,3], Row 1: [4,5,6]
|
||||
M(0,0)=1; M(0,1)=2; M(0,2)=3;
|
||||
M(1,0)=4; M(1,1)=5; M(1,2)=6;
|
||||
|
||||
M.swap_rows(0,1);
|
||||
CHECK(M(0,0)==4 && M(0,1)==5 && M(0,2)==6, "swap_rows row0 wrong");
|
||||
CHECK(M(1,0)==1 && M(1,1)==2 && M(1,2)==3, "swap_rows row1 wrong");
|
||||
|
||||
// swap back via cols
|
||||
M.swap_cols(0,2);
|
||||
// After swapping col0<->col2:
|
||||
// Row0: [6,5,4], Row1: [3,2,1]
|
||||
CHECK(M(0,0)==6 && M(0,1)==5 && M(0,2)==4, "swap_cols row0 wrong");
|
||||
CHECK(M(1,0)==3 && M(1,1)==2 && M(1,2)==1, "swap_cols row1 wrong");
|
||||
|
||||
// no-op swap (a==b) should not crash or change
|
||||
M.swap_rows(1,1);
|
||||
M.swap_cols(2,2);
|
||||
|
||||
// OOB checks
|
||||
bool threw=false;
|
||||
try { M.swap_rows(5,1); } catch (const std::out_of_range&) { threw=true; }
|
||||
CHECK(threw, "swap_rows should throw on OOB");
|
||||
threw=false;
|
||||
try { M.swap_cols(0,9); } catch (const std::out_of_range&) { threw=true; }
|
||||
CHECK(threw, "swap_cols should throw on OOB");
|
||||
}
|
||||
|
||||
// ---------- data() layout (contiguous row-major) ----------
|
||||
TEST_CASE(Matrix_Data_Layout) {
|
||||
Md M(2,3, 0.0);
|
||||
// Fill increasing sequence
|
||||
double val=1.0;
|
||||
for (uint64_t i=0;i<M.rows();++i)
|
||||
for (uint64_t j=0;j<M.cols();++j)
|
||||
M(i,j) = val++;
|
||||
|
||||
const double* p = M.data();
|
||||
// Expect row-major: [1,2,3,4,5,6]
|
||||
CHECK(p[0]==1.0 && p[1]==2.0 && p[2]==3.0 && p[3]==4.0 && p[4]==5.0 && p[5]==6.0,
|
||||
"data() row-major layout wrong");
|
||||
}
|
||||
|
||||
// ---------- Stream output ----------
|
||||
TEST_CASE(Matrix_StreamOutput) {
|
||||
Mf M(2,2,0.0f);
|
||||
M(0,0)=1.0f; M(0,1)=2.0f;
|
||||
M(1,0)=3.0f; M(1,1)=4.0f;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << M;
|
||||
const std::string s = oss.str();
|
||||
// Format example:
|
||||
// [[1.000, 2.000],
|
||||
// [3.000, 4.000]]
|
||||
CHECK(s.find("[[1.000, 2.000],") != std::string::npos, "ostream first row format");
|
||||
CHECK(s.find("[3.000, 4.000]]") != std::string::npos, "ostream second row format");
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h" // matrix.h, vector.h
|
||||
#include "./numerics/matvec.h" // numerics::matvec / inplace_transpose
|
||||
|
||||
#include <chrono>
|
||||
|
||||
using utils::Vi; using utils::Vf; using utils::Vd;
|
||||
using utils::Mi; using utils::Mf; using utils::Md;
|
||||
|
||||
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// matvec: y = A * x
|
||||
// ------------------------------------------------------------
|
||||
TEST_CASE(Matvec_Serial_Simple) {
|
||||
// A = [[1,2,3],
|
||||
// [4,5,6]]
|
||||
Md A(2,3,0.0);
|
||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
||||
Vd x(3,0.0); x[0]=7; x[1]=8; x[2]=9;
|
||||
|
||||
auto y = numerics::matvec<double>(A,x); // [ 1*7+2*8+3*9 , 4*7+5*8+6*9 ] = [50, 122]
|
||||
CHECK(y.size()==2, "matvec size wrong");
|
||||
CHECK(y[0]==50.0 && y[1]==122.0, "matvec values wrong");
|
||||
}
|
||||
|
||||
TEST_CASE(Matvec_OMP_Equals_Serial) {
|
||||
Md A(3,3,0.0);
|
||||
// A = I * 2
|
||||
for (uint64_t i=0;i<3;++i) A(i,i)=2.0;
|
||||
Vd x(3,0.0); x[0]=1; x[1]=2; x[2]=3;
|
||||
|
||||
auto ys = numerics::matvec<double>(A,x);
|
||||
auto yp = numerics::matvec_omp<double>(A,x);
|
||||
|
||||
CHECK((ys.nearly_equal_vec(yp)), "matvec_omp != matvec");
|
||||
}
|
||||
|
||||
TEST_CASE(Matvec_Auto_Equals_Serial) {
|
||||
Md A(2,2,0.0); A(0,0)=2; A(0,1)=1; A(1,0)=0.5; A(1,1)=3;
|
||||
Vd x(2,0.0); x[0]=4; x[1]=5;
|
||||
|
||||
auto ys = numerics::matvec<double>(A,x);
|
||||
auto ya = numerics::matvec_auto<double>(A,x);
|
||||
|
||||
CHECK((ys.nearly_equal_vec(ya)), "matvec_auto != serial");
|
||||
}
|
||||
|
||||
TEST_CASE(Matvec_DimensionMismatch_Throws) {
|
||||
Md A(2,3,0.0);
|
||||
Vd x(4,0.0);
|
||||
bool threw=false;
|
||||
try { auto _ = numerics::matvec<double>(A,x); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "matvec must throw on dimension mismatch");
|
||||
}
|
||||
|
||||
TEST_CASE(Matvec_Zero_Edges) {
|
||||
Md A(0,3,0.0); // 0x3
|
||||
Vd x(3,1.0);
|
||||
auto y = numerics::matvec<double>(A,x);
|
||||
CHECK(y.size()==0, "0xN * x should return size 0 vector");
|
||||
|
||||
Md B(2,0,0.0); // 2x0
|
||||
Vd z(0,0.0);
|
||||
auto y2 = numerics::matvec<double>(B,z);
|
||||
CHECK(y2.size()==2 && y2[0]==0.0 && y2[1]==0.0, "N×0 * 0 should return zeros of size N");
|
||||
}
|
||||
|
||||
TEST_CASE(Matvec_Float_Tolerance) {
|
||||
Mf A(2,2,0.0f); A(0,0)=1.0f; A(0,1)=2.0f; A(1,0)=3.0f; A(1,1)=4.0f;
|
||||
Vf x(2,0.0f); x[0]=0.1f; x[1]=0.2f;
|
||||
|
||||
auto y1 = numerics::matvec<float>(A,x);
|
||||
auto y2 = numerics::matvec_omp<float>(A,x);
|
||||
|
||||
CHECK((y1.nearly_equal_vec(y2,1e-6f)), "matvec float omp mismatch");
|
||||
}
|
||||
//
|
||||
// ---------- Auto inside an outer parallel region (no accidental nested teams) ----------
|
||||
// We just check correctness; performance is environment-dependent.
|
||||
//
|
||||
TEST_CASE(Matvec_Auto_Inside_Outer_Parallel_Correctness) {
|
||||
const uint64_t m=64, n=64;
|
||||
Md A(m,n,1.0); Vd x(n,2.0);
|
||||
//fill_deterministic(A); fill_deterministic(x);
|
||||
Vd ref = numerics::matvec<double>(A,x);
|
||||
|
||||
// Call auto inside an outer team
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for schedule(static)
|
||||
#endif
|
||||
for (int rep=0; rep<32; ++rep) {
|
||||
auto y = numerics::matvec_auto<double>(A,x);
|
||||
// Each thread checks its own result equals reference
|
||||
if (!(y.nearly_equal_vec(ref))) {
|
||||
throw TestFailure("matvec_auto wrong under outer parallel region");
|
||||
}
|
||||
}
|
||||
}
|
||||
TEST_CASE(Matvec_Speed_Sanity) {
|
||||
const uint64_t m=4096, n=4096; // ~16M MACs; adjust if needed
|
||||
Md A(m,n,1.0); Vd x(n,2.0);
|
||||
//fill_deterministic(A); fill_deterministic(x);
|
||||
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
auto yS = numerics::matvec(A,x);
|
||||
double tp = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
|
||||
|
||||
#ifdef _OPENMP
|
||||
int threads = omp_get_max_threads();
|
||||
#else
|
||||
int threads = 1;
|
||||
#endif
|
||||
|
||||
t0 = std::chrono::high_resolution_clock::now();
|
||||
auto yP = numerics::matvec_omp(A,x);
|
||||
double ts = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
|
||||
|
||||
CHECK((yS.nearly_equal_vec(yP)), "matvec_omp != matvec_serial (large)");
|
||||
// Only enforce basic sanity if we *can* use >1 threads:
|
||||
if (threads > 1) {
|
||||
// Be generous: just require not significantly slower.
|
||||
CHECK(tp <= ts, "matvec_omp unexpectedly much slower than serial");
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// vecmat: y = x * A
|
||||
// ------------------------------------------------------------
|
||||
TEST_CASE(Vecmat_Serial_Simple) {
|
||||
// A = [[1,2],
|
||||
// [3,4],
|
||||
// [5,6]] (3x2)
|
||||
Md A(3,2,0.0);
|
||||
A(0,0)=1; A(0,1)=2;
|
||||
A(1,0)=3; A(1,1)=4;
|
||||
A(2,0)=5; A(2,1)=6;
|
||||
|
||||
Vd x(3,0.0); x[0]=7; x[1]=8; x[2]=9;
|
||||
|
||||
auto y = numerics::vecmat<double>(x,A); // 1*7+3*8+5*9= 76 ; 2*7+4*8+6*9=100
|
||||
CHECK(y.size()==2, "vecmat size wrong");
|
||||
CHECK(y[0]==76.0 && y[1]==100.0, "vecmat values wrong");
|
||||
}
|
||||
|
||||
TEST_CASE(Vecmat_OMP_Equals_Serial) {
|
||||
Md A(2,2,0.0); A(0,0)=2; A(0,1)=1; A(1,0)=5; A(1,1)=-1;
|
||||
Vd x(2,0.0); x[0]=0.5; x[1]=1.5;
|
||||
|
||||
auto ys = numerics::vecmat<double>(x,A);
|
||||
auto yp = numerics::vecmat_omp<double>(x,A);
|
||||
|
||||
CHECK((ys.nearly_equal_vec(yp)), "vecmat_omp != vecmat");
|
||||
}
|
||||
|
||||
TEST_CASE(Vecmat_Auto_Equals_Serial) {
|
||||
Md A(2,3,0.0);
|
||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
||||
Vd x(2,0.0); x[0]=1; x[1]=2;
|
||||
|
||||
auto ys = numerics::vecmat<double>(x,A);
|
||||
auto ya = numerics::vecmat_auto<double>(x,A);
|
||||
|
||||
CHECK((ys.nearly_equal_vec(ya)), "vecmat_auto != serial");
|
||||
}
|
||||
|
||||
TEST_CASE(Vecmat_DimensionMismatch_Throws) {
|
||||
Md A(2,2,0.0);
|
||||
Vd x(3,0.0);
|
||||
bool threw=false;
|
||||
try { auto _ = numerics::vecmat<double>(x,A); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "vecmat must throw on dimension mismatch");
|
||||
}
|
||||
|
||||
TEST_CASE(Vecmat_Zero_Edges) {
|
||||
Md A(0,3,0.0);
|
||||
Vd x(0,0.0);
|
||||
auto y = numerics::vecmat<double>(x,A); // 0×N times N×M → 0×M
|
||||
CHECK(y.size()==3 && y[0]==0.0 && y[1]==0.0 && y[2]==0.0, "0-length x times A wrong");
|
||||
|
||||
Md B(3,0,0.0);
|
||||
Vd z(3,1.0);
|
||||
auto y2 = numerics::vecmat<double>(z,B); // 1x3 * 3x0 → 1x0
|
||||
CHECK(y2.size()==0, "vecmat with N×0 result size wrong");
|
||||
}
|
||||
|
||||
//
|
||||
// ---------- Auto inside an outer parallel region (no accidental nested teams) ----------
|
||||
// We just check correctness; performance is environment-dependent.
|
||||
//
|
||||
TEST_CASE(Vecmat_Auto_Inside_Outer_Parallel_Correctness) {
|
||||
const uint64_t m=64, n=64;
|
||||
Md A(m,n,1.0); Vd x(m,2.0);
|
||||
//fill_deterministic(A); fill_deterministic(x);
|
||||
Vd ref = numerics::vecmat<double>(x,A);
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for schedule(static)
|
||||
#endif
|
||||
for (int rep=0; rep<32; ++rep) {
|
||||
auto y = numerics::vecmat_auto<double>(x,A);
|
||||
if (!(y.nearly_equal_vec(ref))) {
|
||||
throw TestFailure("vecmat_auto wrong under outer parallel region");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_CASE(Vecmat_Speed_Sanity) {
|
||||
const uint64_t m=4096, n=4096;
|
||||
Md A(m,n,1.0); Vd x(m,2.0);
|
||||
//fill_deterministic(A); fill_deterministic(x);
|
||||
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
auto yS = numerics::vecmat<double>(x,A);
|
||||
double ts = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
#ifdef _OPENMP
|
||||
int threads = omp_get_max_threads();
|
||||
#else
|
||||
int threads = 1;
|
||||
#endif
|
||||
|
||||
t0 = std::chrono::high_resolution_clock::now();
|
||||
auto yP = numerics::vecmat_omp<double>(x,A);
|
||||
double tp = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
CHECK((yS.nearly_equal_vec(yP)), "vecmat_omp != vecmat_serial (large)");
|
||||
if (threads > 1) {
|
||||
CHECK(tp <= ts, "vecmat_omp unexpectedly much slower than serial");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h" // matrix.h, vector.h
|
||||
#include "./numerics/transpose.h" // numerics::transpose / inplace_transpose
|
||||
|
||||
using utils::Mi; using utils::Mf; using utils::Md;
|
||||
|
||||
//
|
||||
// ---------- Out-of-place transpose (rectangular) ----------
|
||||
//
|
||||
TEST_CASE(Transpose_Rectangular_OutOfPlace) {
|
||||
// A = [ [1, 2, 3],
|
||||
// [4, 5, 6] ] (2x3)
|
||||
Md A(2,3,0.0);
|
||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
||||
|
||||
auto AT = numerics::transpose(A); // (3x2)
|
||||
|
||||
CHECK(AT.rows()==3 && AT.cols()==2, "shape wrong after transpose");
|
||||
CHECK(AT(0,0)==1 && AT(1,0)==2 && AT(2,0)==3, "first column wrong");
|
||||
CHECK(AT(0,1)==4 && AT(1,1)==5 && AT(2,1)==6, "second column wrong");
|
||||
|
||||
// Involution: T(T(A)) == A
|
||||
auto ATT = numerics::transpose(AT);
|
||||
CHECK(ATT == A, "transpose(transpose(A)) != A");
|
||||
}
|
||||
|
||||
//
|
||||
// ---------- In-place transpose (square) ----------
|
||||
//
|
||||
TEST_CASE(Transpose_Square_InPlace) {
|
||||
// 3x3 with distinct values
|
||||
Mf S(3,3,0.0f);
|
||||
float val = 1.0f;
|
||||
for (uint64_t i=0;i<3;++i)
|
||||
for (uint64_t j=0;j<3;++j)
|
||||
S(i,j) = val++;
|
||||
|
||||
// Make an explicit transpose to compare against
|
||||
auto ST = numerics::transpose(S);
|
||||
|
||||
// In-place should match the out-of-place result
|
||||
numerics::inplace_transpose(S);
|
||||
CHECK(S == ST, "inplace_transpose result mismatch");
|
||||
|
||||
// Involution: applying in-place again should return original
|
||||
numerics::inplace_transpose(S);
|
||||
// Now S should equal the original pre-inplace matrix (which was transposed above)
|
||||
// We can reconstruct original by transposing ST:
|
||||
auto orig = numerics::transpose(ST);
|
||||
CHECK(S == orig, "inplace transpose twice should restore original");
|
||||
}
|
||||
|
||||
//
|
||||
// ---------- In-place transpose must throw on non-square ----------
|
||||
//
|
||||
TEST_CASE(Transpose_InPlace_Throws_On_Rectangular) {
|
||||
Md R(2,3,0.0); // rectangular
|
||||
bool threw = false;
|
||||
try {
|
||||
numerics::inplace_transpose(R);
|
||||
} catch (const std::runtime_error&) {
|
||||
threw = true;
|
||||
}
|
||||
CHECK(threw, "inplace_transpose must throw on non-square matrices");
|
||||
}
|
||||
|
||||
//
|
||||
// ---------- Edge cases: 0x0 and 1x1 ----------
|
||||
//
|
||||
TEST_CASE(Transpose_Edge_0x0_1x1) {
|
||||
// 0x0 should be fine both ways
|
||||
Md E; // 0x0
|
||||
auto ET = numerics::transpose(E);
|
||||
CHECK(ET.rows()==0 && ET.cols()==0, "0x0 transpose shape wrong");
|
||||
// in-place on 0x0 (rows==cols) should not throw
|
||||
numerics::inplace_transpose(E);
|
||||
CHECK(E.rows()==0 && E.cols()==0, "0x0 inplace transpose changed shape");
|
||||
|
||||
// 1x1 should be a no-op and not throw
|
||||
Mi I(1,1,0);
|
||||
I(0,0) = 42;
|
||||
auto IT = numerics::transpose(I);
|
||||
CHECK(IT.rows()==1 && IT.cols()==1 && IT(0,0)==42, "1x1 transpose wrong");
|
||||
numerics::inplace_transpose(I);
|
||||
CHECK(I(0,0)==42, "1x1 inplace transpose changed value");
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h"
|
||||
|
||||
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||
|
||||
//
|
||||
// ---------- Basic construction & access ----------
|
||||
//
|
||||
TEST_CASE(Vector_Construct_Size_Access) {
|
||||
Vd a; // default
|
||||
CHECK(a.size() == 0, "default size must be 0");
|
||||
|
||||
Vf b(3, 1.0f); // (n, fill)
|
||||
CHECK(b.size() == 3, "size wrong");
|
||||
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "fill wrong");
|
||||
|
||||
b[1] = 2.5f;
|
||||
CHECK(b[1] == 2.5f, "operator[] write failed");
|
||||
|
||||
// resize (grow + value)
|
||||
b.resize(5, 7.0f);
|
||||
CHECK(b.size() == 5, "resize grow size wrong");
|
||||
CHECK(b[0] == 1.0f && b[1] == 2.5f && b[2] == 1.0f && b[3] == 7.0f && b[4] == 7.0f,
|
||||
"resize grow values wrong");
|
||||
|
||||
// resize (shrink)
|
||||
b.resize(2);
|
||||
CHECK(b.size() == 2, "resize shrink size wrong");
|
||||
}
|
||||
|
||||
TEST_CASE(Vector_Clear_PushBack) {
|
||||
Vi v(0, 0);
|
||||
v.push_back(10);
|
||||
v.push_back(20);
|
||||
CHECK(v.size() == 2, "push_back size wrong");
|
||||
CHECK(v[0] == 10 && v[1] == 20, "push_back values wrong");
|
||||
|
||||
v.clear();
|
||||
CHECK(v.size() == 0, "clear failed");
|
||||
}
|
||||
//
|
||||
// ---------- Equality / Inequality (tolerant for float/double) ----------
|
||||
//
|
||||
TEST_CASE(Vector_Equality_Tolerant) {
|
||||
Vd a(3, 1.0), b(3, 1.0);
|
||||
CHECK(a == b, "== identical failed");
|
||||
CHECK(!(a != b), "!= identical failed");
|
||||
|
||||
// Tiny perturbation within eps (1e-6 default)
|
||||
b[1] += 1e-7;
|
||||
CHECK(a == b, "== tolerant failed");
|
||||
|
||||
// Larger perturbation should fail equality
|
||||
b[1] += 1e-4;
|
||||
CHECK(a != b, "!= with difference failed");
|
||||
}
|
||||
//
|
||||
// ---------- Scalar arithmetic: +, -, *, / (inplace and returning) ----------
|
||||
//
|
||||
TEST_CASE(Vector_Scalar_Arithmetic) {
|
||||
Vf a(3, 1.0f);
|
||||
|
||||
// inplace
|
||||
a.inplace_add(2); // int convertible to float
|
||||
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_add failed");
|
||||
|
||||
a.inplace_subtract(1.5f);
|
||||
CHECK(std::fabs(a[0] - 1.5f) < 1e-6f &&
|
||||
std::fabs(a[1] - 1.5f) < 1e-6f &&
|
||||
std::fabs(a[2] - 1.5f) < 1e-6f, "inplace_subtract failed");
|
||||
|
||||
a.inplace_multiply(4.0);
|
||||
CHECK(a[0] == 6.0f && a[1] == 6.0f && a[2] == 6.0f, "inplace_multiply failed");
|
||||
|
||||
a.inplace_divide(2);
|
||||
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_divide failed");
|
||||
|
||||
// returning
|
||||
auto b = a + 1.0f;
|
||||
CHECK(b[0] == 4.0f && b[1] == 4.0f && b[2] == 4.0f, "operator+(scalar) failed");
|
||||
|
||||
b = a - 2.0f;
|
||||
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "operator-(scalar) failed");
|
||||
|
||||
b = a * 10; // int -> float
|
||||
CHECK(b[0] == 30.0f && b[1] == 30.0f && b[2] == 30.0f, "operator*(scalar) failed");
|
||||
|
||||
b = a / 3.0f;
|
||||
CHECK(std::fabs(b[0] - 1.0f) < 1e-6f &&
|
||||
std::fabs(b[1] - 1.0f) < 1e-6f &&
|
||||
std::fabs(b[2] - 1.0f) < 1e-6f, "operator/(scalar) failed");
|
||||
|
||||
// scalar on the left (friends implemented for + and *)
|
||||
Vf c(3, 2.0f);
|
||||
auto d = 5 + c; // friend operator+(U, Vector<T>)
|
||||
CHECK(d[0] == 7.0f && d[1] == 7.0f && d[2] == 7.0f, "scalar + vector failed");
|
||||
|
||||
d = 3 * c; // friend operator*(U, Vector<T>)
|
||||
CHECK(d[0] == 6.0f && d[1] == 6.0f && d[2] == 6.0f, "scalar * vector failed");
|
||||
}
|
||||
//
|
||||
// ---------- Vector arithmetic: +, -, *, / (elementwise) ----------
|
||||
//
|
||||
TEST_CASE(Vector_Vector_Arithmetic) {
|
||||
Vd a(3, 1.0), b(3, 2.0);
|
||||
|
||||
// returning
|
||||
auto c = a + b;
|
||||
CHECK(c[0]==3.0 && c[1]==3.0 && c[2]==3.0, "vec + vec failed");
|
||||
|
||||
c = b - a;
|
||||
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec - vec failed");
|
||||
|
||||
c = a * b;
|
||||
CHECK(c[0]==2.0 && c[1]==2.0 && c[2]==2.0, "vec * vec failed");
|
||||
|
||||
c = b / b;
|
||||
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec / vec failed");
|
||||
|
||||
// inplace
|
||||
a = Vd(3, 1.0);
|
||||
a += b;
|
||||
CHECK(a[0]==3.0 && a[1]==3.0 && a[2]==3.0, "inplace vec + vec failed");
|
||||
a -= b;
|
||||
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec - vec failed");
|
||||
a *= b;
|
||||
CHECK(a[0]==2.0 && a[1]==2.0 && a[2]==2.0, "inplace vec * vec failed");
|
||||
a /= b;
|
||||
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec / vec failed");
|
||||
}
|
||||
//
|
||||
// ---------- Size mismatch error paths ----------
|
||||
//
|
||||
TEST_CASE(Vector_SizeMismatch_Throws) {
|
||||
Vd a(3, 1.0), b(4, 2.0);
|
||||
|
||||
bool threw = false;
|
||||
try { auto c = a + b; (void)c; } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "add should throw on size mismatch");
|
||||
|
||||
threw = false;
|
||||
try { a.inplace_subtract(b); } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "inplace_subtract should throw on size mismatch");
|
||||
|
||||
threw = false;
|
||||
try { auto d = a * b; (void)d; } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "multiply should throw on size mismatch");
|
||||
|
||||
threw = false;
|
||||
try { auto s = a.dot(b); (void)s; } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "dot should throw on size mismatch");
|
||||
}
|
||||
|
||||
//
|
||||
// ---------- Power / sqrt ----------
|
||||
//
|
||||
TEST_CASE(Vector_Power_Sqrt) {
|
||||
Vd a(3, 2.0); // [2,2,2]
|
||||
auto b = a.power(3.0); // [8,8,8]
|
||||
CHECK(b[0]==8.0 && b[1]==8.0 && b[2]==8.0, "scalar power failed");
|
||||
|
||||
Vd p(3, 3.0); // [3,3,3]
|
||||
auto c = b.power(p); // 8^3 = 512
|
||||
CHECK(c[0]==512.0 && c[1]==512.0 && c[2]==512.0, "vector power failed");
|
||||
|
||||
Vd d(3, 9.0);
|
||||
auto e = d.sqrt(); // [3,3,3]
|
||||
CHECK(e[0]==3.0 && e[1]==3.0 && e[2]==3.0, "sqrt failed");
|
||||
|
||||
// inplace
|
||||
d.inplace_sqrt(); // becomes [3,3,3]
|
||||
CHECK(d == e, "inplace_sqrt failed");
|
||||
}
|
||||
|
||||
//
|
||||
// ---------- Dot / Sum / Norm / Normalize ----------
|
||||
//
|
||||
TEST_CASE(Vector_Dot_Sum_Norm_Normalize) {
|
||||
Vd a(3, 0.0);
|
||||
a[0]=1.0; a[1]=2.0; a[2]=2.0;
|
||||
|
||||
CHECK(a.sum() == 5.0, "sum failed");
|
||||
CHECK(a.dot(a) == 9.0, "dot self failed");
|
||||
|
||||
auto n = a.norm();
|
||||
CHECK(std::fabs(n - 3.0) < 1e-12, "norm failed");
|
||||
|
||||
auto b = a.normalize();
|
||||
CHECK(std::fabs(b.norm() - 1.0) < 1e-12, "normalize() not unit");
|
||||
|
||||
// inplace normalize
|
||||
a.inplace_normalize();
|
||||
CHECK(std::fabs(a.norm() - 1.0) < 1e-12, "inplace_normalize not unit");
|
||||
|
||||
// zero-norm error
|
||||
Vd z(3, 0.0);
|
||||
bool threw = false;
|
||||
try { z.inplace_normalize(); } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "normalize zero vector must throw");
|
||||
}
|
||||
//
|
||||
// ---------- Stream output (basic sanity) ----------
|
||||
//
|
||||
TEST_CASE(Vector_StreamOutput) {
|
||||
Vi a(3, 2);
|
||||
std::ostringstream oss;
|
||||
oss << a;
|
||||
auto s = oss.str();
|
||||
CHECK(s == "[2, 2, 2]", "ostream<< wrong format");
|
||||
}
|
||||
Reference in New Issue
Block a user