Ready for fvm steady case
This commit is contained in:
BIN
Binary file not shown.
+3
-1
@@ -3,6 +3,8 @@
|
|||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
#include "./utils/matrix.h"
|
#include "./utils/matrix.h"
|
||||||
|
|
||||||
|
#include "./numerics/initializers/eye.h"
|
||||||
|
|
||||||
namespace decomp{
|
namespace decomp{
|
||||||
|
|
||||||
// Stores PA = LU with partial pivoting (row permutations).
|
// Stores PA = LU with partial pivoting (row permutations).
|
||||||
@@ -150,7 +152,7 @@ namespace decomp{
|
|||||||
}
|
}
|
||||||
|
|
||||||
void inplace_inverse(utils::Matrix<T>& Ainv){
|
void inplace_inverse(utils::Matrix<T>& Ainv){
|
||||||
Ainv.eye(rows);
|
numerics::inplace_eye<T>(Ainv);
|
||||||
inplace_solve(Ainv, Ainv);
|
inplace_solve(Ainv, Ainv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,138 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
#include "./core/omp_config.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace numerics {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void inplace_eye(utils::Matrix<T>& A, uint64_t N = 0){
|
||||||
|
|
||||||
|
bool need_full_zero = true;
|
||||||
|
|
||||||
|
if (N != 0){
|
||||||
|
A.resize(N,N,T{0});
|
||||||
|
need_full_zero = false;
|
||||||
|
}else{
|
||||||
|
N = A.rows();
|
||||||
|
if (N != A.cols()) {
|
||||||
|
throw std::runtime_error("inplace_eye: non-square matrix");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 1) Zero the whole matrix if we didn't just resize with zeros
|
||||||
|
if (need_full_zero){
|
||||||
|
for (uint64_t i = 0; i < N; ++i){
|
||||||
|
for (uint64_t j = 0; j < N; ++j){
|
||||||
|
if (i==j){
|
||||||
|
A(i,j) = T{1};
|
||||||
|
}else{
|
||||||
|
A(i,j) = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
for (uint64_t i = 0; i < N; ++i){
|
||||||
|
A(i,i) = T{1};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void inplace_eye_omp(utils::Matrix<T>& A, uint64_t N = 0){
|
||||||
|
|
||||||
|
bool need_full_zero = true;
|
||||||
|
|
||||||
|
if (N != 0){
|
||||||
|
A.resize(N,N,T{0});
|
||||||
|
need_full_zero = false;
|
||||||
|
}else{
|
||||||
|
N = A.rows();
|
||||||
|
if (N != A.cols()) {
|
||||||
|
throw std::runtime_error("inplace_eye_omp: non-square matrix");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1) Zero the whole matrix if we didn't just resize with zeros
|
||||||
|
if (need_full_zero){
|
||||||
|
T* ptr = A.data();
|
||||||
|
uint64_t NN = N*N;
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (uint64_t i = 0; i < NN; ++i){
|
||||||
|
ptr[i] = T{0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2) Set the diagonal to 1
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (uint64_t i = 0; i < N; ++i){
|
||||||
|
A(i,i) = T{1};
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> eye(uint64_t N){
|
||||||
|
utils::Matrix<T> A;
|
||||||
|
inplace_eye(A, N);
|
||||||
|
return A;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> eye_omp(uint64_t N){
|
||||||
|
utils::Matrix<T> A;
|
||||||
|
inplace_eye_omp(A, N);
|
||||||
|
return A;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> eye_omp_auto(uint64_t N){
|
||||||
|
|
||||||
|
uint64_t work = N*N;
|
||||||
|
utils::Matrix<T> A(N,N,T{0});
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||||
|
#else
|
||||||
|
bool can_parallel = false;
|
||||||
|
uint64_t threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (can_parallel || work > threads * 4ull) {
|
||||||
|
inplace_eye_omp(A, 0);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Safe fallback
|
||||||
|
inplace_eye(A, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return A;
|
||||||
|
}
|
||||||
|
// Untested:
|
||||||
|
template <typename T>
|
||||||
|
void inplace_eye_omp_auto(utils::Matrix<T>& A, uint64_t N = 0){
|
||||||
|
|
||||||
|
uint64_t work = N*N;
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||||
|
#else
|
||||||
|
bool can_parallel = false;
|
||||||
|
uint64_t threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (can_parallel || work > threads * 4ull) {
|
||||||
|
inplace_eye_omp(A, 0);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Safe fallback
|
||||||
|
inplace_eye(A, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
|
||||||
|
//#include "./numerics/interpolation1d/interpolation1d_base.h"
|
||||||
|
#include "./numerics/interpolation1d/interpolation1d_barycentric.h"
|
||||||
|
#include "./numerics/interpolation1d/interpolation1d_cubic_spline.h"
|
||||||
|
#include "./numerics/interpolation1d/interpolation1d_linear.h"
|
||||||
|
#include "./numerics/interpolation1d/interpolation1d_polynomial.h"
|
||||||
|
#include "./numerics/interpolation1d/interpolation1d_rational.h"
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "./numerics/interpolation1d_base.h"
|
#include "./numerics/interpolation1d/interpolation1d_base.h"
|
||||||
|
|
||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
#include "./numerics/min.h"
|
#include "./numerics/min.h"
|
||||||
|
|||||||
-2
@@ -43,11 +43,9 @@ namespace numerics{
|
|||||||
T interp(T x){
|
T interp(T x){
|
||||||
int64_t jlo;
|
int64_t jlo;
|
||||||
if (cor){
|
if (cor){
|
||||||
std::cout << "Uses hunt()" << std::endl;
|
|
||||||
jlo = hunt(x);
|
jlo = hunt(x);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
std::cout << "Uses locate()" << std::endl;
|
|
||||||
jlo = locate(x);
|
jlo = locate(x);
|
||||||
}
|
}
|
||||||
return rawinterp(jlo,x);
|
return rawinterp(jlo,x);
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "./numerics/interpolation1d_base.h"
|
#include "./numerics/interpolation1d/interpolation1d_base.h"
|
||||||
|
|
||||||
//#include "./numerics/abs.h"
|
//#include "./numerics/abs.h"
|
||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "./numerics/interpolation1d_base.h"
|
#include "./numerics/interpolation1d/interpolation1d_base.h"
|
||||||
|
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "./numerics/interpolation1d_base.h"
|
#include "./numerics/interpolation1d/interpolation1d_base.h"
|
||||||
|
|
||||||
#include "./numerics/abs.h"
|
#include "./numerics/abs.h"
|
||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "./numerics/interpolation1d_base.h"
|
#include "./numerics/interpolation1d/interpolation1d_base.h"
|
||||||
|
|
||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
#include "./numerics/abs.h"
|
#include "./numerics/abs.h"
|
||||||
|
|||||||
@@ -5,6 +5,8 @@
|
|||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
#include "./utils/matrix.h"
|
#include "./utils/matrix.h"
|
||||||
|
|
||||||
|
#include "./numerics/initializers/eye.h"
|
||||||
|
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
@@ -13,7 +15,7 @@ namespace numerics{
|
|||||||
void inverse_gj(utils::Matrix<T>& A){
|
void inverse_gj(utils::Matrix<T>& A){
|
||||||
//utils::Matrix<T> B(A.rows(),A.cols(), T{0});
|
//utils::Matrix<T> B(A.rows(),A.cols(), T{0});
|
||||||
utils::Matrix<T> B;
|
utils::Matrix<T> B;
|
||||||
B.eye(A.rows());
|
B = eye_omp_auto<T>(A.rows());
|
||||||
|
|
||||||
|
|
||||||
uint64_t icol{0}, irow{0}, rows{A.rows()}, cols{A.cols()};
|
uint64_t icol{0}, irow{0}, rows{A.rows()}, cols{A.cols()};
|
||||||
|
|||||||
@@ -3,8 +3,6 @@
|
|||||||
#include "./decomp/lu.h"
|
#include "./decomp/lu.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "./core/omp_config.h"
|
||||||
|
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
#include "./numerics/abs.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace numerics{
|
||||||
|
|
||||||
|
// -------------- Serial ----------------
|
||||||
|
template <typename T>
|
||||||
|
bool matequal(const utils::Matrix<T>& A, const utils::Matrix<T>& B, double tol = 1e-9) {
|
||||||
|
|
||||||
|
if (A.rows() != B.rows() || A.cols() != B.cols()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool decimal = std::is_floating_point<T>::value;
|
||||||
|
const uint64_t rows=A.rows(), cols=A.cols();
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < rows; ++i)
|
||||||
|
for (uint64_t j = 0; j < cols; ++j)
|
||||||
|
if (decimal) {
|
||||||
|
if (numerics::abs(A(i,j) - B(i,j)) > static_cast<T>(tol)){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (A(i,j) != B(i,j)){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------- Parallel ----------------
|
||||||
|
template <typename T>
|
||||||
|
bool matequal_omp(const utils::Matrix<T>& A, const utils::Matrix<T>& B, double tol = 1e-9) {
|
||||||
|
|
||||||
|
if (A.rows() != B.rows() || A.cols() != B.cols()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool decimal = std::is_floating_point<T>::value;
|
||||||
|
bool eq = true;
|
||||||
|
const uint64_t rows=A.rows(), cols=A.cols();
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2) schedule(static) reduction(&&:eq)
|
||||||
|
for (uint64_t i = 0; i < rows; ++i)
|
||||||
|
for (uint64_t j = 0; j < cols; ++j)
|
||||||
|
if (decimal) {
|
||||||
|
eq = eq && (numerics::abs(A(i,j) - B(i,j)) <= static_cast<T>(tol));
|
||||||
|
} else {
|
||||||
|
eq = eq && (A(i,j) == B(i,j));
|
||||||
|
}
|
||||||
|
return eq;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------- Auto OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
bool matequal_auto(const utils::Matrix<T>& A, const utils::Matrix<T>& B, double tol = 1e-9) {
|
||||||
|
|
||||||
|
if (A.rows() != B.rows() || A.cols() != B.cols()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t work = A.rows() * A.cols();
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||||
|
#else
|
||||||
|
bool can_parallel = false;
|
||||||
|
uint64_t threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (can_parallel || work > threads * 4ull) {
|
||||||
|
return matequal_omp(A,B,tol);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Safe fallback
|
||||||
|
return matequal(A,B,tol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace numerics
|
||||||
@@ -85,21 +85,23 @@ utils::Matrix<T> matmul_auto(const utils::Matrix<T>& A,
|
|||||||
const uint64_t m=A.rows(), p=B.cols();
|
const uint64_t m=A.rows(), p=B.cols();
|
||||||
const uint64_t work = m * p;
|
const uint64_t work = m * p;
|
||||||
|
|
||||||
bool can_parallel = omp_config::omp_parallel_allowed();
|
|
||||||
|
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
int threads = omp_get_max_threads();
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
uint64_t threads = static_cast<uint64_t>(omp_get_max_threads());
|
||||||
#else
|
#else
|
||||||
int threads = 1;
|
bool can_parallel = false;
|
||||||
|
uint64_t threads = 1;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
// Tiny problems: serial is cheapest.
|
// Tiny problems: serial is cheapest.
|
||||||
if (!can_parallel || work < static_cast<uint64_t>(threads)*4ull) {
|
if (!can_parallel || work < threads*4ull) {
|
||||||
return matmul(A,B);
|
return matmul(A,B);
|
||||||
}
|
}
|
||||||
// Plenty of (i,j) work → collapse(2) is a great default.
|
// Plenty of (i,j) work → collapse(2) is a great default.
|
||||||
else if (work >= 8ull * static_cast<uint64_t>(threads)) {
|
else if (work >= 8ull * threads) {
|
||||||
return matmul_collapse_omp(A,B);
|
return matmul_collapse_omp(A,B);
|
||||||
}
|
}
|
||||||
// Many rows and very few columns → rows-only cheaper overhead.
|
// Many rows and very few columns → rows-only cheaper overhead.
|
||||||
|
|||||||
@@ -38,16 +38,24 @@ namespace numerics{
|
|||||||
const uint64_t m = A.rows();
|
const uint64_t m = A.rows();
|
||||||
const uint64_t n = A.cols();
|
const uint64_t n = A.cols();
|
||||||
|
|
||||||
utils::Vector<T> y(m, T{0});
|
utils::Vector<T> y(m, T{0}); // <-- y has length m (rows)
|
||||||
|
|
||||||
|
|
||||||
|
const T* xptr = x.data();
|
||||||
|
const T* Aptr = A.data(); // row-major: A(i,j) == Aptr[i*n + j]
|
||||||
|
|
||||||
|
// Each row i is an independent dot product: y[i] = dot(A[i,*], x)
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (uint64_t i = 0; i < m; ++i) {
|
for (uint64_t i = 0; i < m; ++i) {
|
||||||
T acc = T{0};
|
const T* row = Aptr + i * n; // contiguous row i
|
||||||
|
T acc = T{0};
|
||||||
|
#pragma omp simd reduction(+:acc)
|
||||||
for (uint64_t j = 0; j < n; ++j) {
|
for (uint64_t j = 0; j < n; ++j) {
|
||||||
acc += A(i, j) * x[j];
|
acc += row[j] * xptr[j];
|
||||||
}
|
}
|
||||||
y[i] = acc;
|
y[i] = acc;
|
||||||
}
|
}
|
||||||
|
|
||||||
return y;
|
return y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
// "./numerics/numerics.h"
|
// "./numerics/numerics.h"
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "./numerics/initializers/eye.h"
|
||||||
|
#include "./numerics/matequal.h"
|
||||||
#include "./numerics/transpose.h"
|
#include "./numerics/transpose.h"
|
||||||
#include "./numerics/inverse.h"
|
#include "./numerics/inverse.h"
|
||||||
#include "./numerics/matmul.h"
|
#include "./numerics/matmul.h"
|
||||||
|
|||||||
@@ -3,12 +3,13 @@
|
|||||||
|
|
||||||
|
|
||||||
#include "./utils/matrix.h"
|
#include "./utils/matrix.h"
|
||||||
|
#include "./core/omp_config.h"
|
||||||
|
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void inplace_transpose(utils::Matrix<T>& A){
|
void inplace_transpose_square(utils::Matrix<T>& A){
|
||||||
|
|
||||||
const uint64_t rows = A.rows();
|
const uint64_t rows = A.rows();
|
||||||
const uint64_t cols = A.cols();
|
const uint64_t cols = A.cols();
|
||||||
@@ -27,13 +28,54 @@ namespace numerics{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void inplace_transpose_square_omp(utils::Matrix<T>& A){
|
||||||
|
|
||||||
|
const uint64_t rows = A.rows();
|
||||||
|
const uint64_t cols = A.cols();
|
||||||
|
|
||||||
|
if (rows != cols){
|
||||||
|
throw std::runtime_error("inplace_transpose only valid for square matrices");
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (uint64_t i = 0; i < rows; ++i){
|
||||||
|
for (uint64_t j = i + 1; j < cols; ++j){
|
||||||
|
T tmp = A(j,i);
|
||||||
|
A(j,i) = A(i,j);
|
||||||
|
A(i,j) = tmp;
|
||||||
|
//std::swap(A(j,i), A(i,j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
utils::Matrix<T> transpose(const utils::Matrix<T>& A){
|
utils::Matrix<T> transpose(const utils::Matrix<T>& A){
|
||||||
|
|
||||||
const uint64_t rows = A.rows();
|
const uint64_t rows = A.rows();
|
||||||
const uint64_t cols = A.cols();
|
const uint64_t cols = A.cols();
|
||||||
|
|
||||||
utils::Matrix<T> B(cols, rows, T{});
|
utils::Matrix<T> B(cols, rows, T{0});
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < rows; ++i){
|
||||||
|
for (uint64_t j = 0; j < cols; ++j){
|
||||||
|
B(j,i) = A(i,j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return B;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> transpose_omp(const utils::Matrix<T>& A){
|
||||||
|
|
||||||
|
const uint64_t rows = A.rows();
|
||||||
|
const uint64_t cols = A.cols();
|
||||||
|
|
||||||
|
utils::Matrix<T> B(cols, rows, T{0});
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2) schedule(static)
|
||||||
for (uint64_t i = 0; i < rows; ++i){
|
for (uint64_t i = 0; i < rows; ++i){
|
||||||
for (uint64_t j = 0; j < cols; ++j){
|
for (uint64_t j = 0; j < cols; ++j){
|
||||||
B(j,i) = A(i,j);
|
B(j,i) = A(i,j);
|
||||||
@@ -43,6 +85,69 @@ namespace numerics{
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -------- Auto selectors --------
|
||||||
|
template <typename T>
|
||||||
|
void inplace_transpose_square_auto(utils::Matrix<T>& A) {
|
||||||
|
const uint64_t rows = A.rows(), cols = A.cols();
|
||||||
|
|
||||||
|
if (rows != cols) {
|
||||||
|
throw std::runtime_error("inplace_transpose_auto: only valid for square matrices");
|
||||||
|
}
|
||||||
|
const std::uint64_t work = static_cast<std::uint64_t>((rows * (rows - 1)) / 2); // number of swaps
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
uint64_t threads = static_cast<std::uint64_t>(omp_get_max_threads());
|
||||||
|
#else
|
||||||
|
bool can_parallel = false;
|
||||||
|
uint64_t threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (can_parallel && work > threads * 4ull) {
|
||||||
|
inplace_transpose_square_omp(A);
|
||||||
|
}else {
|
||||||
|
inplace_transpose_square(A);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> transpose_auto(const utils::Matrix<T>& A) {
|
||||||
|
|
||||||
|
const uint64_t rows = A.rows();
|
||||||
|
const uint64_t cols = A.cols();
|
||||||
|
|
||||||
|
uint64_t work = A.rows() * A.cols();
|
||||||
|
|
||||||
|
if (rows==cols){
|
||||||
|
utils::Matrix<T> B(rows, cols, T{0});
|
||||||
|
inplace_transpose_square_auto(B);
|
||||||
|
return B;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
uint64_t threads = static_cast<std::uint64_t>(omp_get_max_threads());
|
||||||
|
#else
|
||||||
|
bool can_parallel = false;
|
||||||
|
uint64_t threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (!can_parallel || work > threads * 4ull) {
|
||||||
|
return transpose_omp(A);
|
||||||
|
} else {
|
||||||
|
return transpose(A);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace numerics
|
} // namespace numerics
|
||||||
|
|
||||||
#endif // _transpose_n_
|
#endif // _transpose_n_
|
||||||
+6
-37
@@ -3,6 +3,11 @@
|
|||||||
|
|
||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
#include <omp.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
namespace utils{
|
namespace utils{
|
||||||
@@ -32,42 +37,6 @@ public:
|
|||||||
T* data() noexcept { return data_.data(); }
|
T* data() noexcept { return data_.data(); }
|
||||||
const T* data() const noexcept { return data_.data(); }
|
const T* data() const noexcept { return data_.data(); }
|
||||||
|
|
||||||
//# MATRIX: equal operator #
|
|
||||||
bool operator==(const Matrix<T>& A) const {
|
|
||||||
if (rows_ != A.rows_ || cols_ != A.cols_) return false;
|
|
||||||
for (uint64_t i = 0; i < rows_; ++i)
|
|
||||||
for (uint64_t j = 0; j < cols_; ++j)
|
|
||||||
if (data_[i*cols_ + j] != A(i,j))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
bool operator!=(const Matrix<T>& A) const {
|
|
||||||
return !(*this == A);
|
|
||||||
}
|
|
||||||
bool nearly_equal(const Matrix<T>& A, T tol = static_cast<T>(1e-9)) const {
|
|
||||||
if (rows_ != A.rows() || cols_ != A.cols()) return false;
|
|
||||||
for (uint64_t i = 0; i < rows_; ++i)
|
|
||||||
for (uint64_t j = 0; j < cols_; ++j) {
|
|
||||||
T a = (*this)(i,j);
|
|
||||||
T b = A(i,j);
|
|
||||||
if (std::is_floating_point<T>::value) {
|
|
||||||
if (std::fabs(a - b) > tol) return false;
|
|
||||||
} else {
|
|
||||||
if (a != b) return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void eye(uint64_t rows_cols){
|
|
||||||
rows_ = cols_ = rows_cols;
|
|
||||||
|
|
||||||
data_.clear();
|
|
||||||
data_.resize(rows_cols*rows_cols, T{0});
|
|
||||||
for (uint64_t i = 0; i < rows_; ++i){
|
|
||||||
data_[i * cols_ + i] = T{1};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void resize(uint64_t rows, uint64_t cols, const T& value = T(0)){
|
void resize(uint64_t rows, uint64_t cols, const T& value = T(0)){
|
||||||
rows_ = rows;
|
rows_ = rows;
|
||||||
@@ -186,7 +155,7 @@ private:
|
|||||||
std::vector<T> data_;
|
std::vector<T> data_;
|
||||||
|
|
||||||
};
|
};
|
||||||
typedef Matrix<int> Mi;
|
typedef Matrix<int64_t> Mi;
|
||||||
typedef Matrix<float> Mf;
|
typedef Matrix<float> Mf;
|
||||||
typedef Matrix<double> Md;
|
typedef Matrix<double> Md;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
#include "./numerics/matequal.h"
|
||||||
|
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
// definitions of the previously-declared members
|
||||||
|
template <typename T>
|
||||||
|
inline bool Matrix<T>::equals(const Matrix<T>& B, T tol) const {
|
||||||
|
return numerics::matequal(*this, B, tol);
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
inline bool Matrix<T>::equals_omp(const Matrix<T>& B, T tol) const {
|
||||||
|
return numerics::matequal_omp(*this, B, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline bool Matrix<T>::equals_auto(const Matrix<T>& B, T tol) const {
|
||||||
|
return numerics::matequal_auto(*this, B, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
@@ -6,6 +6,11 @@
|
|||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <type_traits>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
namespace utils{
|
namespace utils{
|
||||||
//######################################
|
//######################################
|
||||||
//# VECTOR TYPE #
|
//# VECTOR TYPE #
|
||||||
@@ -49,6 +54,10 @@ public:
|
|||||||
v.resize(new_size, value);
|
v.resize(new_size, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
T* data() noexcept { return v.data(); }
|
||||||
|
const T* data() const noexcept { return v.data(); }
|
||||||
|
|
||||||
//###########################################
|
//###########################################
|
||||||
//# VECTOR: == and != #
|
//# VECTOR: == and != #
|
||||||
//###########################################
|
//###########################################
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ TEST_MAIN := $(OBJ_DIR)/test/test_all.o
|
|||||||
# === OpenMP runtime configuration (override-able) ===
|
# === OpenMP runtime configuration (override-able) ===
|
||||||
OMP_PROC_BIND ?= close # close|spread|master
|
OMP_PROC_BIND ?= close # close|spread|master
|
||||||
OMP_PLACES ?= cores # cores|threads|sockets
|
OMP_PLACES ?= cores # cores|threads|sockets
|
||||||
OMP_MAX_LEVELS ?= 2 # 1 = no nested teams; set 2+ to allow nesting
|
OMP_MAX_LEVELS ?= 1 # 1 = no nested teams; set 2+ to allow nesting
|
||||||
OMP_THREADS ?= 16 # e.g. "16" or "8,4" for nested (outer,inner)
|
OMP_THREADS ?= 16 # e.g. "16" or "8,4" for nested (outer,inner)
|
||||||
OMP_DYNAMIC ?= TRUE # TRUE/FALSE: let runtime adjust threads
|
OMP_DYNAMIC ?= TRUE # TRUE/FALSE: let runtime adjust threads
|
||||||
OMP_SCHEDULE ?= STATIC # STATIC recommended for matvec/matmul
|
OMP_SCHEDULE ?= STATIC # STATIC recommended for matvec/matmul
|
||||||
@@ -103,13 +103,20 @@ run: clean-test $(TARGET)
|
|||||||
./$(TARGET)
|
./$(TARGET)
|
||||||
|
|
||||||
# Handy presets
|
# Handy presets
|
||||||
.PHONY: run-single
|
.PHONY: run-single-core
|
||||||
run-single: ## Single-level parallel (good default)
|
run-single-core: ## Single-level one core (good default)
|
||||||
|
$(MAKE) run OMP_MAX_LEVELS=1 OMP_THREADS=1 OMP_PROC_BIND=close OMP_PLACES=cores
|
||||||
|
|
||||||
|
# Handy presets
|
||||||
|
.PHONY: run-multi-core
|
||||||
|
run-multi-core: ## Single-level parallel (good default)
|
||||||
$(MAKE) run OMP_MAX_LEVELS=1 OMP_THREADS=16 OMP_PROC_BIND=close OMP_PLACES=cores
|
$(MAKE) run OMP_MAX_LEVELS=1 OMP_THREADS=16 OMP_PROC_BIND=close OMP_PLACES=cores
|
||||||
|
|
||||||
.PHONY: run-nested
|
.PHONY: run-nested-multi-core
|
||||||
run-nested: ## Two-level nested (outer,inner), adjust to your cores
|
run-nested-multi-core: ## Two-level nested (outer,inner), adjust to your cores
|
||||||
$(MAKE) run OMP_MAX_LEVELS=2 OMP_THREADS="4,8" OMP_PROC_BIND=close OMP_PLACES=cores
|
$(MAKE) run OMP_MAX_LEVELS=2 OMP_THREADS="2,8" OMP_PROC_BIND=close OMP_PLACES=cores
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Optional: print debug info
|
# Optional: print debug info
|
||||||
.PHONY: info
|
.PHONY: info
|
||||||
|
|||||||
-41
@@ -1,41 +0,0 @@
|
|||||||
obj/main.o: src/main.cpp include/./utils/utils.h include/./utils/vector.h \
|
|
||||||
include/./utils/matrix.h include/./numerics/numerics.h \
|
|
||||||
include/./numerics/transpose.h include/./numerics/inverse.h \
|
|
||||||
include/./numerics/inverse/inverse_gauss_jordan.h \
|
|
||||||
include/./numerics/inverse/inverse_lu.h include/./decomp/lu.h \
|
|
||||||
include/./numerics/matmul.h include/./core/omp_config.h \
|
|
||||||
include/./numerics/matvec.h include/./numerics/min.h \
|
|
||||||
include/./numerics/max.h include/./numerics/abs.h \
|
|
||||||
include/./numerics/interpolation1d_base.h \
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_linear.h \
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_polynomial.h \
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_cubic_spline.h \
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_rational.h \
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_barycentric.h \
|
|
||||||
include/./decomp/decomp.h include/./modules/grid1d.h \
|
|
||||||
include/utils/vector.h include/./modules/finitedifference1d.h
|
|
||||||
include/./utils/utils.h:
|
|
||||||
include/./utils/vector.h:
|
|
||||||
include/./utils/matrix.h:
|
|
||||||
include/./numerics/numerics.h:
|
|
||||||
include/./numerics/transpose.h:
|
|
||||||
include/./numerics/inverse.h:
|
|
||||||
include/./numerics/inverse/inverse_gauss_jordan.h:
|
|
||||||
include/./numerics/inverse/inverse_lu.h:
|
|
||||||
include/./decomp/lu.h:
|
|
||||||
include/./numerics/matmul.h:
|
|
||||||
include/./core/omp_config.h:
|
|
||||||
include/./numerics/matvec.h:
|
|
||||||
include/./numerics/min.h:
|
|
||||||
include/./numerics/max.h:
|
|
||||||
include/./numerics/abs.h:
|
|
||||||
include/./numerics/interpolation1d_base.h:
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_linear.h:
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_polynomial.h:
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_cubic_spline.h:
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_rational.h:
|
|
||||||
include/./numerics/interpolation1d/interpolation1d_barycentric.h:
|
|
||||||
include/./decomp/decomp.h:
|
|
||||||
include/./modules/grid1d.h:
|
|
||||||
include/utils/vector.h:
|
|
||||||
include/./modules/finitedifference1d.h:
|
|
||||||
BIN
Binary file not shown.
@@ -0,0 +1,2 @@
|
|||||||
|
obj/test/test_all.o: test/test_all.cpp test/test_common.h
|
||||||
|
test/test_common.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,8 @@
|
|||||||
|
obj/test/test_eye.o: test/test_eye.cpp test/test_common.h \
|
||||||
|
include/./utils/matrix.h include/./utils/vector.h \
|
||||||
|
include/./numerics/initializers/eye.h include/./core/omp_config.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./numerics/initializers/eye.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,24 @@
|
|||||||
|
obj/test/test_interpolation1d.o: test/test_interpolation1d.cpp \
|
||||||
|
test/test_common.h include/./utils/matrix.h include/./utils/vector.h \
|
||||||
|
include/./numerics/interpolation1d.h \
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_barycentric.h \
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_base.h \
|
||||||
|
include/./numerics/min.h include/./numerics/max.h \
|
||||||
|
include/./numerics/abs.h \
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_cubic_spline.h \
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_linear.h \
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_polynomial.h \
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_rational.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./numerics/interpolation1d.h:
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_barycentric.h:
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_base.h:
|
||||||
|
include/./numerics/min.h:
|
||||||
|
include/./numerics/max.h:
|
||||||
|
include/./numerics/abs.h:
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_cubic_spline.h:
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_linear.h:
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_polynomial.h:
|
||||||
|
include/./numerics/interpolation1d/interpolation1d_rational.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,17 @@
|
|||||||
|
obj/test/test_inverse.o: test/test_inverse.cpp test/test_common.h \
|
||||||
|
include/./utils/matrix.h include/./utils/vector.h \
|
||||||
|
include/./numerics/inverse.h \
|
||||||
|
include/./numerics/inverse/inverse_gauss_jordan.h \
|
||||||
|
include/./numerics/initializers/eye.h include/./core/omp_config.h \
|
||||||
|
include/./numerics/inverse/inverse_lu.h include/./decomp/lu.h \
|
||||||
|
include/./numerics/matmul.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./numerics/inverse.h:
|
||||||
|
include/./numerics/inverse/inverse_gauss_jordan.h:
|
||||||
|
include/./numerics/initializers/eye.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
|
include/./numerics/inverse/inverse_lu.h:
|
||||||
|
include/./decomp/lu.h:
|
||||||
|
include/./numerics/matmul.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,13 @@
|
|||||||
|
obj/test/test_lu.o: test/test_lu.cpp test/test_common.h \
|
||||||
|
include/./utils/matrix.h include/./utils/vector.h \
|
||||||
|
include/./numerics/matmul.h include/./core/omp_config.h \
|
||||||
|
include/./numerics/matvec.h include/./decomp/lu.h \
|
||||||
|
include/./numerics/initializers/eye.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./numerics/matmul.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
|
include/./numerics/matvec.h:
|
||||||
|
include/./decomp/lu.h:
|
||||||
|
include/./numerics/initializers/eye.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,10 @@
|
|||||||
|
obj/test/test_matequal.o: test/test_matequal.cpp test/test_common.h \
|
||||||
|
include/./numerics/matequal.h include/./core/omp_config.h \
|
||||||
|
include/./utils/matrix.h include/./utils/vector.h \
|
||||||
|
include/./numerics/abs.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./numerics/matequal.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./numerics/abs.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,10 @@
|
|||||||
|
obj/test/test_matmul.o: test/test_matmul.cpp test/test_common.h \
|
||||||
|
include/./utils/utils.h include/./utils/vector.h \
|
||||||
|
include/./utils/matrix.h include/./numerics/matmul.h \
|
||||||
|
include/./core/omp_config.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/utils.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./numerics/matmul.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,5 @@
|
|||||||
|
obj/test/test_matrix.o: test/test_matrix.cpp test/test_common.h \
|
||||||
|
include/./utils/matrix.h include/./utils/vector.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,8 @@
|
|||||||
|
obj/test/test_matvec.o: test/test_matvec.cpp test/test_common.h \
|
||||||
|
include/./numerics/matvec.h include/./utils/matrix.h \
|
||||||
|
include/./utils/vector.h include/./core/omp_config.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./numerics/matvec.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,8 @@
|
|||||||
|
obj/test/test_transpose.o: test/test_transpose.cpp test/test_common.h \
|
||||||
|
include/./numerics/transpose.h include/./utils/matrix.h \
|
||||||
|
include/./utils/vector.h include/./core/omp_config.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./numerics/transpose.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./core/omp_config.h:
|
||||||
Binary file not shown.
@@ -0,0 +1,7 @@
|
|||||||
|
obj/test/test_vector.o: test/test_vector.cpp test/test_common.h \
|
||||||
|
include/./utils/utils.h include/./utils/vector.h \
|
||||||
|
include/./utils/matrix.h
|
||||||
|
test/test_common.h:
|
||||||
|
include/./utils/utils.h:
|
||||||
|
include/./utils/vector.h:
|
||||||
|
include/./utils/matrix.h:
|
||||||
Binary file not shown.
+64
-1
@@ -10,6 +10,9 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
//#include "./numerics/interpolation/interpolation_linear.h"
|
//#include "./numerics/interpolation/interpolation_linear.h"
|
||||||
|
|
||||||
|
|
||||||
@@ -17,7 +20,67 @@
|
|||||||
|
|
||||||
int main(int argc, char const *argv[])
|
int main(int argc, char const *argv[])
|
||||||
{
|
{
|
||||||
|
utils::Md A;
|
||||||
|
|
||||||
|
/*
|
||||||
|
int hw = omp_get_max_active_levels();
|
||||||
|
if (hw <= 1) return 0;
|
||||||
|
|
||||||
|
const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed
|
||||||
|
utils::Md A(m,k,1), B(k,p,1), C(k,p,1);
|
||||||
|
|
||||||
|
omp_set_max_active_levels(1);
|
||||||
|
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
for (int i = 0; i < m*k*p; ++i){
|
||||||
|
A==B
|
||||||
|
}
|
||||||
|
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
|
||||||
|
omp_set_max_active_levels(2);
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
for (int i = 0; i < m*k*p; ++i){
|
||||||
|
A==B
|
||||||
|
}
|
||||||
|
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
omp_set_num_threads(prev);
|
||||||
|
|
||||||
|
// Must not be notably slower with many threads
|
||||||
|
CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread");
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
utils::Md A(5,5, 1);
|
||||||
|
utils::Md B(5,5, 1);
|
||||||
|
utils::Md C(5,5, 2);
|
||||||
|
|
||||||
|
bool result1 = (A==B);
|
||||||
|
bool result2 = (A==C);
|
||||||
|
|
||||||
|
omp_set_max_active_levels(1):
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; ++i){
|
||||||
|
(A==B)
|
||||||
|
}
|
||||||
|
|
||||||
|
omp_set_max_active_levels(2):
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; ++i){
|
||||||
|
(A==B)
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << result1 << std::endl;
|
||||||
|
std::cout << result2 << std::endl;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
utils::Vector<double> x(100, 0), y(100,0);
|
utils::Vector<double> x(100, 0), y(100,0);
|
||||||
for (uint64_t i = 0; i < 100; ++i){
|
for (uint64_t i = 0; i < 100; ++i){
|
||||||
@@ -78,6 +141,6 @@ int main(int argc, char const *argv[])
|
|||||||
std::cout << rational.interp(p) << std::endl;
|
std::cout << rational.interp(p) << std::endl;
|
||||||
std::cout << barycentric.interp(p) << std::endl;
|
std::cout << barycentric.interp(p) << std::endl;
|
||||||
|
|
||||||
|
*/
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
+1
-1
@@ -35,7 +35,7 @@ int main() {
|
|||||||
for (auto& t : TestRegistry::list()) {
|
for (auto& t : TestRegistry::list()) {
|
||||||
try {
|
try {
|
||||||
t.second();
|
t.second();
|
||||||
std::cout << "[PASS] " << t.first << "\n";
|
//std::cout << "[PASS] " << t.first << "\n";
|
||||||
} catch (const TestFailure& e) {
|
} catch (const TestFailure& e) {
|
||||||
std::cerr << "[FAIL] " << t.first << " -> " << e.what() << "\n";
|
std::cerr << "[FAIL] " << t.first << " -> " << e.what() << "\n";
|
||||||
++fails;
|
++fails;
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
#include "./numerics/initializers/eye.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Tiny helper
|
||||||
|
template <typename T>
|
||||||
|
static bool is_identity(const utils::Matrix<T>& M) {
|
||||||
|
if (M.rows() != M.cols()) return false;
|
||||||
|
for (std::uint64_t i = 0; i < M.rows(); ++i)
|
||||||
|
for (std::uint64_t j = 0; j < M.cols(); ++j)
|
||||||
|
if ((i == j && M(i,j) != T{1}) ||
|
||||||
|
(i != j && M(i,j) != T{0})) return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ============ Basic correctness ============
|
||||||
|
TEST_CASE(Generate_EYE_double) {
|
||||||
|
utils::Matrix<double> I = numerics::eye<double>(5);
|
||||||
|
CHECK_EQ(I.rows(), 5, "rows should be 5");
|
||||||
|
CHECK_EQ(I.cols(), 5, "cols should be 5");
|
||||||
|
CHECK(is_identity(I), "eye<double>(5) is not identity");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Generate_EYE_int) {
|
||||||
|
utils::Matrix<uint64_t> I = numerics::eye<uint64_t>(3);
|
||||||
|
CHECK_EQ(I.rows(), 3, "rows should be 3");
|
||||||
|
CHECK_EQ(I.cols(), 3, "cols should be 3");
|
||||||
|
CHECK(is_identity(I), "eye<int>(3) is not identity");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inplace_EYE_resize_creates_identity) {
|
||||||
|
utils::Md A; // empty
|
||||||
|
numerics::inplace_eye(A, 7); // resize to 7x7 and set identity
|
||||||
|
CHECK_EQ(A.rows(), 7, "rows should be 7");
|
||||||
|
CHECK_EQ(A.cols(), 7, "cols should be 7");
|
||||||
|
CHECK(is_identity(A), "inplace_eye resize did not create identity");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inplace_EYE_inplace_on_square) {
|
||||||
|
utils::Md A(4,4, 42.0); // junk values
|
||||||
|
numerics::inplace_eye(A); // N==0 → must be square, zero + diag
|
||||||
|
CHECK(is_identity(A), "inplace_eye on square did not produce identity");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inplace_EYE_throws_on_non_square) {
|
||||||
|
utils::Md A(2,3, 5.0);
|
||||||
|
bool threw = false;
|
||||||
|
try {
|
||||||
|
numerics::inplace_eye(A); // N==0 and non-square → throws
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
threw = true;
|
||||||
|
}
|
||||||
|
CHECK(threw, "inplace_eye should throw on non-square when N==0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ OpenMP variants (compiled only when -fopenmp is used) ============
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
TEST_CASE(EYE_OMP_matches_serial) {
|
||||||
|
utils::Md I1 = numerics::eye<double>(64);
|
||||||
|
utils::Md I2 = numerics::eye_omp<double>(64);
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < I1.rows(); ++i){
|
||||||
|
for (uint64_t j = 0; j < I1.cols(); ++j){
|
||||||
|
CHECK(I1(i,j) == I2(i,j), "eye_omp != eye");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
utils::Md A(64,64, 3.14);
|
||||||
|
numerics::inplace_eye_omp(A); // N==0 → must be square
|
||||||
|
CHECK(is_identity(A), "inplace_eye_omp did not produce identity");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE(EYE_OMP_speed) {
|
||||||
|
|
||||||
|
uint64_t prev = omp_get_max_threads();
|
||||||
|
if (prev <= 1) return;
|
||||||
|
|
||||||
|
omp_set_num_threads(16);
|
||||||
|
uint64_t N = 16384;
|
||||||
|
|
||||||
|
utils::Matrix<uint8_t> I1(N,N,1), I2(N,N,1);
|
||||||
|
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
numerics::inplace_eye<uint8_t>(I1);
|
||||||
|
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
numerics::inplace_eye_omp<uint8_t>(I2);
|
||||||
|
double t2 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
omp_set_num_threads(prev);
|
||||||
|
|
||||||
|
CHECK(t2 <= t1, "collapse_omp: multi-thread slower than single-thread");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE(EYE_OMP_nested_ok) {
|
||||||
|
// Ensure nested regions are allowed for this test (won't error if not)
|
||||||
|
|
||||||
|
uint64_t prev_levels = omp_get_max_active_levels();
|
||||||
|
uint64_t prev = omp_get_max_threads();
|
||||||
|
|
||||||
|
omp_set_max_active_levels(2);
|
||||||
|
// Outer team size (small); inner will be large when nesting is allowed
|
||||||
|
const int outer_threads = 1;
|
||||||
|
// Inner team size to request during nested-enabled run
|
||||||
|
const int inner_threads = 12;
|
||||||
|
uint64_t N = 4096*8;
|
||||||
|
utils::Matrix<uint8_t> I1(N, N, 7);
|
||||||
|
utils::Matrix<uint8_t> I2(N, N, 7);
|
||||||
|
|
||||||
|
|
||||||
|
// ---------- baseline: inner runs with a small team (no outer region) ----------
|
||||||
|
omp_set_max_active_levels(1); // disable nesting for baseline
|
||||||
|
omp_set_num_threads(outer_threads);
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < 2; ++i){
|
||||||
|
numerics::inplace_eye_omp<uint8_t>(I1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ---------- nested: outer×inner (only one outer thread launches inner) ----------
|
||||||
|
omp_set_max_active_levels(2); // allow one nested level
|
||||||
|
|
||||||
|
#pragma omp parallel num_threads(outer_threads)
|
||||||
|
{
|
||||||
|
#pragma omp single // avoid racing on I1
|
||||||
|
{
|
||||||
|
omp_set_num_threads(inner_threads);
|
||||||
|
for (uint64_t i = 0; i < 2; ++i){
|
||||||
|
numerics::inplace_eye_omp<uint8_t>(I2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
omp_set_max_active_levels(prev_levels);
|
||||||
|
omp_set_num_threads(prev);
|
||||||
|
CHECK(is_identity(I1), "EYE_OMP_nested did not produce identity");
|
||||||
|
CHECK(is_identity(I2), "EYE_OMP_nested did not produce identity");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE(EYE_OMP_auto_is_identity) {
|
||||||
|
auto I = numerics::eye_omp_auto<double>(32);
|
||||||
|
CHECK(is_identity(I), "eye_omp_auto did not produce identity");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
#include "test_common.h"
|
|
||||||
#include "./utils/utils.h"
|
|
||||||
#include "./numerics/inverse.h"
|
|
||||||
#include "./numerics/matmul.h"
|
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_GJ_Basic_3x3) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
// Well-conditioned 3x3
|
|
||||||
A(0,0)=3; A(0,1)=0.2; A(0,2)=-0.1;
|
|
||||||
A(1,0)=0.1; A(1,1)=2.5; A(1,2)=0.3;
|
|
||||||
A(2,0)=-0.2;A(2,1)=0.4; A(2,2)=2.0;
|
|
||||||
|
|
||||||
auto Ainv = numerics::inverse<T>(A, "Gauss-Jordan");
|
|
||||||
utils::Matrix<T> I;
|
|
||||||
I.eye(3);
|
|
||||||
auto prod = numerics::matmul<T>(A, Ainv);
|
|
||||||
|
|
||||||
CHECK(prod.nearly_equal(I, (T)1e-10), "inverse(GJ): A*A^{-1} ≈ I");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_LU_Basic_3x3) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
A(0,0)=3; A(0,1)=0.2; A(0,2)=-0.1;
|
|
||||||
A(1,0)=0.1; A(1,1)=2.5; A(1,2)=0.3;
|
|
||||||
A(2,0)=-0.2;A(2,1)=0.4; A(2,2)=2.0;
|
|
||||||
|
|
||||||
auto Ainv = numerics::inverse<T>(A, "LU");
|
|
||||||
utils::Matrix<T> I;
|
|
||||||
I.eye(3);
|
|
||||||
auto prod = numerics::matmul<T>(A, Ainv);
|
|
||||||
|
|
||||||
CHECK(prod.nearly_equal(I, (T)1e-10), "inverse(LU): A*A^{-1} ≈ I");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_GJ_vs_LU_Consistency) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
A(0,0)=4; A(0,1)=1; A(0,2)=2;
|
|
||||||
A(1,0)=0; A(1,1)=3; A(1,2)=-1;
|
|
||||||
A(2,0)=0; A(2,1)=0; A(2,2)=2;
|
|
||||||
|
|
||||||
auto GJ = numerics::inverse<T>(A, "Gauss-Jordan");
|
|
||||||
auto LU = numerics::inverse<T>(A, "LU");
|
|
||||||
|
|
||||||
CHECK(GJ.nearly_equal(LU, (T)1e-12), "inverse: GJ and LU produce nearly the same result");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE(Inplace_Inverse_LU) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
A(0,0)=3; A(0,1)=0.2; A(0,2)=-0.1;
|
|
||||||
A(1,0)=0.1; A(1,1)=2.5; A(1,2)=0.3;
|
|
||||||
A(2,0)=-0.2;A(2,1)=0.4; A(2,2)=2.0;
|
|
||||||
|
|
||||||
auto Ainv_ref = numerics::inverse<T>(A, "LU"); // out-of-place
|
|
||||||
auto A_copy = A;
|
|
||||||
numerics::inplace_inverse<T>(A_copy, "LU"); // in-place
|
|
||||||
|
|
||||||
CHECK(A_copy.nearly_equal(Ainv_ref, (T)1e-12), "inplace_inverse(LU) equals out-of-place");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inplace_Inverse_GJ) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(2,2, T{0});
|
|
||||||
A(0,0)=2; A(0,1)=1;
|
|
||||||
A(1,0)=1; A(1,1)=3;
|
|
||||||
|
|
||||||
auto Ainv_ref = numerics::inverse<T>(A, "Gauss-Jordan");
|
|
||||||
auto A_copy = A;
|
|
||||||
numerics::inplace_inverse<T>(A_copy, "Gauss-Jordan");
|
|
||||||
|
|
||||||
CHECK(A_copy.nearly_equal(Ainv_ref, (T)1e-12), "inplace_inverse(GJ) equals out-of-place");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_Identity) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> I;
|
|
||||||
I.eye(3);
|
|
||||||
auto invI = numerics::inverse<T>(I, "LU");
|
|
||||||
CHECK(invI.nearly_equal(I, (T)0), "inverse(I) == I");
|
|
||||||
}
|
|
||||||
TEST_CASE(Inverse_NonSquare_Throws) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(2,3, T{0}); // non-square
|
|
||||||
bool threw1=false, threw2=false;
|
|
||||||
try { auto X = numerics::inverse<T>(A, "LU"); (void)X; } catch(...) { threw1=true; }
|
|
||||||
try { numerics::inplace_inverse<T>(A, "Gauss-Jordan"); } catch(...) { threw2=true; }
|
|
||||||
CHECK(threw1 && threw2, "inverse throws on non-square for both methods");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_Singular_Throws) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> S(3,3, T{0});
|
|
||||||
S(0,0)=1; S(0,1)=2; S(0,2)=3;
|
|
||||||
S(1,0)=1; S(1,1)=2; S(1,2)=3; // duplicate row -> singular
|
|
||||||
S(2,0)=0; S(2,1)=1; S(2,2)=0;
|
|
||||||
|
|
||||||
bool threw_gj=false, threw_lu=false;
|
|
||||||
try { auto X = numerics::inverse<T>(S, "Gauss-Jordan"); (void)X; } catch(...) { threw_gj=true; }
|
|
||||||
try { auto X = numerics::inverse<T>(S, "LU"); (void)X; } catch(...) { threw_lu=true; }
|
|
||||||
CHECK(threw_gj && threw_lu, "inverse throws on singular for both methods");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_Unknown_Method_Throws) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(2,2, T{0});
|
|
||||||
A(0,0)=1; A(1,1)=1;
|
|
||||||
bool threw=false;
|
|
||||||
try { auto X = numerics::inverse<T>(A, "Foobar"); (void)X; } catch(...) { threw=true; }
|
|
||||||
CHECK(threw, "inverse unknown method throws");
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
#include "test_common.h"
|
||||||
|
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
#include "./utils/vector.h"
|
||||||
|
|
||||||
|
|
||||||
|
#include "./numerics/interpolation1d.h"
|
||||||
|
|
||||||
|
|
||||||
|
// ------------ helpers ------------
|
||||||
|
template <typename T>
|
||||||
|
static void make_uniform_xy(utils::Vector<T>& x, utils::Vector<T>& y,
|
||||||
|
std::uint64_t N, T x0, T dx,
|
||||||
|
T a, T b) {
|
||||||
|
x.resize(N, T(0));
|
||||||
|
y.resize(N, T(0));
|
||||||
|
for (std::uint64_t i=0; i<N; ++i) {
|
||||||
|
T xi = x0 + dx * T(i);
|
||||||
|
x[i] = xi;
|
||||||
|
y[i] = a*xi + b; // y = a x + b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
static void make_uniform_xy_fun(utils::Vector<T>& x, utils::Vector<T>& y,
|
||||||
|
std::uint64_t N, T x0, T dx,
|
||||||
|
T (*f)(T)) {
|
||||||
|
x.resize(N, T(0));
|
||||||
|
y.resize(N, T(0));
|
||||||
|
for (std::uint64_t i=0; i<N; ++i) {
|
||||||
|
T xi = x0 + dx * T(i);
|
||||||
|
x[i] = xi;
|
||||||
|
y[i] = f(xi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
static inline bool almost_eq(T a, T b, double tol=1e-12) {
|
||||||
|
return std::fabs(double(a)-double(b)) <= tol;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------ tests ------------
|
||||||
|
|
||||||
|
// 1) All interpolants reproduce the data points exactly (or to tiny FP error)
|
||||||
|
TEST_CASE(Interp_Reproduces_Data_Nodes) {
|
||||||
|
const std::uint64_t N = 25;
|
||||||
|
utils::Vd x, y;
|
||||||
|
// linear data: y = 0.1 x + 0.9 on x = 1..25
|
||||||
|
make_uniform_xy<double>(x, y, N, 1.0, 1.0, 0.1, 0.9);
|
||||||
|
|
||||||
|
numerics::interp_linear<double> lin(x,y);
|
||||||
|
numerics::interp_polynomial<double> pol(x,y, 3);
|
||||||
|
numerics::interp_cubic_spline<double> spl(x,y);
|
||||||
|
numerics::interp_rational<double> rat(x,y, 2);
|
||||||
|
numerics::interp_barycentric<double> bar(x,y, 2);
|
||||||
|
|
||||||
|
for (std::uint64_t i=0; i<N; ++i) {
|
||||||
|
double xi = x[i];
|
||||||
|
double yi = y[i];
|
||||||
|
CHECK(almost_eq(lin.interp(xi), yi), "linear fails at node");
|
||||||
|
CHECK(almost_eq(pol.interp(xi), yi), "polynomial fails at node");
|
||||||
|
CHECK(almost_eq(spl.interp(xi), yi), "spline fails at node");
|
||||||
|
CHECK(almost_eq(rat.interp(xi), yi), "rational fails at node");
|
||||||
|
CHECK(almost_eq(bar.interp(xi), yi), "barycentric fails at node");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) Linear dataset: all interpolants match the analytic line at off-grid points
|
||||||
|
TEST_CASE(Interp_Linear_Dataset_OffGrid) {
|
||||||
|
const std::uint64_t N = 100;
|
||||||
|
utils::Vd x, y;
|
||||||
|
// your quick test dataset: x = 1..100, y = 0.1 x + 0.9
|
||||||
|
make_uniform_xy<double>(x, y, N, 1.0, 1.0, 0.1, 0.9);
|
||||||
|
|
||||||
|
numerics::interp_linear<double> lin(x,y);
|
||||||
|
numerics::interp_polynomial<double> pol(x,y, 3);
|
||||||
|
numerics::interp_cubic_spline<double> spl(x,y);
|
||||||
|
numerics::interp_rational<double> rat(x,y, 3);
|
||||||
|
numerics::interp_barycentric<double> bar(x,y, 3);
|
||||||
|
|
||||||
|
std::vector<double> qs = { 1.5, 5.5, 5.51, 50.01, 99.9 };
|
||||||
|
for (double q : qs) {
|
||||||
|
const double ytrue = 0.1*q + 0.9;
|
||||||
|
CHECK(almost_eq(lin.interp(q), ytrue, 1e-9), "linear off-grid mismatch");
|
||||||
|
CHECK(almost_eq(pol.interp(q), ytrue, 1e-9), "polynomial off-grid mismatch");
|
||||||
|
CHECK(almost_eq(spl.interp(q), ytrue, 1e-9), "spline off-grid mismatch");
|
||||||
|
CHECK(almost_eq(rat.interp(q), ytrue, 1e-9), "rational off-grid mismatch");
|
||||||
|
CHECK(almost_eq(bar.interp(q), ytrue, 1e-9), "barycentric off-grid mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
// endpoints should match exactly (no extrapolation)
|
||||||
|
CHECK(almost_eq(lin.interp(x[0]), y[0]), "linear endpoint");
|
||||||
|
CHECK(almost_eq(lin.interp(x[N-1]), y[N-1]), "linear endpoint");
|
||||||
|
CHECK(almost_eq(spl.interp(x[0]), y[0]), "spline endpoint");
|
||||||
|
CHECK(almost_eq(spl.interp(x[N-1]), y[N-1]), "spline endpoint");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) Quadratic dataset: degree-2 polynomial interpolation should be exact
|
||||||
|
static double f_quad(double t) { return t*t - 3.0*t + 2.0; } // (t-1)(t-2)
|
||||||
|
|
||||||
|
TEST_CASE(Interp_Polynomial_Deg2_Exact_On_Quadratic) {
|
||||||
|
const std::uint64_t N = 21;
|
||||||
|
utils::Vd x, y;
|
||||||
|
make_uniform_xy_fun<double>(x, y, N, -2.0, 0.5, &f_quad);
|
||||||
|
|
||||||
|
numerics::interp_polynomial<double> pol(x,y, 3);
|
||||||
|
|
||||||
|
std::vector<double> qs = { -1.75, -0.1, 0.3, 1.4, 3.9, 7.25 };
|
||||||
|
for (double q : qs) {
|
||||||
|
// only test inside the data domain to avoid extrap behavior
|
||||||
|
if (q >= x[0] && q <= x[N-1]) {
|
||||||
|
const double ytrue = f_quad(q);
|
||||||
|
//std::cout << pol.interp(q) << ", " << ytrue << ", "<< q <<std::endl;
|
||||||
|
CHECK(almost_eq(pol.interp(q), ytrue, 1e-10), "degree-2 polynomial should be exact on quadratic");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) Cross-method consistency: methods agree closely on smooth data (cosine wave)
|
||||||
|
static double f_cos(double t) { return std::cos(0.1*t); }
|
||||||
|
|
||||||
|
TEST_CASE(Interp_Cross_Method_Consistency) {
|
||||||
|
const std::uint64_t N = 60;
|
||||||
|
utils::Vd x, y;
|
||||||
|
make_uniform_xy_fun<double>(x, y, N, 0.0, 0.5, &f_cos);
|
||||||
|
|
||||||
|
numerics::interp_linear<double> lin(x,y);
|
||||||
|
numerics::interp_polynomial<double> pol(x,y, 3); // small local degree
|
||||||
|
numerics::interp_cubic_spline<double> spl(x,y);
|
||||||
|
numerics::interp_rational<double> rat(x,y, 3);
|
||||||
|
numerics::interp_barycentric<double> bar(x,y, 3);
|
||||||
|
|
||||||
|
// sample a handful of interior points and require all methods to be mutually close
|
||||||
|
std::vector<double> qs = { 1.25, 7.75, 12.1, 18.6, 22.75 };
|
||||||
|
for (double q : qs) {
|
||||||
|
// skip if q is outside just in case
|
||||||
|
if (q < x[0] || q > x[N-1]) continue;
|
||||||
|
|
||||||
|
double yl = lin.interp(q);
|
||||||
|
double yp = pol.interp(q);
|
||||||
|
double ys = spl.interp(q);
|
||||||
|
double yr = rat.interp(q);
|
||||||
|
double yb = bar.interp(q);
|
||||||
|
|
||||||
|
//std::cout << "lin: " << yl << std::endl;
|
||||||
|
//std::cout << "pol: " << yp << std::endl;
|
||||||
|
//std::cout << "spl: " << ys << std::endl;
|
||||||
|
//std::cout << "rat: " << yr << std::endl;
|
||||||
|
//std::cout << "bar: " << yb << std::endl;
|
||||||
|
|
||||||
|
// spline is usually the smoothest; use it as the anchor
|
||||||
|
CHECK(almost_eq(yl, ys, 5e-4), "linear vs spline");
|
||||||
|
CHECK(almost_eq(yp, ys, 5e-4), "polynomial vs spline");
|
||||||
|
CHECK(almost_eq(yr, ys, 5e-4), "rational vs spline");
|
||||||
|
CHECK(almost_eq(yb, ys, 5e-4), "barycentric vs spline");
|
||||||
|
}
|
||||||
|
}
|
||||||
+120
-94
@@ -1,115 +1,141 @@
|
|||||||
#include "test_common.h"
|
#include "test_common.h"
|
||||||
#include "./utils/utils.h"
|
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
#include "./utils/vector.h"
|
||||||
#include "./numerics/inverse.h"
|
#include "./numerics/inverse.h"
|
||||||
#include "./numerics/matmul.h"
|
#include "./numerics/matmul.h"
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_GJ_Basic_3x3) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
// Well-conditioned 3x3
|
|
||||||
A(0,0)=3; A(0,1)=0.2; A(0,2)=-0.1;
|
|
||||||
A(1,0)=0.1; A(1,1)=2.5; A(1,2)=0.3;
|
|
||||||
A(2,0)=-0.2;A(2,1)=0.4; A(2,2)=2.0;
|
|
||||||
|
|
||||||
auto Ainv = numerics::inverse<T>(A, "Gauss-Jordan");
|
// ---------- helpers ----------
|
||||||
utils::Matrix<T> I;
|
template <typename T>
|
||||||
I.eye(3);
|
static utils::Matrix<T> identity(std::uint64_t n) {
|
||||||
auto prod = numerics::matmul<T>(A, Ainv);
|
utils::Matrix<T> I(n,n,T(0));
|
||||||
|
for (std::uint64_t i=0;i<n;++i) I(i,i) = T(1);
|
||||||
CHECK(prod.nearly_equal(I, (T)1e-10), "inverse(GJ): A*A^{-1} ≈ I");
|
return I;
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_LU_Basic_3x3) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
A(0,0)=3; A(0,1)=0.2; A(0,2)=-0.1;
|
|
||||||
A(1,0)=0.1; A(1,1)=2.5; A(1,2)=0.3;
|
|
||||||
A(2,0)=-0.2;A(2,1)=0.4; A(2,2)=2.0;
|
|
||||||
|
|
||||||
auto Ainv = numerics::inverse<T>(A, "LU");
|
|
||||||
utils::Matrix<T> I;
|
|
||||||
I.eye(3);
|
|
||||||
auto prod = numerics::matmul<T>(A, Ainv);
|
|
||||||
|
|
||||||
CHECK(prod.nearly_equal(I, (T)1e-10), "inverse(LU): A*A^{-1} ≈ I");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_GJ_vs_LU_Consistency) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
|
||||||
A(0,0)=4; A(0,1)=1; A(0,2)=2;
|
|
||||||
A(1,0)=0; A(1,1)=3; A(1,2)=-1;
|
|
||||||
A(2,0)=0; A(2,1)=0; A(2,2)=2;
|
|
||||||
|
|
||||||
auto GJ = numerics::inverse<T>(A, "Gauss-Jordan");
|
|
||||||
auto LU = numerics::inverse<T>(A, "LU");
|
|
||||||
|
|
||||||
CHECK(GJ.nearly_equal(LU, (T)1e-12), "inverse: GJ and LU produce nearly the same result");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_CASE(Inplace_Inverse_LU) {
|
template <typename T>
|
||||||
using T = double;
|
static bool mats_equal_tol(const utils::Matrix<T>& X,
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
const utils::Matrix<T>& Y,
|
||||||
A(0,0)=3; A(0,1)=0.2; A(0,2)=-0.1;
|
double tol = 1e-12) {
|
||||||
A(1,0)=0.1; A(1,1)=2.5; A(1,2)=0.3;
|
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
|
||||||
A(2,0)=-0.2;A(2,1)=0.4; A(2,2)=2.0;
|
for (std::uint64_t i=0;i<X.rows();++i)
|
||||||
|
for (std::uint64_t j=0;j<X.cols();++j)
|
||||||
auto Ainv_ref = numerics::inverse<T>(A, "LU"); // out-of-place
|
if (std::fabs(double(X(i,j) - Y(i,j))) > tol) return false;
|
||||||
auto A_copy = A;
|
return true;
|
||||||
numerics::inplace_inverse<T>(A_copy, "LU"); // in-place
|
|
||||||
|
|
||||||
CHECK(A_copy.nearly_equal(Ainv_ref, (T)1e-12), "inplace_inverse(LU) equals out-of-place");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(Inplace_Inverse_GJ) {
|
// A small well-conditioned SPD 3x3
|
||||||
using T = double;
|
static utils::Matrix<double> make_A3() {
|
||||||
utils::Matrix<T> A(2,2, T{0});
|
utils::Matrix<double> A(3,3,0.0);
|
||||||
A(0,0)=2; A(0,1)=1;
|
// [ 4 3 0
|
||||||
A(1,0)=1; A(1,1)=3;
|
// 3 4 -1
|
||||||
|
// 0 -1 4 ]
|
||||||
auto Ainv_ref = numerics::inverse<T>(A, "Gauss-Jordan");
|
A(0,0)=4; A(0,1)=3; A(0,2)=0;
|
||||||
auto A_copy = A;
|
A(1,0)=3; A(1,1)=4; A(1,2)=-1;
|
||||||
numerics::inplace_inverse<T>(A_copy, "Gauss-Jordan");
|
A(2,0)=0; A(2,1)=-1; A(2,2)=4;
|
||||||
|
return A;
|
||||||
CHECK(A_copy.nearly_equal(Ainv_ref, (T)1e-12), "inplace_inverse(GJ) equals out-of-place");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(Inverse_Identity) {
|
// A random-ish SPD 5x5 (constructed deterministically)
|
||||||
using T = double;
|
static utils::Matrix<double> make_A5_spd() {
|
||||||
utils::Matrix<T> I;
|
utils::Matrix<double> R(5,5,0.0);
|
||||||
I.eye(3);
|
// Fill R with a simple pattern
|
||||||
auto invI = numerics::inverse<T>(I, "LU");
|
double v=1.0;
|
||||||
CHECK(invI.nearly_equal(I, (T)0), "inverse(I) == I");
|
for (std::uint64_t i=0;i<5;++i)
|
||||||
|
for (std::uint64_t j=0;j<5;++j, v+=0.37)
|
||||||
|
R(i,j) = std::fmod(v, 3.0) - 1.0; // values in [-1,2)
|
||||||
|
// A = R^T R + 5 I → SPD and well-conditioned
|
||||||
|
utils::Matrix<double> Rt(5,5,0.0);
|
||||||
|
for (std::uint64_t i=0;i<5;++i)
|
||||||
|
for (std::uint64_t j=0;j<5;++j)
|
||||||
|
Rt(i,j) = R(j,i);
|
||||||
|
auto RtR = numerics::matmul(Rt, R);
|
||||||
|
for (std::uint64_t i=0;i<5;++i) RtR(i,i) += 5.0;
|
||||||
|
return RtR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------- tests ----------
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_GJ_3x3) {
|
||||||
|
auto A = make_A3();
|
||||||
|
auto A_copy = A;
|
||||||
|
|
||||||
|
auto Inv = numerics::inverse(A, "Gauss-Jordan"); // non-inplace
|
||||||
|
auto I = identity<double>(3);
|
||||||
|
|
||||||
|
auto AInv = numerics::matmul(A, Inv);
|
||||||
|
auto InvA = numerics::matmul(Inv, A);
|
||||||
|
|
||||||
|
CHECK(mats_equal_tol(AInv, I, 1e-11), "A*Inv != I (Gauss-Jordan)");
|
||||||
|
CHECK(mats_equal_tol(InvA, I, 1e-11), "Inv*A != I (Gauss-Jordan)");
|
||||||
|
|
||||||
|
// A should be unchanged by numerics::inverse (it copies internally)
|
||||||
|
CHECK(mats_equal_tol(A, A_copy, 1e-15), "inverse() modified input A");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_LU_3x3) {
|
||||||
|
auto A = make_A3();
|
||||||
|
|
||||||
|
auto Inv = numerics::inverse(A, "LU");
|
||||||
|
auto I = identity<double>(3);
|
||||||
|
|
||||||
|
auto AInv = numerics::matmul(A, Inv);
|
||||||
|
auto InvA = numerics::matmul(Inv, A);
|
||||||
|
|
||||||
|
CHECK(mats_equal_tol(AInv, I, 1e-11), "A*Inv != I (LU)");
|
||||||
|
CHECK(mats_equal_tol(InvA, I, 1e-11), "Inv*A != I (LU)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inplace_Inverse_Both_Methods_Agree_5x5) {
|
||||||
|
auto A = make_A5_spd();
|
||||||
|
|
||||||
|
auto GJ = A; numerics::inplace_inverse(GJ, "Gauss-Jordan");
|
||||||
|
auto LU = A; numerics::inplace_inverse(LU, "LU");
|
||||||
|
|
||||||
|
// Both should be valid inverses
|
||||||
|
auto I = identity<double>(5);
|
||||||
|
CHECK(mats_equal_tol(numerics::matmul(A, GJ), I, 1e-10), "A*GJ != I");
|
||||||
|
CHECK(mats_equal_tol(numerics::matmul(GJ, A), I, 1e-10), "GJ*A != I");
|
||||||
|
CHECK(mats_equal_tol(numerics::matmul(A, LU), I, 1e-10), "A*LU != I");
|
||||||
|
CHECK(mats_equal_tol(numerics::matmul(LU, A), I, 1e-10), "LU*A != I");
|
||||||
|
|
||||||
|
// And they should be very close to each other
|
||||||
|
CHECK(mats_equal_tol(GJ, LU, 1e-10), "Gauss-Jordan inverse != LU inverse");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE(Inverse_NonSquare_Throws) {
|
TEST_CASE(Inverse_NonSquare_Throws) {
|
||||||
using T = double;
|
utils::Matrix<double> A(2,3,1.0);
|
||||||
utils::Matrix<T> A(2,3, T{0}); // non-square
|
bool threw=false;
|
||||||
bool threw1=false, threw2=false;
|
try { auto B = numerics::inverse(A, "Gauss-Jordan"); (void)B; } catch(const std::runtime_error&) { threw=true; }
|
||||||
try { auto X = numerics::inverse<T>(A, "LU"); (void)X; } catch(...) { threw1=true; }
|
CHECK(threw, "inverse should throw on non-square (Gauss-Jordan)");
|
||||||
try { numerics::inplace_inverse<T>(A, "Gauss-Jordan"); } catch(...) { threw2=true; }
|
|
||||||
CHECK(threw1 && threw2, "inverse throws on non-square for both methods");
|
threw=false;
|
||||||
|
try { numerics::inplace_inverse(A, "LU"); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "inplace_inverse should throw on non-square (LU)");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(Inverse_Singular_Throws) {
|
TEST_CASE(Inverse_Singular_Throws) {
|
||||||
using T = double;
|
utils::Matrix<double> A(3,3,0.0);
|
||||||
utils::Matrix<T> S(3,3, T{0});
|
// Two identical rows → singular
|
||||||
S(0,0)=1; S(0,1)=2; S(0,2)=3;
|
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||||
S(1,0)=1; S(1,1)=2; S(1,2)=3; // duplicate row -> singular
|
A(1,0)=1; A(1,1)=2; A(1,2)=3;
|
||||||
S(2,0)=0; S(2,1)=1; S(2,2)=0;
|
A(2,0)=0; A(2,1)=1; A(2,2)=4;
|
||||||
|
|
||||||
bool threw_gj=false, threw_lu=false;
|
|
||||||
try { auto X = numerics::inverse<T>(S, "Gauss-Jordan"); (void)X; } catch(...) { threw_gj=true; }
|
|
||||||
try { auto X = numerics::inverse<T>(S, "LU"); (void)X; } catch(...) { threw_lu=true; }
|
|
||||||
CHECK(threw_gj && threw_lu, "inverse throws on singular for both methods");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Inverse_Unknown_Method_Throws) {
|
|
||||||
using T = double;
|
|
||||||
utils::Matrix<T> A(2,2, T{0});
|
|
||||||
A(0,0)=1; A(1,1)=1;
|
|
||||||
bool threw=false;
|
bool threw=false;
|
||||||
try { auto X = numerics::inverse<T>(A, "Foobar"); (void)X; } catch(...) { threw=true; }
|
try { auto B = numerics::inverse(A, "Gauss-Jordan"); (void)B; } catch(const std::runtime_error&) { threw=true; }
|
||||||
CHECK(threw, "inverse unknown method throws");
|
CHECK(threw, "inverse(GJ) should throw on singular");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try { auto B = numerics::inverse(A, "LU"); (void)B; } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "inverse(LU) should throw on singular");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_Invalid_Method_Throws) {
|
||||||
|
auto A = make_A3();
|
||||||
|
bool threw=false;
|
||||||
|
try { auto B = numerics::inverse(A, "NotAThing"); (void)B; } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "inverse should throw on unknown method");
|
||||||
}
|
}
|
||||||
+138
-159
@@ -1,190 +1,169 @@
|
|||||||
#include "test_common.h"
|
#include "test_common.h"
|
||||||
|
|
||||||
#include "./utils/utils.h" // brings in vector.h, matrix.h, etc.
|
#include "./utils/matrix.h"
|
||||||
#include "./numerics/matmul.h" // numerics::matmul
|
#include "./utils/vector.h"
|
||||||
|
#include "./numerics/matmul.h"
|
||||||
|
#include "./numerics/matvec.h"
|
||||||
|
|
||||||
#include "./decomp/lu.h"
|
#include "./decomp/lu.h"
|
||||||
|
|
||||||
//#include <chrono>
|
//#include <chrono>
|
||||||
|
|
||||||
|
// ---------- helpers ----------
|
||||||
TEST_CASE(LU_Solve_Vector_Basic) {
|
template <typename T>
|
||||||
using T = double;
|
static bool mats_equal_tol(const utils::Matrix<T>& X,
|
||||||
|
const utils::Matrix<T>& Y,
|
||||||
// A * x = b with exact solution x = [1, 1, 2]^T
|
double tol = 1e-12) {
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
|
||||||
A(0,0)=2; A(0,1)=1; A(0,2)=1;
|
for (std::uint64_t i=0;i<X.rows();++i)
|
||||||
A(1,0)=4; A(1,1)=-6; A(1,2)=0;
|
for (std::uint64_t j=0;j<X.cols();++j)
|
||||||
A(2,0)=-2; A(2,1)=7; A(2,2)=2;
|
if (std::fabs(double(X(i,j) - Y(i,j))) > tol) return false;
|
||||||
|
return true;
|
||||||
utils::Vector<T> b(3, T{0});
|
|
||||||
b[0]=5; b[1]=-2; b[2]=9;
|
|
||||||
|
|
||||||
decomp::LUdcmpd lu(A);
|
|
||||||
auto x = lu.solve(b);
|
|
||||||
|
|
||||||
utils::Vector<T> x_true(3, T{0});
|
|
||||||
x_true[0]=1; x_true[1]=1; x_true[2]=2;
|
|
||||||
|
|
||||||
CHECK( (x.nearly_equal(x_true,1e-12)), "LU solve (vector RHS) failed" );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(LU_Solve_MatrixRHS_TwoColumns) {
|
template <typename T>
|
||||||
using T = double;
|
static utils::Matrix<T> identity(std::uint64_t n) {
|
||||||
|
utils::Matrix<T> I(n,n,T(0));
|
||||||
// Same A, solve two RHS at once
|
for (std::uint64_t i=0;i<n;++i) I(i,i) = T(1);
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
return I;
|
||||||
A(0,0)=2; A(0,1)=1; A(0,2)=1;
|
|
||||||
A(1,0)=4; A(1,1)=-6; A(1,2)=0;
|
|
||||||
A(2,0)=-2; A(2,1)=7; A(2,2)=2;
|
|
||||||
|
|
||||||
utils::Matrix<T> B(3,2, T{0});
|
|
||||||
// First column b1 (same as previous test)
|
|
||||||
B(0,0)=5; B(1,0)=-2; B(2,0)=9;
|
|
||||||
// Second column b2 → choose solution x2 = [0, 2, 1]^T
|
|
||||||
// Compute b2 = A * x2 by hand:
|
|
||||||
// Row0: 2*0 + 1*2 + 1*1 = 3
|
|
||||||
// Row1: 4*0 -6*2 + 0*1 = -12
|
|
||||||
// Row2: -2*0 +7*2 + 2*1 = 16
|
|
||||||
B(0,1)=3; B(1,1)=-12; B(2,1)=16;
|
|
||||||
|
|
||||||
decomp::LUdcmpd lu(A);
|
|
||||||
auto X = lu.solve(B);
|
|
||||||
|
|
||||||
// Check A*X ≈ B
|
|
||||||
auto AX = numerics::matmul(A, X);
|
|
||||||
CHECK( AX.nearly_equal(B, 1e-12), "A * X does not match B for matrix RHS" );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
TEST_CASE(LU_Determinant_Known) {
|
static void split_LU(const utils::Matrix<T>& lu,
|
||||||
using T = double;
|
utils::Matrix<T>& L,
|
||||||
|
utils::Matrix<T>& U) {
|
||||||
// Determinant of:
|
const std::uint64_t n = lu.rows();
|
||||||
// [[1,2,3],[0,1,4],[5,6,0]] is 1
|
L.resize(n,n,T(0));
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
U.resize(n,n,T(0));
|
||||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
for (std::uint64_t i=0;i<n;++i) {
|
||||||
A(1,0)=0; A(1,1)=1; A(1,2)=4;
|
for (std::uint64_t j=0;j<n;++j) {
|
||||||
A(2,0)=5; A(2,1)=6; A(2,2)=0;
|
if (i>j) L(i,j) = lu(i,j);
|
||||||
|
else if (i==j){ L(i,i) = T(1); U(i,i) = lu(i,i); }
|
||||||
decomp::LUdcmpd lu(A);
|
else U(i,j) = lu(i,j);
|
||||||
T det = lu.det();
|
|
||||||
CHECK( std::fabs(det - T{1}) < 1e-12, "det(A) should be 1" );
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(LU_Pivoting_Handles_Zero_Leading) {
|
|
||||||
using T = double;
|
|
||||||
|
|
||||||
// Requires pivoting (A(0,0)=0); system has solution x=[1,2]^T, b = A*x = [2,3]^T
|
|
||||||
utils::Matrix<T> A(2,2, T{0});
|
|
||||||
A(0,0)=0; A(0,1)=1;
|
|
||||||
A(1,0)=1; A(1,1)=1;
|
|
||||||
|
|
||||||
utils::Vector<T> b(2, T{0});
|
|
||||||
b[0]=2; b[1]=3;
|
|
||||||
|
|
||||||
decomp::LUdcmpd lu(A);
|
|
||||||
auto x = lu.solve(b);
|
|
||||||
|
|
||||||
utils::Vector<T> x_true(2, T{0}); x_true[0]=1; x_true[1]=2;
|
|
||||||
CHECK( (x.nearly_equal(x_true,1e-12)), "Pivoting failed on zero-leading matrix" );
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(LU_Input_Unchanged_By_NonInplace_Path) {
|
|
||||||
using T = double;
|
|
||||||
|
|
||||||
utils::Matrix<T> A(4,4, T{0});
|
|
||||||
for (uint64_t i=0;i<4;++i) {
|
|
||||||
for (uint64_t j=0;j<4;++j) {
|
|
||||||
A(i,j) = (i==j) ? 3.0 : 0.1 * ((i+1)*(j+2) % 5 + 1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::Matrix<T> A_copy = A;
|
|
||||||
|
|
||||||
decomp::LUdcmpd lu(A); // constructor should not mutate input A
|
|
||||||
CHECK( A.nearly_equal(A_copy, 0.0), "LU constructor modified input matrix" );
|
|
||||||
|
|
||||||
// Also check solve doesn't mutate RHS copy when using out-of-place convenience
|
|
||||||
utils::Vector<T> b(4, 0.0);
|
|
||||||
for (uint64_t i=0;i<4;++i) b[i] = double(i+1);
|
|
||||||
auto b_copy = b;
|
|
||||||
|
|
||||||
auto x = lu.solve(b);
|
|
||||||
(void)x;
|
|
||||||
CHECK( (b.nearly_equal(b_copy, 1e-300)), "solve(b) modified its input vector" );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(LU_Inplace_Equals_OutOfPlace_Solve_Vector) {
|
template <typename T>
|
||||||
using T = double;
|
static utils::Matrix<T> permutation_from_indx(const std::vector<std::uint64_t>& indx) {
|
||||||
|
const std::uint64_t n = indx.size();
|
||||||
|
auto P = identity<T>(n);
|
||||||
|
// Apply the same sequence of row swaps that was applied during factorization
|
||||||
|
for (std::uint64_t k=0;k<n;++k) {
|
||||||
|
const std::uint64_t imax = indx[k];
|
||||||
|
if (imax != k) P.swap_rows(k, imax);
|
||||||
|
}
|
||||||
|
return P;
|
||||||
|
}
|
||||||
|
|
||||||
utils::Matrix<T> A(3,3, T{0});
|
// A well-conditioned 3x3 (symmetric positive definite)
|
||||||
A(0,0)=4; A(0,1)=1; A(0,2)=2;
|
static utils::Matrix<double> make_A_spd() {
|
||||||
A(1,0)=0; A(1,1)=3; A(1,2)=-1;
|
utils::Matrix<double> A(3,3,0.0);
|
||||||
A(2,0)=0; A(2,1)=0; A(2,2)=2;
|
// [ 4 3 0
|
||||||
|
// 3 4 -1
|
||||||
utils::Vector<T> b(3, T{0}); b[0]=7; b[1]=5; b[2]=4;
|
// 0 -1 4 ]
|
||||||
|
A(0,0)=4; A(0,1)=3; A(0,2)=0;
|
||||||
|
A(1,0)=3; A(1,1)=4; A(1,2)=-1;
|
||||||
|
A(2,0)=0; A(2,1)=-1; A(2,2)=4;
|
||||||
|
return A;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(LU_PA_equals_LU) {
|
||||||
|
auto A = make_A_spd();
|
||||||
decomp::LUdcmpd lu(A);
|
decomp::LUdcmpd lu(A);
|
||||||
|
|
||||||
auto x1 = lu.solve(b);
|
utils::Matrix<double> L,U;
|
||||||
utils::Vector<T> x2(3, T{0});
|
split_LU(lu.lu, L, U);
|
||||||
lu.inplace_solve(b, x2);
|
auto P = permutation_from_indx<double>(lu.indx);
|
||||||
|
|
||||||
CHECK( (x1.nearly_equal(x2,1e-12)), "inplace_solve(b,x) differs from solve(b)" );
|
auto PA = numerics::matmul(P, A);
|
||||||
|
auto LU = numerics::matmul(L, U);
|
||||||
|
|
||||||
|
CHECK(mats_equal_tol(PA, LU, 1e-12), "PA should equal LU");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(LU_Singular_Throws) {
|
TEST_CASE(LU_Solve_Vector) {
|
||||||
using T = double;
|
auto A = make_A_spd();
|
||||||
|
decomp::LUdcmpd lu(A);
|
||||||
|
|
||||||
// Singular (row2 = 2 * row1)
|
utils::Vd b(3,0.0);
|
||||||
utils::Matrix<T> S(2,2, T{0});
|
b[0]=1.0; b[1]=2.0; b[2]=3.0;
|
||||||
S(0,0)=1; S(0,1)=2;
|
|
||||||
S(1,0)=2; S(1,1)=4;
|
|
||||||
|
|
||||||
bool threw=false;
|
auto x = lu.solve(b);
|
||||||
try {
|
auto Ax = numerics::matvec(A, x);
|
||||||
decomp::LUdcmpd lu(S);
|
|
||||||
(void)lu;
|
CHECK(b.nearly_equal_vec(Ax, 1e-12), "A*x should equal b");
|
||||||
} catch (const std::runtime_error&) { threw = true; }
|
}
|
||||||
CHECK(threw, "LU should throw on singular matrix");
|
|
||||||
|
TEST_CASE(LU_Solve_Matrix_MultiRHS) {
|
||||||
|
auto A = make_A_spd();
|
||||||
|
decomp::LUdcmpd lu(A);
|
||||||
|
|
||||||
|
utils::Matrix<double> B(3,2,0.0);
|
||||||
|
// two RHS columns
|
||||||
|
B(0,0)=1; B(1,0)=2; B(2,0)=3;
|
||||||
|
B(0,1)=4; B(1,1)=5; B(2,1)=6;
|
||||||
|
|
||||||
|
auto X = lu.solve(B); // 3x2
|
||||||
|
|
||||||
|
// Check A*X == B
|
||||||
|
auto AX = numerics::matmul(A, X);
|
||||||
|
CHECK(mats_equal_tol(AX, B, 1e-12), "A*X should equal B");
|
||||||
|
|
||||||
|
// And that column-wise solve agrees
|
||||||
|
utils::Vd b0(3,0.0), b1(3,0.0);
|
||||||
|
for (int i=0;i<3;++i){ b0[i]=B(i,0); b1[i]=B(i,1); }
|
||||||
|
auto x0 = lu.solve(b0);
|
||||||
|
auto x1 = lu.solve(b1);
|
||||||
|
|
||||||
|
CHECK(std::fabs(double(X(0,0)-x0[0]))<1e-12 &&
|
||||||
|
std::fabs(double(X(1,0)-x0[1]))<1e-12 &&
|
||||||
|
std::fabs(double(X(2,0)-x0[2]))<1e-12, "column 0 mismatch");
|
||||||
|
|
||||||
|
CHECK(std::fabs(double(X(0,1)-x1[0]))<1e-12 &&
|
||||||
|
std::fabs(double(X(1,1)-x1[1]))<1e-12 &&
|
||||||
|
std::fabs(double(X(2,1)-x1[2]))<1e-12, "column 1 mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(LU_Determinant) {
|
||||||
|
auto A = make_A_spd();
|
||||||
|
decomp::LUdcmpd lu(A);
|
||||||
|
// For this A, det = 24
|
||||||
|
double d = lu.det();
|
||||||
|
CHECK(std::fabs(d - 24.0) < 1e-12, "determinant incorrect");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(LU_Inverse_via_SolveI) {
|
||||||
|
auto A = make_A_spd();
|
||||||
|
decomp::LUdcmpd lu(A);
|
||||||
|
|
||||||
|
// Build identity and solve A * X = I
|
||||||
|
auto I = identity<double>(3);
|
||||||
|
auto Inv = lu.solve(I);
|
||||||
|
|
||||||
|
// Check A*Inv == I (and Inv*A == I for good measure)
|
||||||
|
auto AInv = numerics::matmul(A, Inv);
|
||||||
|
auto InvA = numerics::matmul(Inv, A);
|
||||||
|
|
||||||
|
CHECK(mats_equal_tol(AInv, I, 1e-11), "A*Inv should be I");
|
||||||
|
CHECK(mats_equal_tol(InvA, I, 1e-11), "Inv*A should be I");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(LU_NonSquare_Throws) {
|
TEST_CASE(LU_NonSquare_Throws) {
|
||||||
using T = double;
|
utils::Matrix<double> A(2,3,1.0);
|
||||||
|
bool threw=false;
|
||||||
utils::Matrix<T> A(3,2, T{0});
|
try { decomp::LUdcmpd lu(A); } catch (const std::runtime_error&) { threw = true; }
|
||||||
bool threw = false;
|
CHECK(threw, "LU should throw on non-square");
|
||||||
try {
|
|
||||||
decomp::LUdcmpd lu(A);
|
|
||||||
(void)lu;
|
|
||||||
} catch (const std::runtime_error&) { threw = true; }
|
|
||||||
CHECK(threw, "LU should throw on non-square input");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(LU_Inverse_RoundTrip) {
|
TEST_CASE(LU_Singular_Throws) {
|
||||||
using T = double;
|
utils::Matrix<double> A(3,3,0.0);
|
||||||
|
// Make two identical rows
|
||||||
|
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||||
|
A(1,0)=1; A(1,1)=2; A(1,2)=3;
|
||||||
|
A(2,0)=0; A(2,1)=1; A(2,2)=4;
|
||||||
|
|
||||||
// Build a strictly diagonally dominant 5x5
|
bool threw=false;
|
||||||
utils::Matrix<T> A(5,5, T{0});
|
try { decomp::LUdcmpd lu(A); } catch (const std::runtime_error&) { threw = true; }
|
||||||
for (uint64_t i=0;i<5;++i) {
|
CHECK(threw, "LU should throw on singular matrix");
|
||||||
T rowsum = 0;
|
|
||||||
for (uint64_t j=0;j<5;++j) {
|
|
||||||
if (i==j) continue;
|
|
||||||
A(i,j) = T(0.01 * double(1 + ((i+2)*(j+3)) % 7));
|
|
||||||
rowsum += std::fabs(A(i,j));
|
|
||||||
}
|
|
||||||
A(i,i) = rowsum + T{1};
|
|
||||||
}
|
|
||||||
|
|
||||||
decomp::LUdcmpd lu(A);
|
|
||||||
auto Ainv = lu.inverse();
|
|
||||||
|
|
||||||
utils::Md I(5,5, 0.0);
|
|
||||||
for (uint64_t i=0;i<I.rows();++i) I(i,i)=1.0;
|
|
||||||
|
|
||||||
auto L = numerics::matmul(A, Ainv);
|
|
||||||
auto R = numerics::matmul(Ainv, A);
|
|
||||||
|
|
||||||
CHECK(L.nearly_equal(I, 1e-10), "A * inverse(A) not close to I");
|
|
||||||
CHECK(R.nearly_equal(I, 1e-10), "inverse(A) * A not close to I");
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
|
||||||
|
#include "test_common.h"
|
||||||
|
#include "./numerics/matequal.h"
|
||||||
|
|
||||||
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||||
|
using utils::Mf; using utils::Md; using utils::Mi;
|
||||||
|
|
||||||
|
|
||||||
|
// ---------- helpers ----------
|
||||||
|
template <typename T>
|
||||||
|
static void fill_seq(utils::Matrix<T>& M, T start = T{0}, T step = T{1}) {
|
||||||
|
std::uint64_t k = 0;
|
||||||
|
for (std::uint64_t i = 0; i < M.rows(); ++i)
|
||||||
|
for (std::uint64_t j = 0; j < M.cols(); ++j, ++k)
|
||||||
|
M(i,j) = start + step * static_cast<T>(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// ---------- tests ----------
|
||||||
|
|
||||||
|
TEST_CASE(matequal_shape_mismatch) {
|
||||||
|
utils::Mi A(3,3,0), B(3,4,0);
|
||||||
|
CHECK(!numerics::matequal(A,B), "shape mismatch should be false (serial)");
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
CHECK(!numerics::matequal_omp(A,B), "shape mismatch should be false (omp)");
|
||||||
|
#endif
|
||||||
|
CHECK(!numerics::matequal_auto(A,B), "shape mismatch should be false (auto)");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(matequal_int_true_false) {
|
||||||
|
utils::Mi A(4,5,0), B(4,5,0);
|
||||||
|
fill_seq(A, int64_t(0), int64_t(1));
|
||||||
|
fill_seq(B, int64_t(0), int64_t(1));
|
||||||
|
CHECK(numerics::matequal(A,B), "ints equal (serial)");
|
||||||
|
#ifdef _OPENMP
|
||||||
|
CHECK(numerics::matequal_omp(A,B), "ints equal (omp)");
|
||||||
|
#endif
|
||||||
|
// flip one element
|
||||||
|
B(2,3) += 1;
|
||||||
|
CHECK(!numerics::matequal(A,B), "ints differ (serial)");
|
||||||
|
#ifdef _OPENMP
|
||||||
|
CHECK(!numerics::matequal_omp(A,B), "ints differ (omp)"); // will FAIL if your omp branch uses '!='
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(matequal_double_tolerance) {
|
||||||
|
utils::Md A(3,3,0.0), B(3,3,0.0);
|
||||||
|
fill_seq(A, double(1.0), double(0.125));
|
||||||
|
fill_seq(B, double(1.0), double(0.125));
|
||||||
|
// tiny perturbation within default tol
|
||||||
|
B(1,1) += 1e-12;
|
||||||
|
CHECK(numerics::matequal(A,B), "double within tol (serial)");
|
||||||
|
#ifdef _OPENMP
|
||||||
|
CHECK(numerics::matequal_omp(A,B), "double within tol (omp)");
|
||||||
|
#endif
|
||||||
|
// larger perturbation exceeds tol
|
||||||
|
B(0,2) += 1e-6;
|
||||||
|
CHECK(!numerics::matequal(A,B, 1e-9), "double exceeds tol (serial)");
|
||||||
|
#ifdef _OPENMP
|
||||||
|
CHECK(!numerics::matequal_omp(A,B, 1e-9), "double exceeds tol (omp)");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(matequal_auto_agrees) {
|
||||||
|
// Choose size so auto likely takes the OMP path when available,
|
||||||
|
// but this test only checks correctness, not which path was taken.
|
||||||
|
utils::Md A(256,256,0.0), B(256,256,0.0);
|
||||||
|
fill_seq(A, double(0.0), double(0.01));
|
||||||
|
fill_seq(B, double(0.0), double(0.01));
|
||||||
|
CHECK(numerics::matequal_auto(A,B), "auto equal");
|
||||||
|
|
||||||
|
B(5,7) += 1e-3;
|
||||||
|
CHECK(!numerics::matequal_auto(A,B, 1e-9), "auto detects mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
TEST_CASE(mateequal_omp_nested_callsite) {
|
||||||
|
// Verify correctness when called inside an outer parallel region.
|
||||||
|
utils::Mi A(128,128,0), B(128,128,0);
|
||||||
|
fill_seq(A, int64_t(0), int64_t(1));
|
||||||
|
fill_seq(B, int64_t(0), int64_t(1));
|
||||||
|
|
||||||
|
// allow one nested level; inner region inside mateequal_omp may spawn a team
|
||||||
|
int prev_levels = omp_get_max_active_levels();
|
||||||
|
omp_set_max_active_levels(2);
|
||||||
|
|
||||||
|
bool ok_equal = false, ok_diff = false;
|
||||||
|
|
||||||
|
#pragma omp parallel num_threads(2) shared(ok_equal, ok_diff)
|
||||||
|
{
|
||||||
|
#pragma omp single
|
||||||
|
{
|
||||||
|
ok_equal = numerics::matequal_omp(A,B);
|
||||||
|
B(10,10) += 1; // introduce a mismatch
|
||||||
|
ok_diff = !numerics::matequal_omp(A,B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
omp_set_max_active_levels(prev_levels);
|
||||||
|
|
||||||
|
CHECK(ok_equal, "nested equal should be true");
|
||||||
|
CHECK(ok_diff, "nested mismatch should be false");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
+105
-142
@@ -4,161 +4,124 @@
|
|||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
|
// ---------- helpers ----------
|
||||||
|
template <typename T>
|
||||||
|
static bool mats_equal(const utils::Matrix<T>& X, const utils::Matrix<T>& Y, double tol = 0.0) {
|
||||||
|
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
for (std::uint64_t i=0;i<X.rows();++i)
|
||||||
|
for (std::uint64_t j=0;j<X.cols();++j)
|
||||||
|
if (std::fabs(double(X(i,j)) - double(Y(i,j))) > tol) return false;
|
||||||
|
} else {
|
||||||
|
for (std::uint64_t i=0;i<X.rows();++i)
|
||||||
|
for (std::uint64_t j=0;j<X.cols();++j)
|
||||||
|
if (X(i,j) != Y(i,j)) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// ============ Basic correctness ============
|
|
||||||
TEST_CASE(Matmul_Serial_Simple3x3) {
|
|
||||||
utils::Md A(3,3,0.0), B(3,3,0.0);
|
|
||||||
// A = [[1,2,3],[4,5,6],[7,8,9]]
|
|
||||||
double v=1.0;
|
|
||||||
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) A(i,j)=v++;
|
|
||||||
// B = [[9,8,7],[6,5,4],[3,2,1]]
|
|
||||||
double w=9.0;
|
|
||||||
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) B(i,j)=w--;
|
|
||||||
|
|
||||||
auto C = numerics::matmul<double>(A,B);
|
template <typename T>
|
||||||
// Hand-checked first row:
|
static void fill_seq(utils::Matrix<T>& M, T start = T(0), T step = T(1)) {
|
||||||
// row0 dot columns:
|
std::uint64_t k = 0;
|
||||||
// c00 = 1*9 + 2*6 + 3*3 = 30
|
for (std::uint64_t i=0;i<M.rows();++i)
|
||||||
// c01 = 1*8 + 2*5 + 3*2 = 24
|
for (std::uint64_t j=0;j<M.cols();++j,++k)
|
||||||
// c02 = 1*7 + 2*4 + 3*1 = 18
|
M(i,j) = start + step * static_cast<T>(k);
|
||||||
|
}
|
||||||
|
// ---------- tests ----------
|
||||||
|
|
||||||
|
// Small known example: (3x2) · (2x3)
|
||||||
|
TEST_CASE(Matmul_Small_Known) {
|
||||||
|
utils::Mi A(3,2,0), B(2,3,0);
|
||||||
|
// A = [1 2; 3 4; 5 6]
|
||||||
|
A(0,0)=1; A(0,1)=2;
|
||||||
|
A(1,0)=3; A(1,1)=4;
|
||||||
|
A(2,0)=5; A(2,1)=6;
|
||||||
|
// B = [7 8 9; 10 11 12]
|
||||||
|
B(0,0)=7; B(0,1)=8; B(0,2)=9;
|
||||||
|
B(1,0)=10; B(1,1)=11; B(1,2)=12;
|
||||||
|
|
||||||
|
auto C = numerics::matmul(A,B);
|
||||||
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
|
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
|
||||||
CHECK(C(0,0)==30.0 && C(0,1)==24.0 && C(0,2)==18.0, "first row wrong");
|
|
||||||
|
// Expected C:
|
||||||
|
// [27 30 33]
|
||||||
|
// [61 68 75]
|
||||||
|
// [95 106 117]
|
||||||
|
CHECK(C(0,0)==27 && C(0,1)==30 && C(0,2)==33, "row 0 wrong");
|
||||||
|
CHECK(C(1,0)==61 && C(1,1)==68 && C(1,2)==75, "row 1 wrong");
|
||||||
|
CHECK(C(2,0)==95 && C(2,1)==106 && C(2,2)==117, "row 2 wrong");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(Matmul_OMP_Equals_Serial) {
|
TEST_CASE(Matmul_DimMismatch_Throws) {
|
||||||
utils::Md A(4,5,0.0), B(5,3,0.0);
|
utils::Md A(2,3,1.0), B(4,2,2.0); // A.cols()!=B.rows()
|
||||||
// Fill deterministic
|
|
||||||
for (uint64_t i=0;i<A.rows();++i)
|
|
||||||
for (uint64_t j=0;j<A.cols();++j)
|
|
||||||
A(i,j) = 0.1*(1 + (i*17 + j*19)%10);
|
|
||||||
for (uint64_t i=0;i<B.rows();++i)
|
|
||||||
for (uint64_t j=0;j<B.cols();++j)
|
|
||||||
B(i,j) = 0.2*(1 + (i*23 + j*29)%10);
|
|
||||||
|
|
||||||
auto Cs = numerics::matmul<double>(A,B);
|
|
||||||
auto Cr = numerics::matmul_rows_omp<double>(A,B);
|
|
||||||
auto Cc = numerics::matmul_collapse_omp<double>(A,B);
|
|
||||||
auto Ca = numerics::matmul_auto<double>(A,B);
|
|
||||||
|
|
||||||
CHECK((Cs.nearly_equal(Cr, 1e-12)), "rows_omp != serial");
|
|
||||||
CHECK((Cs.nearly_equal(Cc, 1e-12)), "collapse_omp != serial");
|
|
||||||
CHECK((Cs.nearly_equal(Ca, 1e-12)), "auto != serial");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ Dimension mismatch ============
|
|
||||||
TEST_CASE(Matmul_DimensionMismatch_Throws) {
|
|
||||||
utils::Md A(2,3,0.0), B(4,2,0.0);
|
|
||||||
bool threw=false;
|
bool threw=false;
|
||||||
try { auto _ = numerics::matmul<double>(A,B); (void)_; }
|
try { (void)numerics::matmul(A,B); } catch(const std::runtime_error&) { threw=true; }
|
||||||
catch (const std::runtime_error&) { threw=true; }
|
CHECK(threw, "matmul should throw on dim mismatch");
|
||||||
CHECK(threw, "serial should throw on dim mismatch");
|
|
||||||
|
|
||||||
threw=false; try { auto _ = numerics::matmul_rows_omp<double>(A,B); (void)_; }
|
|
||||||
catch (const std::runtime_error&) { threw=true; }
|
|
||||||
CHECK(threw, "rows_omp should throw on dim mismatch");
|
|
||||||
|
|
||||||
threw=false; try { auto _ = numerics::matmul_collapse_omp<double>(A,B); (void)_; }
|
|
||||||
catch (const std::runtime_error&) { threw=true; }
|
|
||||||
CHECK(threw, "collapse_omp should throw on dim mismatch");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Edge cases ============
|
// Compare all variants vs serial on a moderate size
|
||||||
TEST_CASE(Matmul_Edges_ZeroDims) {
|
TEST_CASE(Matmul_Variants_Equal_Int) {
|
||||||
// (0xK) * (KxP) -> (0xP)
|
const std::uint64_t m=32, n=24, p=16;
|
||||||
utils::Md A0(0,5,0.0), B1(5,3,0.0);
|
utils::Mi A(m,n,0), B(n,p,0);
|
||||||
auto C0 = numerics::matmul<double>(A0,B1);
|
|
||||||
CHECK(C0.rows()==0 && C0.cols()==3, "0xK * KxP shape wrong");
|
|
||||||
|
|
||||||
// (MxK) * (Kx0) -> (Mx0)
|
// deterministic fill (no randomness)
|
||||||
utils::Md A2(7,4,0.0), B0(4,0,0.0);
|
fill_seq(A, int64_t(1), int64_t(1));
|
||||||
auto C1 = numerics::matmul<double>(A2,B0);
|
fill_seq(B, int64_t(2), int64_t(3));
|
||||||
CHECK(C1.rows()==7 && C1.cols()==0, "MxK * Kx0 shape wrong");
|
|
||||||
|
auto C_ref = numerics::matmul(A,B);
|
||||||
|
|
||||||
|
auto C_rows = numerics::matmul_rows_omp(A,B);
|
||||||
|
auto C_collapse = numerics::matmul_collapse_omp(A,B);
|
||||||
|
auto C_auto = numerics::matmul_auto(A,B);
|
||||||
|
|
||||||
|
CHECK(mats_equal(C_rows, C_ref), "rows_omp != serial");
|
||||||
|
CHECK(mats_equal(C_collapse, C_ref), "collapse_omp != serial");
|
||||||
|
CHECK(mats_equal(C_auto, C_ref), "auto != serial");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Identity sanity ============
|
TEST_CASE(Matmul_Variants_Equal_Double) {
|
||||||
TEST_CASE(Matmul_Identity) {
|
const std::uint64_t m=33, n=17, p=19;
|
||||||
const uint64_t n=5;
|
utils::Md A(m,n,0.0), B(n,p,0.0);
|
||||||
utils::Md I(n,n,0.0), A(n,n,0.0);
|
|
||||||
for (uint64_t i=0;i<n;++i) I(i,i)=1.0;
|
|
||||||
for (uint64_t i=0;i<n;++i)
|
|
||||||
for (uint64_t j=0;j<n;++j)
|
|
||||||
A(i,j) = (i==j)? 2.0 : ( (i<j)? 1.0 : -1.0 );
|
|
||||||
|
|
||||||
auto L = numerics::matmul<double>(I,A);
|
fill_seq(A, 0.1, 0.01);
|
||||||
auto R = numerics::matmul<double>(A,I);
|
fill_seq(B, 1.0, 0.02);
|
||||||
CHECK(L == A, "I*A != A");
|
|
||||||
CHECK(R == A, "A*I != A");
|
auto C_ref = numerics::matmul(A,B);
|
||||||
|
auto C_rows = numerics::matmul_rows_omp(A,B);
|
||||||
|
auto C_collapse = numerics::matmul_collapse_omp(A,B);
|
||||||
|
auto C_auto = numerics::matmul_auto(A,B);
|
||||||
|
|
||||||
|
CHECK(mats_equal(C_rows, C_ref, 1e-9), "rows_omp != serial (double)");
|
||||||
|
CHECK(mats_equal(C_collapse, C_ref, 1e-9), "collapse_omp != serial (double)");
|
||||||
|
CHECK(mats_equal(C_auto, C_ref, 1e-9), "auto != serial (double)");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============ Perf sanity (same kernel: 1 thread vs many) ============
|
// Nested callsite sanity: call OMP variant from within an outer region
|
||||||
template <class F>
|
#ifdef _OPENMP
|
||||||
static double time_it(F&& f, int iters=1) {
|
TEST_CASE(Matmul_OMP_Nested_Callsite) {
|
||||||
auto t0 = std::chrono::high_resolution_clock::now();
|
const std::uint64_t m=48, n=24, p=32;
|
||||||
for (int i=0;i<iters;++i) f();
|
utils::Mi A(m,n,0), B(n,p,0);
|
||||||
auto t1 = std::chrono::high_resolution_clock::now();
|
fill_seq(A, int64_t(1), int64_t(2));
|
||||||
return std::chrono::duration<double>(t1 - t0).count();
|
fill_seq(B, int64_t(3), int64_t(1));
|
||||||
|
|
||||||
|
auto C_ref = numerics::matmul(A,B);
|
||||||
|
|
||||||
|
int prev_levels = omp_get_max_active_levels();
|
||||||
|
omp_set_max_active_levels(2);
|
||||||
|
|
||||||
|
utils::Mi C_nested;
|
||||||
|
#pragma omp parallel num_threads(2)
|
||||||
|
{
|
||||||
|
#pragma omp single
|
||||||
|
{
|
||||||
|
// either variant is fine; collapse(2) has more parallelism
|
||||||
|
C_nested = numerics::matmul_collapse_omp(A,B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
omp_set_max_active_levels(prev_levels);
|
||||||
|
|
||||||
|
CHECK(mats_equal(C_nested, C_ref), "nested collapse_omp result mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(Matmul_Perf_Sanity_RowOMP) {
|
|
||||||
#ifndef _OPENMP
|
|
||||||
return;
|
|
||||||
#else
|
|
||||||
int hw = omp_get_max_threads();
|
|
||||||
if (hw <= 1) return;
|
|
||||||
|
|
||||||
const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed
|
|
||||||
utils::Md A(m,k,0.0), B(k,p,0.0);
|
|
||||||
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i+j%7)*0.001;
|
|
||||||
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i*3+j%5)*0.0005;
|
|
||||||
|
|
||||||
// Warm-up
|
|
||||||
(void) numerics::matmul_rows_omp<double>(A,B);
|
|
||||||
|
|
||||||
int prev = omp_get_max_threads();
|
|
||||||
auto t0 = std::chrono::high_resolution_clock::now();
|
|
||||||
omp_set_num_threads(1);
|
|
||||||
numerics::matmul_rows_omp<double>(A,B);
|
|
||||||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
|
||||||
|
|
||||||
omp_set_num_threads(hw);
|
|
||||||
t0 = std::chrono::high_resolution_clock::now();
|
|
||||||
numerics::matmul_rows_omp<double>(A,B);
|
|
||||||
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
|
||||||
|
|
||||||
omp_set_num_threads(prev);
|
|
||||||
|
|
||||||
// Must not be notably slower with many threads
|
|
||||||
CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread");
|
|
||||||
#endif
|
#endif
|
||||||
}
|
|
||||||
|
|
||||||
TEST_CASE(Matmul_Perf_Sanity_CollapseOMP) {
|
|
||||||
#ifndef _OPENMP
|
|
||||||
return;
|
|
||||||
#else
|
|
||||||
int hw = omp_get_max_threads();
|
|
||||||
if (hw <= 1) return;
|
|
||||||
|
|
||||||
const uint64_t m=512, k=512, p=512;
|
|
||||||
utils::Md A(m,k,0.0), B(k,p,0.0);
|
|
||||||
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i*7+j%11)*0.0003;
|
|
||||||
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i%13+j)*0.0002;
|
|
||||||
|
|
||||||
(void) numerics::matmul_collapse_omp<double>(A,B); // warm-up
|
|
||||||
|
|
||||||
int prev = omp_get_max_threads();
|
|
||||||
auto t0 = std::chrono::high_resolution_clock::now();
|
|
||||||
omp_set_num_threads(1);
|
|
||||||
numerics::matmul_collapse_omp<double>(A,B);
|
|
||||||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
|
||||||
|
|
||||||
omp_set_num_threads(hw);
|
|
||||||
t0 = std::chrono::high_resolution_clock::now();
|
|
||||||
numerics::matmul_collapse_omp<double>(A,B);
|
|
||||||
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
|
||||||
|
|
||||||
omp_set_num_threads(prev);
|
|
||||||
|
|
||||||
CHECK(tN <= t1 * 1.05, "collapse_omp: multi-thread slower than single-thread");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
+139
-117
@@ -1,142 +1,164 @@
|
|||||||
|
|
||||||
#include "test_common.h"
|
#include "test_common.h"
|
||||||
#include "./utils/utils.h"
|
#include "./utils/matrix.h"
|
||||||
|
|
||||||
using utils::Vf; using utils::Vd; using utils::Vi;
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||||
using utils::Mf; using utils::Md; using utils::Mi;
|
using utils::Mf; using utils::Md; using utils::Mi;
|
||||||
|
|
||||||
|
|
||||||
// ---------- Construction & element access ----------
|
// tiny helper
|
||||||
TEST_CASE(Matrix_Construct_Access) {
|
template <typename T>
|
||||||
Md M; // default
|
static bool mat_is_filled(const utils::Matrix<T>& M, T v) {
|
||||||
CHECK(M.rows()==0 && M.cols()==0, "default ctor dims wrong");
|
for (std::uint64_t i = 0; i < M.rows(); ++i)
|
||||||
|
for (std::uint64_t j = 0; j < M.cols(); ++j)
|
||||||
Mf A(2,3, 1.0f);
|
if (M(i,j) != v) return false;
|
||||||
CHECK(A.rows()==2 && A.cols()==3, "ctor dims wrong");
|
return true;
|
||||||
CHECK(A(0,0)==1.0f && A(1,2)==1.0f, "fill wrong");
|
|
||||||
|
|
||||||
A(0,1)=2.5f; A(1,0)=3.5f;
|
|
||||||
CHECK(A(0,1)==2.5f && A(1,0)==3.5f, "operator() set/get failed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Equality, inequality, nearly_equal ----------
|
// ------------------- basic construction -------------------
|
||||||
TEST_CASE(Matrix_Equality) {
|
TEST_CASE(Matrix_Default_Construct) {
|
||||||
Mi A(2,2,0), B(2,2,0), C(2,2,1);
|
Md M;
|
||||||
A(0,0)=1; A(1,1)=1; // A = I
|
CHECK_EQ(M.rows(), 0, "default rows should be 0");
|
||||||
B(0,0)=1; B(1,1)=1; // B = I
|
CHECK_EQ(M.cols(), 0, "default cols should be 0");
|
||||||
|
|
||||||
CHECK(A == B, "== failed identical");
|
|
||||||
CHECK(!(A != B), "!= failed identical");
|
|
||||||
CHECK(A != C, "!= failed different");
|
|
||||||
|
|
||||||
Md F1(2,2,0.0), F2(2,2,0.0);
|
|
||||||
F1(0,0)=1.0; F1(1,1)=2.0;
|
|
||||||
F2(0,0)=1.0; F2(1,1)=2.0 + 5e-10; // tiny perturbation
|
|
||||||
CHECK(!(F1 == F2), "operator== is exact; should differ");
|
|
||||||
CHECK(F1.nearly_equal(F2, 1e-9), "nearly_equal should accept tiny delta");
|
|
||||||
CHECK(!F1.nearly_equal(F2, 1e-12), "nearly_equal too strict should fail");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Row helpers ----------
|
TEST_CASE(Matrix_Filled_Construct) {
|
||||||
TEST_CASE(Matrix_Row_Get_Set) {
|
Md M(3, 4, 2.5);
|
||||||
Mf M(3,4, 0.0f);
|
CHECK_EQ(M.rows(), 3, "rows");
|
||||||
Vf r(4, 0.0f);
|
CHECK_EQ(M.cols(), 4, "cols");
|
||||||
for (uint64_t j=0;j<4;++j) r[j] = float(j+1); // [1,2,3,4]
|
CHECK(mat_is_filled(M, 2.5), "all elements should be 2.5");
|
||||||
|
|
||||||
M.set_row(1, r);
|
|
||||||
auto out = M.get_row(1);
|
|
||||||
CHECK(out == r, "set_row/get_row mismatch");
|
|
||||||
|
|
||||||
// size mismatch should throw
|
|
||||||
bool threw=false;
|
|
||||||
try { Vf bad(3, 9.0f); M.set_row(2, bad); } catch (const std::exception&) { threw=true; }
|
|
||||||
CHECK(threw, "set_row should throw on size mismatch");
|
|
||||||
|
|
||||||
// out of range
|
|
||||||
threw=false;
|
|
||||||
try { (void)M.get_row(3); } catch (const std::out_of_range&) { threw=true; }
|
|
||||||
CHECK(threw, "get_row should throw on OOB index");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Column helpers ----------
|
// ------------------- element access / write -------------------
|
||||||
TEST_CASE(Matrix_Col_Get_Set) {
|
TEST_CASE(Matrix_Set_Get) {
|
||||||
Md M(3,2, 0.0);
|
Mi M(2, 3, 0);
|
||||||
Vd c(3, 0.0);
|
M(0,0) = 42;
|
||||||
c[0]=10; c[1]=20; c[2]=30;
|
M(1,2) = -7;
|
||||||
|
CHECK_EQ(M(0,0), 42, "set/get (0,0)");
|
||||||
M.set_col(1, c);
|
CHECK_EQ(M(1,2), -7, "set/get (1,2)");
|
||||||
auto out = M.get_col(1);
|
|
||||||
CHECK(out == c, "set_col/get_col mismatch");
|
|
||||||
|
|
||||||
// size mismatch should throw
|
|
||||||
bool threw=false;
|
|
||||||
try { Vd bad(2, 9.0); M.set_col(0, bad); } catch (const std::exception&) { threw=true; }
|
|
||||||
CHECK(threw, "set_col should throw on size mismatch");
|
|
||||||
|
|
||||||
// out of range
|
|
||||||
threw=false;
|
|
||||||
try { (void)M.get_col(2); } catch (const std::out_of_range&) { threw=true; }
|
|
||||||
CHECK(threw, "get_col should throw on OOB index");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- swap_rows / swap_cols ----------
|
// ------------------- resize semantics -------------------
|
||||||
TEST_CASE(Matrix_Swap_Rows_Cols) {
|
TEST_CASE(Matrix_Resize_Grow) {
|
||||||
|
Mf M(2, 2, 1.0f);
|
||||||
|
M.resize(3, 4, 9.0f); // grow; newly appended elements get the fill value
|
||||||
|
CHECK_EQ(M.rows(), 3, "rows after resize");
|
||||||
|
CHECK_EQ(M.cols(), 4, "cols after resize");
|
||||||
|
CHECK(M(2,3) == 9.0f, "last element should be the fill value after grow");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matrix_Resize_Shrink) {
|
||||||
|
Mi M(4, 4, 5);
|
||||||
|
M(0,0) = 11;
|
||||||
|
M.resize(2, 2, 999); // shrink; size reduces
|
||||||
|
CHECK_EQ(M.rows(), 2, "rows after shrink");
|
||||||
|
CHECK_EQ(M.cols(), 2, "cols after shrink");
|
||||||
|
// element mapping after shrink is implementation dependent; just check bounds usable
|
||||||
|
M(1,1) = 3;
|
||||||
|
CHECK_EQ(M(1,1), 3, "write after shrink works");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------- row helpers -------------------
|
||||||
|
TEST_CASE(Matrix_Get_Row) {
|
||||||
|
Mi M(3,4,0);
|
||||||
|
// set row 1 to [10,20,30,40]
|
||||||
|
for (std::uint64_t j=0;j<4;++j) M(1,j) = (j+1)*10;
|
||||||
|
auto r = M.get_row(1);
|
||||||
|
CHECK_EQ(r.size(), 4, "row size");
|
||||||
|
CHECK(r[0]==10 && r[1]==20 && r[2]==30 && r[3]==40, "row contents");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matrix_Set_Row) {
|
||||||
Mi M(2,3,0);
|
Mi M(2,3,0);
|
||||||
// Row 0: [1,2,3], Row 1: [4,5,6]
|
utils::Vector<int64_t> v(3, 0);
|
||||||
M(0,0)=1; M(0,1)=2; M(0,2)=3;
|
v[0]=7; v[1]=8; v[2]=9;
|
||||||
M(1,0)=4; M(1,1)=5; M(1,2)=6;
|
M.set_row(0, v);
|
||||||
|
CHECK(M(0,0)==7 && M(0,1)==8 && M(0,2)==9, "set_row contents");
|
||||||
|
}
|
||||||
|
|
||||||
M.swap_rows(0,1);
|
TEST_CASE(Matrix_Row_OutOfRange_Throws) {
|
||||||
CHECK(M(0,0)==4 && M(0,1)==5 && M(0,2)==6, "swap_rows row0 wrong");
|
Mi M(2,2,0);
|
||||||
CHECK(M(1,0)==1 && M(1,1)==2 && M(1,2)==3, "swap_rows row1 wrong");
|
|
||||||
|
|
||||||
// swap back via cols
|
|
||||||
M.swap_cols(0,2);
|
|
||||||
// After swapping col0<->col2:
|
|
||||||
// Row0: [6,5,4], Row1: [3,2,1]
|
|
||||||
CHECK(M(0,0)==6 && M(0,1)==5 && M(0,2)==4, "swap_cols row0 wrong");
|
|
||||||
CHECK(M(1,0)==3 && M(1,1)==2 && M(1,2)==1, "swap_cols row1 wrong");
|
|
||||||
|
|
||||||
// no-op swap (a==b) should not crash or change
|
|
||||||
M.swap_rows(1,1);
|
|
||||||
M.swap_cols(2,2);
|
|
||||||
|
|
||||||
// OOB checks
|
|
||||||
bool threw=false;
|
bool threw=false;
|
||||||
try { M.swap_rows(5,1); } catch (const std::out_of_range&) { threw=true; }
|
try { (void)M.get_row(2); } catch(const std::out_of_range&) { threw=true; }
|
||||||
CHECK(threw, "swap_rows should throw on OOB");
|
CHECK(threw, "get_row out-of-range should throw");
|
||||||
|
|
||||||
threw=false;
|
threw=false;
|
||||||
try { M.swap_cols(0,9); } catch (const std::out_of_range&) { threw=true; }
|
try {
|
||||||
CHECK(threw, "swap_cols should throw on OOB");
|
utils::Vector<int64_t> v(3,1); // wrong size
|
||||||
|
M.set_row(1, v);
|
||||||
|
} catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "set_row size mismatch should throw");
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- data() layout (contiguous row-major) ----------
|
// ------------------- col helpers -------------------
|
||||||
TEST_CASE(Matrix_Data_Layout) {
|
TEST_CASE(Matrix_Get_Col) {
|
||||||
Md M(2,3, 0.0);
|
Mi M(3,2,0);
|
||||||
// Fill increasing sequence
|
M(0,1)=5; M(1,1)=6; M(2,1)=7;
|
||||||
double val=1.0;
|
auto c = M.get_col(1);
|
||||||
for (uint64_t i=0;i<M.rows();++i)
|
CHECK_EQ(c.size(), 3, "col size");
|
||||||
for (uint64_t j=0;j<M.cols();++j)
|
CHECK(c[0]==5 && c[1]==6 && c[2]==7, "col contents");
|
||||||
M(i,j) = val++;
|
|
||||||
|
|
||||||
const double* p = M.data();
|
|
||||||
// Expect row-major: [1,2,3,4,5,6]
|
|
||||||
CHECK(p[0]==1.0 && p[1]==2.0 && p[2]==3.0 && p[3]==4.0 && p[4]==5.0 && p[5]==6.0,
|
|
||||||
"data() row-major layout wrong");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Stream output ----------
|
TEST_CASE(Matrix_Set_Col) {
|
||||||
TEST_CASE(Matrix_StreamOutput) {
|
Mi M(3,2,0);
|
||||||
|
utils::Vector<int64_t> v(3, 0);
|
||||||
|
v[0]=1; v[1]=4; v[2]=9;
|
||||||
|
M.set_col(0, v);
|
||||||
|
CHECK(M(0,0)==1 && M(1,0)==4 && M(2,0)==9, "set_col contents");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matrix_Col_OutOfRange_Throws) {
|
||||||
|
Mi M(2,2,0);
|
||||||
|
bool threw=false;
|
||||||
|
try { (void)M.get_col(2); } catch(const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "get_col out-of-range should throw");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try {
|
||||||
|
utils::Vector<int64_t> v(1,1); // wrong size
|
||||||
|
M.set_col(1, v);
|
||||||
|
} catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "set_col size mismatch should throw");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------- swap rows/cols -------------------
|
||||||
|
TEST_CASE(Matrix_Swap_Rows) {
|
||||||
|
Mi M(2,3,0);
|
||||||
|
// row0: 1 2 3, row1: 4 5 6
|
||||||
|
for (std::uint64_t j=0;j<3;++j){ M(0,j)=j+1; M(1,j)=j+4; }
|
||||||
|
M.swap_rows(0,1);
|
||||||
|
CHECK(M(0,0)==4 && M(0,1)==5 && M(0,2)==6, "row0 after swap");
|
||||||
|
CHECK(M(1,0)==1 && M(1,1)==2 && M(1,2)==3, "row1 after swap");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matrix_Swap_Cols) {
|
||||||
|
Mi M(3,3,0);
|
||||||
|
// col0: 9 8 7, col2: 1 2 3
|
||||||
|
M(0,0)=9; M(1,0)=8; M(2,0)=7;
|
||||||
|
M(0,2)=1; M(1,2)=2; M(2,2)=3;
|
||||||
|
M.swap_cols(0,2);
|
||||||
|
CHECK(M(0,0)==1 && M(1,0)==2 && M(2,0)==3, "col0 after swap");
|
||||||
|
CHECK(M(0,2)==9 && M(1,2)==8 && M(2,2)==7, "col2 after swap");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matrix_Swap_OutOfRange_Throws) {
|
||||||
|
Mi M(2,2,0);
|
||||||
|
bool threw=false;
|
||||||
|
try { M.swap_rows(0,2); } catch(const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "swap_rows out-of-range should throw");
|
||||||
|
threw=false;
|
||||||
|
try { M.swap_cols(0,3); } catch(const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "swap_cols out-of-range should throw");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------- stream output (basic sanity) -------------------
|
||||||
|
TEST_CASE(Matrix_Stream_ToString) {
|
||||||
Mf M(2,2,0.0f);
|
Mf M(2,2,0.0f);
|
||||||
M(0,0)=1.0f; M(0,1)=2.0f;
|
M(0,0)=1.0f; M(0,1)=2.0f; M(1,0)=3.0f; M(1,1)=4.0f;
|
||||||
M(1,0)=3.0f; M(1,1)=4.0f;
|
std::ostringstream os;
|
||||||
|
os << M;
|
||||||
std::ostringstream oss;
|
auto s = os.str();
|
||||||
oss << M;
|
CHECK(s.find("[[") != std::string::npos, "starts with [[");
|
||||||
const std::string s = oss.str();
|
CHECK(s.find("1.000") != std::string::npos, "contains formatted 1.000");
|
||||||
// Format example:
|
CHECK(s.find("4.000") != std::string::npos, "contains formatted 4.000");
|
||||||
// [[1.000, 2.000],
|
|
||||||
// [3.000, 4.000]]
|
|
||||||
CHECK(s.find("[[1.000, 2.000],") != std::string::npos, "ostream first row format");
|
|
||||||
CHECK(s.find("[3.000, 4.000]]") != std::string::npos, "ostream second row format");
|
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
|
|
||||||
#include "test_common.h"
|
#include "test_common.h"
|
||||||
#include "./utils/utils.h" // matrix.h, vector.h
|
|
||||||
#include "./numerics/matvec.h" // numerics::matvec / inplace_transpose
|
|
||||||
|
|
||||||
|
#include "./numerics/matvec.h" // numerics::matvec / inplace_transpose
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
using utils::Vi; using utils::Vf; using utils::Vd;
|
using utils::Vi; using utils::Vf; using utils::Vd;
|
||||||
@@ -107,7 +106,7 @@ TEST_CASE(Matvec_Speed_Sanity) {
|
|||||||
|
|
||||||
auto t0 = std::chrono::high_resolution_clock::now();
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
auto yS = numerics::matvec(A,x);
|
auto yS = numerics::matvec(A,x);
|
||||||
double tp = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
|
double tp = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
int threads = omp_get_max_threads();
|
int threads = omp_get_max_threads();
|
||||||
@@ -117,13 +116,13 @@ TEST_CASE(Matvec_Speed_Sanity) {
|
|||||||
|
|
||||||
t0 = std::chrono::high_resolution_clock::now();
|
t0 = std::chrono::high_resolution_clock::now();
|
||||||
auto yP = numerics::matvec_omp(A,x);
|
auto yP = numerics::matvec_omp(A,x);
|
||||||
double ts = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
|
double ts = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
CHECK((yS.nearly_equal_vec(yP)), "matvec_omp != matvec_serial (large)");
|
CHECK((yS.nearly_equal_vec(yP)), "matvec_omp != matvec_serial (large)");
|
||||||
// Only enforce basic sanity if we *can* use >1 threads:
|
// Only enforce basic sanity if we *can* use >1 threads:
|
||||||
if (threads > 1) {
|
if (threads > 1) {
|
||||||
// Be generous: just require not significantly slower.
|
// Be generous: just require not significantly slower.
|
||||||
CHECK(tp <= ts, "matvec_omp unexpectedly much slower than serial");
|
CHECK(tp >= ts, "matvec_omp unexpectedly much slower than serial");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+136
-73
@@ -1,88 +1,151 @@
|
|||||||
|
|
||||||
#include "test_common.h"
|
#include "test_common.h"
|
||||||
#include "./utils/utils.h" // matrix.h, vector.h
|
//#include "./utils/matrix.h" // matrix.h, vector.h
|
||||||
#include "./numerics/transpose.h" // numerics::transpose / inplace_transpose
|
#include "./numerics/transpose.h" // numerics::transpose / inplace_transpose
|
||||||
|
|
||||||
using utils::Mi; using utils::Mf; using utils::Md;
|
using utils::Mi; using utils::Mf; using utils::Md;
|
||||||
|
|
||||||
//
|
/// ---- helpers ----
|
||||||
// ---------- Out-of-place transpose (rectangular) ----------
|
template <typename T>
|
||||||
//
|
static void fill_seq(utils::Matrix<T>& M, T start = T(0), T step = T(1)) {
|
||||||
TEST_CASE(Transpose_Rectangular_OutOfPlace) {
|
std::uint64_t k = 0;
|
||||||
// A = [ [1, 2, 3],
|
for (std::uint64_t i=0; i<M.rows(); ++i)
|
||||||
// [4, 5, 6] ] (2x3)
|
for (std::uint64_t j=0; j<M.cols(); ++j, ++k)
|
||||||
Md A(2,3,0.0);
|
M(i,j) = start + step * static_cast<T>(k);
|
||||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
|
||||||
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
|
||||||
|
|
||||||
auto AT = numerics::transpose(A); // (3x2)
|
|
||||||
|
|
||||||
CHECK(AT.rows()==3 && AT.cols()==2, "shape wrong after transpose");
|
|
||||||
CHECK(AT(0,0)==1 && AT(1,0)==2 && AT(2,0)==3, "first column wrong");
|
|
||||||
CHECK(AT(0,1)==4 && AT(1,1)==5 && AT(2,1)==6, "second column wrong");
|
|
||||||
|
|
||||||
// Involution: T(T(A)) == A
|
|
||||||
auto ATT = numerics::transpose(AT);
|
|
||||||
CHECK(ATT == A, "transpose(transpose(A)) != A");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// ---------- In-place transpose (square) ----------
|
|
||||||
//
|
|
||||||
TEST_CASE(Transpose_Square_InPlace) {
|
|
||||||
// 3x3 with distinct values
|
|
||||||
Mf S(3,3,0.0f);
|
|
||||||
float val = 1.0f;
|
|
||||||
for (uint64_t i=0;i<3;++i)
|
|
||||||
for (uint64_t j=0;j<3;++j)
|
|
||||||
S(i,j) = val++;
|
|
||||||
|
|
||||||
// Make an explicit transpose to compare against
|
template <typename T>
|
||||||
auto ST = numerics::transpose(S);
|
static bool mats_equal(const utils::Matrix<T>& X, const utils::Matrix<T>& Y) {
|
||||||
|
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
|
||||||
// In-place should match the out-of-place result
|
for (std::uint64_t i=0; i<X.rows(); ++i)
|
||||||
numerics::inplace_transpose(S);
|
for (std::uint64_t j=0; j<X.cols(); ++j)
|
||||||
CHECK(S == ST, "inplace_transpose result mismatch");
|
if (X(i,j) != Y(i,j)) return false;
|
||||||
|
return true;
|
||||||
// Involution: applying in-place again should return original
|
|
||||||
numerics::inplace_transpose(S);
|
|
||||||
// Now S should equal the original pre-inplace matrix (which was transposed above)
|
|
||||||
// We can reconstruct original by transposing ST:
|
|
||||||
auto orig = numerics::transpose(ST);
|
|
||||||
CHECK(S == orig, "inplace transpose twice should restore original");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
// ---- tests ----
|
||||||
// ---------- In-place transpose must throw on non-square ----------
|
|
||||||
//
|
// Empty and 1x1 edge cases
|
||||||
TEST_CASE(Transpose_InPlace_Throws_On_Rectangular) {
|
TEST_CASE(Transpose_Edges) {
|
||||||
Md R(2,3,0.0); // rectangular
|
utils::Mi E; // 0x0
|
||||||
bool threw = false;
|
auto Et = numerics::transpose(E);
|
||||||
try {
|
CHECK(Et.rows()==0 && Et.cols()==0, "transpose of empty should be empty");
|
||||||
numerics::inplace_transpose(R);
|
|
||||||
} catch (const std::runtime_error&) {
|
utils::Mi S(1,1,42);
|
||||||
threw = true;
|
auto St = numerics::transpose(S);
|
||||||
|
CHECK(St.rows()==1 && St.cols()==1, "1x1 stays 1x1");
|
||||||
|
CHECK(St(0,0)==42, "1x1 value preserved");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rectangular out-of-place
|
||||||
|
TEST_CASE(Transpose_Rectangular) {
|
||||||
|
const std::uint64_t r=3, c=5;
|
||||||
|
utils::Mi A(r,c,0);
|
||||||
|
fill_seq(A, int64_t(1), int64_t(1));
|
||||||
|
auto B = numerics::transpose(A);
|
||||||
|
|
||||||
|
CHECK(B.rows()==c && B.cols()==r, "shape swapped");
|
||||||
|
for (std::uint64_t i=0;i<r;++i)
|
||||||
|
for (std::uint64_t j=0;j<c;++j)
|
||||||
|
CHECK(B(j,i)==A(i,j), "transpose content mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Square: in-place equals out-of-place
|
||||||
|
TEST_CASE(Transpose_Inplace_Equals_OutOfPlace) {
|
||||||
|
const std::uint64_t n=7;
|
||||||
|
utils::Mi A(n,n,0);
|
||||||
|
fill_seq(A, int64_t(10), int64_t(3));
|
||||||
|
auto B = numerics::transpose(A);
|
||||||
|
|
||||||
|
auto C = A; // copy
|
||||||
|
numerics::inplace_transpose_square(C);
|
||||||
|
CHECK(mats_equal(B, C), "inplace transpose should match out-of-place");
|
||||||
|
}
|
||||||
|
|
||||||
|
// In-place should throw on non-square
|
||||||
|
TEST_CASE(Transpose_Inplace_Throws_On_Rect) {
|
||||||
|
utils::Mi A(2,3,0);
|
||||||
|
bool threw=false;
|
||||||
|
try { numerics::inplace_transpose_square(A); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "inplace_transpose_square must throw on non-square");
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OMP variants (compile only with -fopenmp) ---
|
||||||
|
#ifdef _OPENMP
|
||||||
|
TEST_CASE(Transpose_OMP_OutOfPlace_Equals_Serial) {
|
||||||
|
const std::uint64_t r=17, c=31;
|
||||||
|
utils::Mi A(r,c,0);
|
||||||
|
fill_seq(A, int64_t(5), int64_t(2));
|
||||||
|
|
||||||
|
auto B_serial = numerics::transpose(A);
|
||||||
|
auto B_omp = numerics::transpose_omp(A);
|
||||||
|
CHECK(mats_equal(B_serial, B_omp), "transpose_omp != transpose");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Transpose_OMP_Inplace_Equals_Serial) {
|
||||||
|
const std::uint64_t n=32;
|
||||||
|
utils::Mi A(n,n,0);
|
||||||
|
fill_seq(A, int64_t(0), int64_t(1));
|
||||||
|
|
||||||
|
auto B_ref = numerics::transpose(A);
|
||||||
|
|
||||||
|
auto C = A;
|
||||||
|
numerics::inplace_transpose_square_omp(C);
|
||||||
|
CHECK(mats_equal(B_ref, C), "inplace_transpose_square_omp != transpose");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto selectors (if you added transpose_auto / inplace_transpose_square_auto_auto)
|
||||||
|
TEST_CASE(Transpose_Auto_Equals_Serial) {
|
||||||
|
// Rectangular: transpose_auto
|
||||||
|
utils::Mi A(23,11,0);
|
||||||
|
fill_seq(A, int64_t(1), int64_t(1));
|
||||||
|
auto B_ref = numerics::transpose(A);
|
||||||
|
auto B_auto = numerics::transpose_auto(A);
|
||||||
|
CHECK(mats_equal(B_ref, B_auto), "transpose_auto != transpose");
|
||||||
|
|
||||||
|
// Square: inplace_transpose_square_auto
|
||||||
|
utils::Mi S(29,29,0);
|
||||||
|
fill_seq(S, int64_t(4), int64_t(7));
|
||||||
|
auto Tref = numerics::transpose(S);
|
||||||
|
numerics::inplace_transpose_square_auto(S);
|
||||||
|
CHECK(mats_equal(Tref, S), "inplace_transpose_square_auto != transpose");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nested callsite sanity: call OMP versions inside an outer region
|
||||||
|
TEST_CASE(Transpose_OMP_Nested_Callsite) {
|
||||||
|
// Out-of-place on rectangular
|
||||||
|
utils::Mi A(19,37,0);
|
||||||
|
fill_seq(A, int64_t(2), int64_t(3));
|
||||||
|
auto Bref = numerics::transpose(A);
|
||||||
|
|
||||||
|
int prev_levels = omp_get_max_active_levels();
|
||||||
|
omp_set_max_active_levels(2);
|
||||||
|
|
||||||
|
utils::Mi Bnest;
|
||||||
|
#pragma omp parallel num_threads(2)
|
||||||
|
{
|
||||||
|
#pragma omp single
|
||||||
|
{
|
||||||
|
Bnest = numerics::transpose_omp(A);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
CHECK(threw, "inplace_transpose must throw on non-square matrices");
|
CHECK(mats_equal(Bref, Bnest), "nested transpose_omp mismatch");
|
||||||
}
|
|
||||||
|
|
||||||
//
|
// In-place on square
|
||||||
// ---------- Edge cases: 0x0 and 1x1 ----------
|
utils::Mi S(41,41,0);
|
||||||
//
|
fill_seq(S, int64_t(1), int64_t(5));
|
||||||
TEST_CASE(Transpose_Edge_0x0_1x1) {
|
auto Sref = numerics::transpose(S);
|
||||||
// 0x0 should be fine both ways
|
|
||||||
Md E; // 0x0
|
|
||||||
auto ET = numerics::transpose(E);
|
|
||||||
CHECK(ET.rows()==0 && ET.cols()==0, "0x0 transpose shape wrong");
|
|
||||||
// in-place on 0x0 (rows==cols) should not throw
|
|
||||||
numerics::inplace_transpose(E);
|
|
||||||
CHECK(E.rows()==0 && E.cols()==0, "0x0 inplace transpose changed shape");
|
|
||||||
|
|
||||||
// 1x1 should be a no-op and not throw
|
#pragma omp parallel num_threads(2)
|
||||||
Mi I(1,1,0);
|
{
|
||||||
I(0,0) = 42;
|
#pragma omp single
|
||||||
auto IT = numerics::transpose(I);
|
{
|
||||||
CHECK(IT.rows()==1 && IT.cols()==1 && IT(0,0)==42, "1x1 transpose wrong");
|
numerics::inplace_transpose_square_omp(S);
|
||||||
numerics::inplace_transpose(I);
|
}
|
||||||
CHECK(I(0,0)==42, "1x1 inplace transpose changed value");
|
}
|
||||||
|
omp_set_max_active_levels(prev_levels);
|
||||||
|
|
||||||
|
CHECK(mats_equal(Sref, S), "nested inplace_transpose_square_omp mismatch");
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
+156
-171
@@ -4,208 +4,193 @@
|
|||||||
|
|
||||||
using utils::Vf; using utils::Vd; using utils::Vi;
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||||
|
|
||||||
//
|
// ---------- helpers ----------
|
||||||
// ---------- Basic construction & access ----------
|
template <typename T>
|
||||||
//
|
static bool vec_equal_exact(const utils::Vector<T>& a, const utils::Vector<T>& b) {
|
||||||
TEST_CASE(Vector_Construct_Size_Access) {
|
if (a.size() != b.size()) return false;
|
||||||
Vd a; // default
|
for (std::uint64_t i=0;i<a.size();++i) if (a[i]!=b[i]) return false;
|
||||||
CHECK(a.size() == 0, "default size must be 0");
|
return true;
|
||||||
|
|
||||||
Vf b(3, 1.0f); // (n, fill)
|
|
||||||
CHECK(b.size() == 3, "size wrong");
|
|
||||||
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "fill wrong");
|
|
||||||
|
|
||||||
b[1] = 2.5f;
|
|
||||||
CHECK(b[1] == 2.5f, "operator[] write failed");
|
|
||||||
|
|
||||||
// resize (grow + value)
|
|
||||||
b.resize(5, 7.0f);
|
|
||||||
CHECK(b.size() == 5, "resize grow size wrong");
|
|
||||||
CHECK(b[0] == 1.0f && b[1] == 2.5f && b[2] == 1.0f && b[3] == 7.0f && b[4] == 7.0f,
|
|
||||||
"resize grow values wrong");
|
|
||||||
|
|
||||||
// resize (shrink)
|
|
||||||
b.resize(2);
|
|
||||||
CHECK(b.size() == 2, "resize shrink size wrong");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE(Vector_Clear_PushBack) {
|
|
||||||
Vi v(0, 0);
|
|
||||||
v.push_back(10);
|
|
||||||
v.push_back(20);
|
|
||||||
CHECK(v.size() == 2, "push_back size wrong");
|
|
||||||
CHECK(v[0] == 10 && v[1] == 20, "push_back values wrong");
|
|
||||||
|
|
||||||
v.clear();
|
// ---------- tests ----------
|
||||||
CHECK(v.size() == 0, "clear failed");
|
|
||||||
|
TEST_CASE(Vector_Construct_Size_Fill) {
|
||||||
|
utils::Vd v0;
|
||||||
|
CHECK(v0.size()==0, "default size should be 0");
|
||||||
|
|
||||||
|
utils::Vd v1(5, 3.5);
|
||||||
|
CHECK(v1.size()==5, "size 5");
|
||||||
|
for (std::uint64_t i=0;i<5;++i) CHECK(v1[i]==3.5, "filled value 3.5");
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// ---------- Equality / Inequality (tolerant for float/double) ----------
|
|
||||||
//
|
|
||||||
TEST_CASE(Vector_Equality_Tolerant) {
|
|
||||||
Vd a(3, 1.0), b(3, 1.0);
|
|
||||||
CHECK(a == b, "== identical failed");
|
|
||||||
CHECK(!(a != b), "!= identical failed");
|
|
||||||
|
|
||||||
// Tiny perturbation within eps (1e-6 default)
|
TEST_CASE(Vector_PushBack_Resize) {
|
||||||
b[1] += 1e-7;
|
utils::Vi v;
|
||||||
CHECK(a == b, "== tolerant failed");
|
v.push_back(1); v.push_back(2);
|
||||||
|
CHECK(v.size()==2, "push_back size");
|
||||||
|
CHECK(v[0]==1 && v[1]==2, "push_back contents");
|
||||||
|
|
||||||
// Larger perturbation should fail equality
|
v.resize(5, 7);
|
||||||
b[1] += 1e-4;
|
CHECK(v.size()==5, "resize size");
|
||||||
CHECK(a != b, "!= with difference failed");
|
CHECK(v[2]==7 && v[3]==7 && v[4]==7, "resize fill value");
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// ---------- Scalar arithmetic: +, -, *, / (inplace and returning) ----------
|
TEST_CASE(Vector_Data_ReadWrite) {
|
||||||
//
|
utils::Vd v(4, 0.0);
|
||||||
|
double* p = v.data();
|
||||||
|
for (std::uint64_t i=0;i<4;++i) p[i] = double(i+1);
|
||||||
|
CHECK(v[0]==1.0 && v[3]==4.0, "write via data()");
|
||||||
|
|
||||||
|
const utils::Vd& cv = v;
|
||||||
|
const double* cp = cv.data();
|
||||||
|
double s=0.0; for (std::uint64_t i=0;i<cv.size();++i) s += cp[i];
|
||||||
|
CHECK(std::fabs(s-10.0) < 1e-12, "read via const data()");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vector_Equality_and_NearlyEqual) {
|
||||||
|
utils::Vd a(3, 1.0), b(3, 1.0);
|
||||||
|
CHECK(a==b, "operator== equal");
|
||||||
|
b[1] += 5e-7;
|
||||||
|
CHECK(a==b, "operator== within eps (1e-6)");
|
||||||
|
b[1] += 2e-6;
|
||||||
|
CHECK(!(a==b), "operator== beyond eps");
|
||||||
|
|
||||||
|
utils::Vd c = a;
|
||||||
|
c[2] += 1e-10;
|
||||||
|
CHECK(c.nearly_equal_vec(a, 1e-9), "nearly_equal_vec within tol");
|
||||||
|
c[2] += 1e-6;
|
||||||
|
CHECK(!c.nearly_equal_vec(a, 1e-9), "nearly_equal_vec beyond tol");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_CASE(Vector_Scalar_Arithmetic) {
|
TEST_CASE(Vector_Scalar_Arithmetic) {
|
||||||
Vf a(3, 1.0f);
|
utils::Vi v(3, 10); // [10,10,10]
|
||||||
|
|
||||||
// inplace
|
auto vadd = v + 2; // [12,12,12]
|
||||||
a.inplace_add(2); // int convertible to float
|
CHECK(vadd[0]==12 && vadd[2]==12, "scalar +");
|
||||||
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_add failed");
|
|
||||||
|
|
||||||
a.inplace_subtract(1.5f);
|
v += 3; // [13,13,13]
|
||||||
CHECK(std::fabs(a[0] - 1.5f) < 1e-6f &&
|
CHECK(v[1]==13, "scalar +=");
|
||||||
std::fabs(a[1] - 1.5f) < 1e-6f &&
|
|
||||||
std::fabs(a[2] - 1.5f) < 1e-6f, "inplace_subtract failed");
|
|
||||||
|
|
||||||
a.inplace_multiply(4.0);
|
auto vsub = v - 5; // [8,8,8]
|
||||||
CHECK(a[0] == 6.0f && a[1] == 6.0f && a[2] == 6.0f, "inplace_multiply failed");
|
CHECK(vsub[2]==8, "scalar -");
|
||||||
|
|
||||||
a.inplace_divide(2);
|
v -= 3; // [10,10,10]
|
||||||
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_divide failed");
|
CHECK(v[0]==10, "scalar -=");
|
||||||
|
|
||||||
// returning
|
auto vmul = v * 2; // [20,20,20]
|
||||||
auto b = a + 1.0f;
|
CHECK(vmul[0]==20 && vmul[1]==20, "scalar *");
|
||||||
CHECK(b[0] == 4.0f && b[1] == 4.0f && b[2] == 4.0f, "operator+(scalar) failed");
|
|
||||||
|
|
||||||
b = a - 2.0f;
|
v *= 3; // [30,30,30]
|
||||||
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "operator-(scalar) failed");
|
CHECK(v[2]==30, "scalar *=");
|
||||||
|
|
||||||
b = a * 10; // int -> float
|
auto vdiv = v / 3; // [10,10,10]
|
||||||
CHECK(b[0] == 30.0f && b[1] == 30.0f && b[2] == 30.0f, "operator*(scalar) failed");
|
CHECK(vdiv[0]==10 && vdiv[1]==10, "scalar /");
|
||||||
|
|
||||||
b = a / 3.0f;
|
v /= 2; // [15,15,15]
|
||||||
CHECK(std::fabs(b[0] - 1.0f) < 1e-6f &&
|
CHECK(v[0]==15 && v[2]==15, "scalar /=");
|
||||||
std::fabs(b[1] - 1.0f) < 1e-6f &&
|
|
||||||
std::fabs(b[2] - 1.0f) < 1e-6f, "operator/(scalar) failed");
|
|
||||||
|
|
||||||
// scalar on the left (friends implemented for + and *)
|
|
||||||
Vf c(3, 2.0f);
|
|
||||||
auto d = 5 + c; // friend operator+(U, Vector<T>)
|
|
||||||
CHECK(d[0] == 7.0f && d[1] == 7.0f && d[2] == 7.0f, "scalar + vector failed");
|
|
||||||
|
|
||||||
d = 3 * c; // friend operator*(U, Vector<T>)
|
|
||||||
CHECK(d[0] == 6.0f && d[1] == 6.0f && d[2] == 6.0f, "scalar * vector failed");
|
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// ---------- Vector arithmetic: +, -, *, / (elementwise) ----------
|
|
||||||
//
|
|
||||||
TEST_CASE(Vector_Vector_Arithmetic) {
|
TEST_CASE(Vector_Vector_Arithmetic) {
|
||||||
Vd a(3, 1.0), b(3, 2.0);
|
utils::Vi a(4, 1), b(4, 2);
|
||||||
|
|
||||||
// returning
|
auto c = a + b; // [3,3,3,3]
|
||||||
auto c = a + b;
|
CHECK(c[0]==3 && c[3]==3, "v+v");
|
||||||
CHECK(c[0]==3.0 && c[1]==3.0 && c[2]==3.0, "vec + vec failed");
|
|
||||||
|
|
||||||
c = b - a;
|
a += b; // [3,3,3,3]
|
||||||
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec - vec failed");
|
CHECK(a[1]==3, "v+=v");
|
||||||
|
|
||||||
c = a * b;
|
auto d = a - b; // [1,1,1,1]
|
||||||
CHECK(c[0]==2.0 && c[1]==2.0 && c[2]==2.0, "vec * vec failed");
|
CHECK(d[2]==1, "v-v");
|
||||||
|
|
||||||
c = b / b;
|
a -= b; // [1,1,1,1]
|
||||||
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec / vec failed");
|
CHECK(a[0]==1, "v-=v");
|
||||||
|
|
||||||
// inplace
|
auto e = a * b; // [2,2,2,2]
|
||||||
a = Vd(3, 1.0);
|
CHECK(e[1]==2, "v*v (elemwise)");
|
||||||
a += b;
|
|
||||||
CHECK(a[0]==3.0 && a[1]==3.0 && a[2]==3.0, "inplace vec + vec failed");
|
|
||||||
a -= b;
|
|
||||||
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec - vec failed");
|
|
||||||
a *= b;
|
|
||||||
CHECK(a[0]==2.0 && a[1]==2.0 && a[2]==2.0, "inplace vec * vec failed");
|
|
||||||
a /= b;
|
|
||||||
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec / vec failed");
|
|
||||||
}
|
|
||||||
//
|
|
||||||
// ---------- Size mismatch error paths ----------
|
|
||||||
//
|
|
||||||
TEST_CASE(Vector_SizeMismatch_Throws) {
|
|
||||||
Vd a(3, 1.0), b(4, 2.0);
|
|
||||||
|
|
||||||
bool threw = false;
|
a *= b; // [2,2,2,2]
|
||||||
try { auto c = a + b; (void)c; } catch (const std::runtime_error&) { threw = true; }
|
CHECK(a[3]==2, "v*=v");
|
||||||
CHECK(threw, "add should throw on size mismatch");
|
|
||||||
|
|
||||||
threw = false;
|
auto f = e / b; // [1,1,1,1]
|
||||||
try { a.inplace_subtract(b); } catch (const std::runtime_error&) { threw = true; }
|
CHECK(f[0]==1 && f[3]==1, "v/v (elemwise)");
|
||||||
CHECK(threw, "inplace_subtract should throw on size mismatch");
|
|
||||||
|
|
||||||
threw = false;
|
e /= b; // [1,1,1,1]
|
||||||
try { auto d = a * b; (void)d; } catch (const std::runtime_error&) { threw = true; }
|
CHECK(e[2]==1, "v/=v");
|
||||||
CHECK(threw, "multiply should throw on size mismatch");
|
|
||||||
|
|
||||||
threw = false;
|
|
||||||
try { auto s = a.dot(b); (void)s; } catch (const std::runtime_error&) { threw = true; }
|
|
||||||
CHECK(threw, "dot should throw on size mismatch");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
TEST_CASE(Vector_Friend_Scalar_Left) {
|
||||||
// ---------- Power / sqrt ----------
|
utils::Vd v(3, 2.0); // [2,2,2]
|
||||||
//
|
auto s1 = 3.0 + v; // [5,5,5]
|
||||||
TEST_CASE(Vector_Power_Sqrt) {
|
CHECK(s1[0]==5.0 && s1[2]==5.0, "left scalar +");
|
||||||
Vd a(3, 2.0); // [2,2,2]
|
|
||||||
auto b = a.power(3.0); // [8,8,8]
|
|
||||||
CHECK(b[0]==8.0 && b[1]==8.0 && b[2]==8.0, "scalar power failed");
|
|
||||||
|
|
||||||
Vd p(3, 3.0); // [3,3,3]
|
auto s2 = 4.0 * v; // [8,8,8]
|
||||||
auto c = b.power(p); // 8^3 = 512
|
CHECK(s2[1]==8.0, "left scalar *");
|
||||||
CHECK(c[0]==512.0 && c[1]==512.0 && c[2]==512.0, "vector power failed");
|
|
||||||
|
|
||||||
Vd d(3, 9.0);
|
|
||||||
auto e = d.sqrt(); // [3,3,3]
|
|
||||||
CHECK(e[0]==3.0 && e[1]==3.0 && e[2]==3.0, "sqrt failed");
|
|
||||||
|
|
||||||
// inplace
|
|
||||||
d.inplace_sqrt(); // becomes [3,3,3]
|
|
||||||
CHECK(d == e, "inplace_sqrt failed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
TEST_CASE(Vector_Power_and_Sqrt) {
|
||||||
// ---------- Dot / Sum / Norm / Normalize ----------
|
utils::Vd v(3, 4.0); // [4,4,4]
|
||||||
//
|
auto p = v.power(2.0); // [16,16,16]
|
||||||
TEST_CASE(Vector_Dot_Sum_Norm_Normalize) {
|
CHECK(p[0]==16.0 && p[2]==16.0, "power scalar");
|
||||||
Vd a(3, 0.0);
|
|
||||||
a[0]=1.0; a[1]=2.0; a[2]=2.0;
|
|
||||||
|
|
||||||
CHECK(a.sum() == 5.0, "sum failed");
|
v.inplace_sqrt(); // sqrt([4,4,4]) -> [2,2,2]
|
||||||
CHECK(a.dot(a) == 9.0, "dot self failed");
|
CHECK(v[0]==2.0 && v[1]==2.0, "inplace_sqrt");
|
||||||
|
|
||||||
auto n = a.norm();
|
|
||||||
CHECK(std::fabs(n - 3.0) < 1e-12, "norm failed");
|
|
||||||
|
|
||||||
auto b = a.normalize();
|
|
||||||
CHECK(std::fabs(b.norm() - 1.0) < 1e-12, "normalize() not unit");
|
|
||||||
|
|
||||||
// inplace normalize
|
|
||||||
a.inplace_normalize();
|
|
||||||
CHECK(std::fabs(a.norm() - 1.0) < 1e-12, "inplace_normalize not unit");
|
|
||||||
|
|
||||||
// zero-norm error
|
|
||||||
Vd z(3, 0.0);
|
|
||||||
bool threw = false;
|
|
||||||
try { z.inplace_normalize(); } catch (const std::runtime_error&) { threw = true; }
|
|
||||||
CHECK(threw, "normalize zero vector must throw");
|
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// ---------- Stream output (basic sanity) ----------
|
TEST_CASE(Vector_Dot_Sum_Norm) {
|
||||||
//
|
utils::Vd a(3, 0.0), b(3, 0.0);
|
||||||
TEST_CASE(Vector_StreamOutput) {
|
a[0]=1.0; a[1]=2.0; a[2]=3.0; // a = [1,2,3]
|
||||||
Vi a(3, 2);
|
b[0]=4.0; b[1]=5.0; b[2]=6.0; // b = [4,5,6]
|
||||||
std::ostringstream oss;
|
|
||||||
oss << a;
|
double dot = a.dot(b); // 1*4 + 2*5 + 3*6 = 32
|
||||||
auto s = oss.str();
|
CHECK(std::fabs(dot - 32.0) < 1e-12, "dot");
|
||||||
CHECK(s == "[2, 2, 2]", "ostream<< wrong format");
|
|
||||||
|
double s = a.sum(); // 6
|
||||||
|
CHECK(std::fabs(s - 6.0) < 1e-12, "sum");
|
||||||
|
|
||||||
|
double n = a.norm(); // sqrt(14)
|
||||||
|
CHECK(std::fabs(n - std::sqrt(14.0)) < 1e-12, "norm");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vector_Normalize_and_Throws) {
|
||||||
|
utils::Vd v(3, 0.0);
|
||||||
|
v[0]=3.0; v[1]=4.0; v[2]=0.0; // norm = 5
|
||||||
|
auto u = v.normalize(); // returns new vector
|
||||||
|
CHECK(std::fabs(u.norm() - 1.0) < 1e-12, "normalize() unit length");
|
||||||
|
|
||||||
|
v.inplace_normalize();
|
||||||
|
CHECK(std::fabs(v.norm() - 1.0) < 1e-12, "inplace_normalize unit length");
|
||||||
|
|
||||||
|
utils::Vd z(3, 0.0);
|
||||||
|
bool threw=false;
|
||||||
|
try { z.inplace_normalize(); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "normalize should throw on zero vector");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size mismatch throws (elementwise ops)
|
||||||
|
TEST_CASE(Vector_Size_Mismatch_Throws) {
|
||||||
|
utils::Vi a(3,1), b(4,2);
|
||||||
|
|
||||||
|
bool threw=false;
|
||||||
|
try { (void)a.dot(b); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "dot size mismatch should throw");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try { a.inplace_add(b); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "add size mismatch should throw");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try { a.inplace_subtract(b); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "subtract size mismatch should throw");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try { a.inplace_multiply(b); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "multiply size mismatch should throw");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try { a.inplace_divide(b); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "divide size mismatch should throw");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try { a.inplace_power(b); } catch(const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "power size mismatch should throw");
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user