#include "test_common.h" #include "./utils/utils.h" // brings in vector.h, matrix.h, etc. #include "./numerics/matmul.h" // numerics::matmul #include "./decomp/lu.h" //#include TEST_CASE(LU_Solve_Vector_Basic) { using T = double; // A * x = b with exact solution x = [1, 1, 2]^T utils::Matrix A(3,3, T{0}); A(0,0)=2; A(0,1)=1; A(0,2)=1; A(1,0)=4; A(1,1)=-6; A(1,2)=0; A(2,0)=-2; A(2,1)=7; A(2,2)=2; utils::Vector b(3, T{0}); b[0]=5; b[1]=-2; b[2]=9; decomp::LUdcmpd lu(A); auto x = lu.solve(b); utils::Vector x_true(3, T{0}); x_true[0]=1; x_true[1]=1; x_true[2]=2; CHECK( (x.nearly_equal(x_true,1e-12)), "LU solve (vector RHS) failed" ); } TEST_CASE(LU_Solve_MatrixRHS_TwoColumns) { using T = double; // Same A, solve two RHS at once utils::Matrix A(3,3, T{0}); A(0,0)=2; A(0,1)=1; A(0,2)=1; A(1,0)=4; A(1,1)=-6; A(1,2)=0; A(2,0)=-2; A(2,1)=7; A(2,2)=2; utils::Matrix B(3,2, T{0}); // First column b1 (same as previous test) B(0,0)=5; B(1,0)=-2; B(2,0)=9; // Second column b2 → choose solution x2 = [0, 2, 1]^T // Compute b2 = A * x2 by hand: // Row0: 2*0 + 1*2 + 1*1 = 3 // Row1: 4*0 -6*2 + 0*1 = -12 // Row2: -2*0 +7*2 + 2*1 = 16 B(0,1)=3; B(1,1)=-12; B(2,1)=16; decomp::LUdcmpd lu(A); auto X = lu.solve(B); // Check A*X ≈ B auto AX = numerics::matmul(A, X); CHECK( AX.nearly_equal(B, 1e-12), "A * X does not match B for matrix RHS" ); } TEST_CASE(LU_Determinant_Known) { using T = double; // Determinant of: // [[1,2,3],[0,1,4],[5,6,0]] is 1 utils::Matrix A(3,3, T{0}); A(0,0)=1; A(0,1)=2; A(0,2)=3; A(1,0)=0; A(1,1)=1; A(1,2)=4; A(2,0)=5; A(2,1)=6; A(2,2)=0; decomp::LUdcmpd lu(A); T det = lu.det(); CHECK( std::fabs(det - T{1}) < 1e-12, "det(A) should be 1" ); } TEST_CASE(LU_Pivoting_Handles_Zero_Leading) { using T = double; // Requires pivoting (A(0,0)=0); system has solution x=[1,2]^T, b = A*x = [2,3]^T utils::Matrix A(2,2, T{0}); A(0,0)=0; A(0,1)=1; A(1,0)=1; A(1,1)=1; utils::Vector b(2, T{0}); b[0]=2; b[1]=3; decomp::LUdcmpd lu(A); auto x = lu.solve(b); utils::Vector x_true(2, T{0}); x_true[0]=1; x_true[1]=2; CHECK( (x.nearly_equal(x_true,1e-12)), "Pivoting failed on zero-leading matrix" ); } TEST_CASE(LU_Input_Unchanged_By_NonInplace_Path) { using T = double; utils::Matrix A(4,4, T{0}); for (uint64_t i=0;i<4;++i) { for (uint64_t j=0;j<4;++j) { A(i,j) = (i==j) ? 3.0 : 0.1 * ((i+1)*(j+2) % 5 + 1); } } utils::Matrix A_copy = A; decomp::LUdcmpd lu(A); // constructor should not mutate input A CHECK( A.nearly_equal(A_copy, 0.0), "LU constructor modified input matrix" ); // Also check solve doesn't mutate RHS copy when using out-of-place convenience utils::Vector b(4, 0.0); for (uint64_t i=0;i<4;++i) b[i] = double(i+1); auto b_copy = b; auto x = lu.solve(b); (void)x; CHECK( (b.nearly_equal(b_copy, 1e-300)), "solve(b) modified its input vector" ); } TEST_CASE(LU_Inplace_Equals_OutOfPlace_Solve_Vector) { using T = double; utils::Matrix A(3,3, T{0}); A(0,0)=4; A(0,1)=1; A(0,2)=2; A(1,0)=0; A(1,1)=3; A(1,2)=-1; A(2,0)=0; A(2,1)=0; A(2,2)=2; utils::Vector b(3, T{0}); b[0]=7; b[1]=5; b[2]=4; decomp::LUdcmpd lu(A); auto x1 = lu.solve(b); utils::Vector x2(3, T{0}); lu.inplace_solve(b, x2); CHECK( (x1.nearly_equal(x2,1e-12)), "inplace_solve(b,x) differs from solve(b)" ); } TEST_CASE(LU_Singular_Throws) { using T = double; // Singular (row2 = 2 * row1) utils::Matrix S(2,2, T{0}); S(0,0)=1; S(0,1)=2; S(1,0)=2; S(1,1)=4; bool threw=false; try { decomp::LUdcmpd lu(S); (void)lu; } catch (const std::runtime_error&) { threw = true; } CHECK(threw, "LU should throw on singular matrix"); } TEST_CASE(LU_NonSquare_Throws) { using T = double; utils::Matrix A(3,2, T{0}); bool threw = false; try { decomp::LUdcmpd lu(A); (void)lu; } catch (const std::runtime_error&) { threw = true; } CHECK(threw, "LU should throw on non-square input"); } TEST_CASE(LU_Inverse_RoundTrip) { using T = double; // Build a strictly diagonally dominant 5x5 utils::Matrix A(5,5, T{0}); for (uint64_t i=0;i<5;++i) { T rowsum = 0; for (uint64_t j=0;j<5;++j) { if (i==j) continue; A(i,j) = T(0.01 * double(1 + ((i+2)*(j+3)) % 7)); rowsum += std::fabs(A(i,j)); } A(i,i) = rowsum + T{1}; } decomp::LUdcmpd lu(A); auto Ainv = lu.inverse(); utils::Md I(5,5, 0.0); for (uint64_t i=0;i