#include "test_common.h" #include "./utils/matrix.h" #include "./utils/vector.h" #include "./numerics/matmul.h" #include "./numerics/matvec.h" #include "./decomp/lu.h" //#include // ---------- helpers ---------- template static bool mats_equal_tol(const utils::Matrix& X, const utils::Matrix& Y, double tol = 1e-12) { if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false; for (std::uint64_t i=0;i tol) return false; return true; } template static utils::Matrix identity(std::uint64_t n) { utils::Matrix I(n,n,T(0)); for (std::uint64_t i=0;i static void split_LU(const utils::Matrix& lu, utils::Matrix& L, utils::Matrix& U) { const std::uint64_t n = lu.rows(); L.resize(n,n,T(0)); U.resize(n,n,T(0)); for (std::uint64_t i=0;ij) L(i,j) = lu(i,j); else if (i==j){ L(i,i) = T(1); U(i,i) = lu(i,i); } else U(i,j) = lu(i,j); } } } template static utils::Matrix permutation_from_indx(const std::vector& indx) { const std::uint64_t n = indx.size(); auto P = identity(n); // Apply the same sequence of row swaps that was applied during factorization for (std::uint64_t k=0;k make_A_spd() { utils::Matrix A(3,3,0.0); // [ 4 3 0 // 3 4 -1 // 0 -1 4 ] A(0,0)=4; A(0,1)=3; A(0,2)=0; A(1,0)=3; A(1,1)=4; A(1,2)=-1; A(2,0)=0; A(2,1)=-1; A(2,2)=4; return A; } TEST_CASE(LU_PA_equals_LU) { auto A = make_A_spd(); decomp::LUdcmpd lu(A); utils::Matrix L,U; split_LU(lu.lu, L, U); auto P = permutation_from_indx(lu.indx); auto PA = numerics::matmul(P, A); auto LU = numerics::matmul(L, U); CHECK(mats_equal_tol(PA, LU, 1e-12), "PA should equal LU"); } TEST_CASE(LU_Solve_Vector) { auto A = make_A_spd(); decomp::LUdcmpd lu(A); utils::Vd b(3,0.0); b[0]=1.0; b[1]=2.0; b[2]=3.0; auto x = lu.solve(b); auto Ax = numerics::matvec(A, x); CHECK(b.nearly_equal_vec(Ax, 1e-12), "A*x should equal b"); } TEST_CASE(LU_Solve_Matrix_MultiRHS) { auto A = make_A_spd(); decomp::LUdcmpd lu(A); utils::Matrix B(3,2,0.0); // two RHS columns B(0,0)=1; B(1,0)=2; B(2,0)=3; B(0,1)=4; B(1,1)=5; B(2,1)=6; auto X = lu.solve(B); // 3x2 // Check A*X == B auto AX = numerics::matmul(A, X); CHECK(mats_equal_tol(AX, B, 1e-12), "A*X should equal B"); // And that column-wise solve agrees utils::Vd b0(3,0.0), b1(3,0.0); for (int i=0;i<3;++i){ b0[i]=B(i,0); b1[i]=B(i,1); } auto x0 = lu.solve(b0); auto x1 = lu.solve(b1); CHECK(std::fabs(double(X(0,0)-x0[0]))<1e-12 && std::fabs(double(X(1,0)-x0[1]))<1e-12 && std::fabs(double(X(2,0)-x0[2]))<1e-12, "column 0 mismatch"); CHECK(std::fabs(double(X(0,1)-x1[0]))<1e-12 && std::fabs(double(X(1,1)-x1[1]))<1e-12 && std::fabs(double(X(2,1)-x1[2]))<1e-12, "column 1 mismatch"); } TEST_CASE(LU_Determinant) { auto A = make_A_spd(); decomp::LUdcmpd lu(A); // For this A, det = 24 double d = lu.det(); CHECK(std::fabs(d - 24.0) < 1e-12, "determinant incorrect"); } TEST_CASE(LU_Inverse_via_SolveI) { auto A = make_A_spd(); decomp::LUdcmpd lu(A); // Build identity and solve A * X = I auto I = identity(3); auto Inv = lu.solve(I); // Check A*Inv == I (and Inv*A == I for good measure) auto AInv = numerics::matmul(A, Inv); auto InvA = numerics::matmul(Inv, A); CHECK(mats_equal_tol(AInv, I, 1e-11), "A*Inv should be I"); CHECK(mats_equal_tol(InvA, I, 1e-11), "Inv*A should be I"); } TEST_CASE(LU_NonSquare_Throws) { utils::Matrix A(2,3,1.0); bool threw=false; try { decomp::LUdcmpd lu(A); } catch (const std::runtime_error&) { threw = true; } CHECK(threw, "LU should throw on non-square"); } TEST_CASE(LU_Singular_Throws) { utils::Matrix A(3,3,0.0); // Make two identical rows A(0,0)=1; A(0,1)=2; A(0,2)=3; A(1,0)=1; A(1,1)=2; A(1,2)=3; A(2,0)=0; A(2,1)=1; A(2,2)=4; bool threw=false; try { decomp::LUdcmpd lu(A); } catch (const std::runtime_error&) { threw = true; } CHECK(threw, "LU should throw on singular matrix"); }