Sync public subset from Flux
This commit is contained in:
169
test/test_lu.cpp
Normal file
169
test/test_lu.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
#include "test_common.h"
|
||||
|
||||
#include "./utils/matrix.h"
|
||||
#include "./utils/vector.h"
|
||||
#include "./numerics/matmul.h"
|
||||
#include "./numerics/matvec.h"
|
||||
|
||||
#include "./decomp/lu.h"
|
||||
|
||||
//#include <chrono>
|
||||
|
||||
// ---------- helpers ----------
|
||||
template <typename T>
|
||||
static bool mats_equal_tol(const utils::Matrix<T>& X,
|
||||
const utils::Matrix<T>& Y,
|
||||
double tol = 1e-12) {
|
||||
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
|
||||
for (std::uint64_t i=0;i<X.rows();++i)
|
||||
for (std::uint64_t j=0;j<X.cols();++j)
|
||||
if (std::fabs(double(X(i,j) - Y(i,j))) > tol) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static utils::Matrix<T> identity(std::uint64_t n) {
|
||||
utils::Matrix<T> I(n,n,T(0));
|
||||
for (std::uint64_t i=0;i<n;++i) I(i,i) = T(1);
|
||||
return I;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void split_LU(const utils::Matrix<T>& lu,
|
||||
utils::Matrix<T>& L,
|
||||
utils::Matrix<T>& U) {
|
||||
const std::uint64_t n = lu.rows();
|
||||
L.resize(n,n,T(0));
|
||||
U.resize(n,n,T(0));
|
||||
for (std::uint64_t i=0;i<n;++i) {
|
||||
for (std::uint64_t j=0;j<n;++j) {
|
||||
if (i>j) L(i,j) = lu(i,j);
|
||||
else if (i==j){ L(i,i) = T(1); U(i,i) = lu(i,i); }
|
||||
else U(i,j) = lu(i,j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static utils::Matrix<T> permutation_from_indx(const std::vector<std::uint64_t>& indx) {
|
||||
const std::uint64_t n = indx.size();
|
||||
auto P = identity<T>(n);
|
||||
// Apply the same sequence of row swaps that was applied during factorization
|
||||
for (std::uint64_t k=0;k<n;++k) {
|
||||
const std::uint64_t imax = indx[k];
|
||||
if (imax != k) P.swap_rows(k, imax);
|
||||
}
|
||||
return P;
|
||||
}
|
||||
|
||||
// A well-conditioned 3x3 (symmetric positive definite)
|
||||
static utils::Matrix<double> make_A_spd() {
|
||||
utils::Matrix<double> A(3,3,0.0);
|
||||
// [ 4 3 0
|
||||
// 3 4 -1
|
||||
// 0 -1 4 ]
|
||||
A(0,0)=4; A(0,1)=3; A(0,2)=0;
|
||||
A(1,0)=3; A(1,1)=4; A(1,2)=-1;
|
||||
A(2,0)=0; A(2,1)=-1; A(2,2)=4;
|
||||
return A;
|
||||
}
|
||||
|
||||
TEST_CASE(LU_PA_equals_LU) {
|
||||
auto A = make_A_spd();
|
||||
decomp::LUdcmpd lu(A);
|
||||
|
||||
utils::Matrix<double> L,U;
|
||||
split_LU(lu.lu, L, U);
|
||||
auto P = permutation_from_indx<double>(lu.indx);
|
||||
|
||||
auto PA = numerics::matmul(P, A);
|
||||
auto LU = numerics::matmul(L, U);
|
||||
|
||||
CHECK(mats_equal_tol(PA, LU, 1e-12), "PA should equal LU");
|
||||
}
|
||||
|
||||
TEST_CASE(LU_Solve_Vector) {
|
||||
auto A = make_A_spd();
|
||||
decomp::LUdcmpd lu(A);
|
||||
|
||||
utils::Vd b(3,0.0);
|
||||
b[0]=1.0; b[1]=2.0; b[2]=3.0;
|
||||
|
||||
auto x = lu.solve(b);
|
||||
auto Ax = numerics::matvec(A, x);
|
||||
|
||||
CHECK(b.nearly_equal_vec(Ax, 1e-12), "A*x should equal b");
|
||||
}
|
||||
|
||||
TEST_CASE(LU_Solve_Matrix_MultiRHS) {
|
||||
auto A = make_A_spd();
|
||||
decomp::LUdcmpd lu(A);
|
||||
|
||||
utils::Matrix<double> B(3,2,0.0);
|
||||
// two RHS columns
|
||||
B(0,0)=1; B(1,0)=2; B(2,0)=3;
|
||||
B(0,1)=4; B(1,1)=5; B(2,1)=6;
|
||||
|
||||
auto X = lu.solve(B); // 3x2
|
||||
|
||||
// Check A*X == B
|
||||
auto AX = numerics::matmul(A, X);
|
||||
CHECK(mats_equal_tol(AX, B, 1e-12), "A*X should equal B");
|
||||
|
||||
// And that column-wise solve agrees
|
||||
utils::Vd b0(3,0.0), b1(3,0.0);
|
||||
for (int i=0;i<3;++i){ b0[i]=B(i,0); b1[i]=B(i,1); }
|
||||
auto x0 = lu.solve(b0);
|
||||
auto x1 = lu.solve(b1);
|
||||
|
||||
CHECK(std::fabs(double(X(0,0)-x0[0]))<1e-12 &&
|
||||
std::fabs(double(X(1,0)-x0[1]))<1e-12 &&
|
||||
std::fabs(double(X(2,0)-x0[2]))<1e-12, "column 0 mismatch");
|
||||
|
||||
CHECK(std::fabs(double(X(0,1)-x1[0]))<1e-12 &&
|
||||
std::fabs(double(X(1,1)-x1[1]))<1e-12 &&
|
||||
std::fabs(double(X(2,1)-x1[2]))<1e-12, "column 1 mismatch");
|
||||
}
|
||||
|
||||
TEST_CASE(LU_Determinant) {
|
||||
auto A = make_A_spd();
|
||||
decomp::LUdcmpd lu(A);
|
||||
// For this A, det = 24
|
||||
double d = lu.det();
|
||||
CHECK(std::fabs(d - 24.0) < 1e-12, "determinant incorrect");
|
||||
}
|
||||
|
||||
TEST_CASE(LU_Inverse_via_SolveI) {
|
||||
auto A = make_A_spd();
|
||||
decomp::LUdcmpd lu(A);
|
||||
|
||||
// Build identity and solve A * X = I
|
||||
auto I = identity<double>(3);
|
||||
auto Inv = lu.solve(I);
|
||||
|
||||
// Check A*Inv == I (and Inv*A == I for good measure)
|
||||
auto AInv = numerics::matmul(A, Inv);
|
||||
auto InvA = numerics::matmul(Inv, A);
|
||||
|
||||
CHECK(mats_equal_tol(AInv, I, 1e-11), "A*Inv should be I");
|
||||
CHECK(mats_equal_tol(InvA, I, 1e-11), "Inv*A should be I");
|
||||
}
|
||||
|
||||
TEST_CASE(LU_NonSquare_Throws) {
|
||||
utils::Matrix<double> A(2,3,1.0);
|
||||
bool threw=false;
|
||||
try { decomp::LUdcmpd lu(A); } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "LU should throw on non-square");
|
||||
}
|
||||
|
||||
TEST_CASE(LU_Singular_Throws) {
|
||||
utils::Matrix<double> A(3,3,0.0);
|
||||
// Make two identical rows
|
||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||
A(1,0)=1; A(1,1)=2; A(1,2)=3;
|
||||
A(2,0)=0; A(2,1)=1; A(2,2)=4;
|
||||
|
||||
bool threw=false;
|
||||
try { decomp::LUdcmpd lu(A); } catch (const std::runtime_error&) { threw = true; }
|
||||
CHECK(threw, "LU should throw on singular matrix");
|
||||
}
|
||||
Reference in New Issue
Block a user