Finishing up and starting lu decomp

This commit is contained in:
2025-09-13 21:44:20 +02:00
parent 320436ce98
commit 88087ea6a6
24 changed files with 1502 additions and 699 deletions
BIN
View File
Binary file not shown.
+42 -23
View File
@@ -4,31 +4,50 @@
#include <omp.h>
// Configure OpenMP behavior at runtime.
inline void omp_configure(int max_active_levels,
bool dynamic_threads,
const std::vector<int>& threads_per_level = {},
bool bind_close = true)
{
// 1) Allow nested parallel regions (levels of teams)
// Example: outer #pragma omp parallel ... { inner #pragma omp parallel ... }
omp_set_max_active_levels(max_active_levels); // 1 = only top-level; 2+ enables nesting
namespace omp_config{
// 2) Let the runtime shrink/grow thread counts if it thinks it should
// (helps avoid oversubscription when you accidentally ask for too many threads)
omp_set_dynamic(dynamic_threads ? 1 : 0);
// Configure OpenMP behavior at runtime.
inline void omp_configure(int max_active_levels,
bool dynamic_threads,
const std::vector<int>& threads_per_level = {},
bool bind_close = true)
{
// 1) Allow nested parallel regions (levels of teams)
// Example: outer #pragma omp parallel ... { inner #pragma omp parallel ... }
omp_set_max_active_levels(max_active_levels); // 1 = only top-level; 2+ enables nesting
// 3) Thread binding (keep threads near their cores) is controlled via env vars,
// so here we just *recommend* a good default (see below). You *can* setenv()
// in code, but its cleaner to do it outside the program.
(void)bind_close; // documented below in env var section
// 2) Let the runtime shrink/grow thread counts if it thinks it should
// (helps avoid oversubscription when you accidentally ask for too many threads)
omp_set_dynamic(dynamic_threads ? 1 : 0);
// 4) Top-level default thread count (inner levels are usually set per region)
if (!threads_per_level.empty()) {
omp_set_num_threads(threads_per_level[0]); // e.g. 16 for the outermost team
// Inner levels:
// - Use num_threads(threads_per_level[L]) on the inner #pragma omp parallel
// - or set OMP_NUM_THREADS="outer,inner,inner2" as an environment variable
// 3) Thread binding (keep threads near their cores) is controlled via env vars,
// so here we just *recommend* a good default (see below). You *can* setenv()
// in code, but its cleaner to do it outside the program.
(void)bind_close; // documented below in env var section
// 4) Top-level default thread count (inner levels are usually set per region)
if (!threads_per_level.empty()) {
omp_set_num_threads(threads_per_level[0]); // e.g. 16 for the outermost team
// Inner levels:
// - Use num_threads(threads_per_level[L]) on the inner #pragma omp parallel
// - or set OMP_NUM_THREADS="outer,inner,inner2" as an environment variable
}
}
}
// ---------- Helper: may we create another team? ----------
inline bool omp_parallel_allowed() {
#ifdef _OPENMP
// If were not in parallel, we can spawn.
if (!omp_in_parallel()) return true;
// Already inside parallel: allow only if nesting is enabled and not at limit.
int level = omp_get_active_level(); // 0 outside parallel, 1 inside, ...
int maxlv = omp_get_max_active_levels(); // user/runtime cap
return static_cast<bool>(level < maxlv);
#else
return false; // no OpenMP → no extra teams
#endif
}
} // namespace omp_config
View File
+4
View File
@@ -0,0 +1,4 @@
#pragma once
#include "./decomp/lu.h"
+61
View File
@@ -0,0 +1,61 @@
#pragma once
#include "./utils/vector.h"
#include "./utils/matrix.h"
namespace decomp{
// Stores PA = LU with partial pivoting (row permutations).
template <typename T>
struct LU{
utils::Matrix<T> LUmat; // combined L (unit diagonal implied) and U
std::vector<uint64_t> piv; // pivot indices (row permutations)
bool singular= false; // set true if zero (or near-zero) pivots encountered
LU() = default;
explicit LU(const utils::Matrix<T>& A) {
factor(A);
}
void factor(const utils::Matrix<T>&A){
const uint64_t rows = A.rows();
if (rows != A.cols()){
throw std::runtime_error("LU: factor non-square");
}
if (rows == 0){
LUmat = A;
piv.clear();
singular = false;
return;
}
T big, tmp;
LUmat = A;
piv.resize(rows); // piv stores the implicit scaling of each row.
//double d = 1.0; // No row interchanges yet.
for (uint64_t i = 0; i < rows; ++i){ // Loop over rows to get the implicit scaling information.
big = T{0};
for (uint64_t j = 0; j < rows; ++j){
tmp=std::abs(LUmat(i,j));
if (tmp > big){
big = tmp;
}
}
if (big == T{0}){
throw std::runtime_error("Singular matrix in LU.factor()");
}
}
}
}; // struct LU
typedef LU<float> LUf;
typedef LU<double> LUd;
} // namespace decomp
+17 -78
View File
@@ -5,100 +5,39 @@
#include "./utils/vector.h"
#include "./utils/matrix.h"
#include "./numerics/inverse/inverse_gauss_jordan.h"
#include "./numerics/inverse/inverse_lu.h"
#include <omp.h>
namespace numerics{
template <typename T>
void inplace_inverse(utils::Matrix<T>& A, std::string method = "Gauss-Jordan"){
if (A.rows() != A.cols()) {
throw std::runtime_error("inplace_inverse: non-square matrix");
}
if (method == "Gauss-Jordan"){
utils::Matrix<T> B(A.rows(),A.cols(), T{0});
uint64_t icol{0}, irow{0}, rows{A.rows()}, cols{A.cols()};
double big, dum, pivinv, temp;
utils::Vi indxc(rows,0), indxr(rows,0), ipiv(rows,0);
//for (uint64_t j = 0; j < N; ++j){ ipiv[j] = 0;}
for (uint64_t i = 0; i < rows; i++){
big = 0.0;
for (uint64_t j = 0; j < rows; j++){
if (ipiv[j] != 1){
for (uint64_t k = 0; k < rows; k++){
if (ipiv[k] == 0){
if (abs(A(j,k)) >= big){
big = abs(A(j,k));
irow = j;
icol = k;
}
}
}
}
}
ipiv[icol]++;
if (irow != icol){
for (uint64_t l = 0; l < rows; l++){ // SWAP
temp = A(irow,l);
A(irow,l) = A(icol,l);
A(icol,l) = temp;
}
for (uint64_t l = 0; l < cols; l++){ // SWAP temp matrix
temp = B(irow,l);
B(irow,l) = B(icol,l);
B(icol,l) = temp;
}
}
indxr[i] = irow;
indxc[i] = icol;
if (A(icol,icol) == 0.0){
throw std::runtime_error("utill:inplace_inverse('Gauss-Jordan' - Singular Matrix");
}
pivinv= 1.0/A(icol,icol);
A(icol,icol)=1.0;
for (uint64_t l = 0; l < rows; l++){
A(icol,l) *= pivinv;
}
for (uint64_t l = 0; l < cols; l++){
B(icol,l) *= pivinv;
}
for (uint64_t ll = 0; ll < rows; ll++){
if (ll != icol){
dum = A(ll,icol);
A(ll,icol) = 0;
for (uint64_t l = 0; l < rows; l++){
A(ll,l) -= A(icol,l)*dum;
}
for (uint64_t l = 0; l < rows; l++){
B(ll,l) -= B(icol,l)*dum;
}
}
}
}
//m = temp_m;
for (int64_t l = rows-1; l >= 0; l--){
if (indxr[l] != indxc[l]){
for (uint64_t k = 0; k < rows; k++){
temp = A(k,indxr[l]);
A(k,indxr[l]) = A(k,indxc[l]);
A(k,indxc[l]) = temp;
}
}
}
}
inverse_gj(A);
}
else{
throw std::runtime_error("numerics::inplace_inverse(" + method + ") - Not implemented yet \r \nImplemented: 'Gauss-Jordan',");
}
}
}
template <typename T>
utils::Matrix<T> inverse(utils::Matrix<T>& A, std::string method = "Gauss-Jordan"){
utils::Matrix<T> B = A;
inplace_inverse(B, method);
return B;
}
@@ -0,0 +1,94 @@
#ifndef _inverse_gj_n_
#define _inverse_gj_n_
#include "./utils/vector.h"
#include "./utils/matrix.h"
#include <omp.h>
namespace numerics{
template <typename T>
void inverse_gj(utils::Matrix<T>& A){
utils::Matrix<T> B(A.rows(),A.cols(), T{0});
uint64_t icol{0}, irow{0}, rows{A.rows()}, cols{A.cols()};
double big, dum, pivinv, temp;
utils::Vi indxc(rows,0), indxr(rows,0), ipiv(rows,0);
//for (uint64_t j = 0; j < N; ++j){ ipiv[j] = 0;}
for (uint64_t i = 0; i < rows; i++){
big = 0.0;
for (uint64_t j = 0; j < rows; j++){
if (ipiv[j] != 1){
for (uint64_t k = 0; k < rows; k++){
if (ipiv[k] == 0){
if (abs(A(j,k)) >= big){
big = abs(A(j,k));
irow = j;
icol = k;
}
}
}
}
}
ipiv[icol]++;
if (irow != icol){
for (uint64_t l = 0; l < rows; l++){ // SWAP
temp = A(irow,l);
A(irow,l) = A(icol,l);
A(icol,l) = temp;
}
for (uint64_t l = 0; l < cols; l++){ // SWAP temp matrix
temp = B(irow,l);
B(irow,l) = B(icol,l);
B(icol,l) = temp;
}
}
indxr[i] = irow;
indxc[i] = icol;
if (A(icol,icol) == 0.0){
throw std::runtime_error("utill:inplace_inverse('Gauss-Jordan' - Singular Matrix");
}
pivinv= 1.0/A(icol,icol);
A(icol,icol)=1.0;
for (uint64_t l = 0; l < rows; l++){
A(icol,l) *= pivinv;
}
for (uint64_t l = 0; l < cols; l++){
B(icol,l) *= pivinv;
}
for (uint64_t ll = 0; ll < rows; ll++){
if (ll != icol){
dum = A(ll,icol);
A(ll,icol) = 0;
for (uint64_t l = 0; l < rows; l++){
A(ll,l) -= A(icol,l)*dum;
}
for (uint64_t l = 0; l < rows; l++){
B(ll,l) -= B(icol,l)*dum;
}
}
}
}
//m = temp_m;
for (int64_t l = rows-1; l >= 0; l--){
if (indxr[l] != indxc[l]){
for (uint64_t k = 0; k < rows; k++){
temp = A(k,indxr[l]);
A(k,indxr[l]) = A(k,indxc[l]);
A(k,indxc[l]) = temp;
}
}
}
}
} // namespace numerics
#endif // _inverse_gj_n_
+83 -4
View File
@@ -3,10 +3,12 @@
#include "./utils/matrix.h"
#include "./core/omp_config.h"
namespace numerics{
// ---------------- Serial baseline ----------------
template <typename T>
utils::Matrix<T> matmul(const utils::Matrix<T>& A, const utils::Matrix<T>& B){
@@ -19,10 +21,8 @@ namespace numerics{
const uint64_t p = B.cols();
T tmp;
utils::Matrix<T> C(m, n, T{0});
utils::Matrix<T> C(m, p, T{0});
//#pragma omp parallel for collapse(2) schedule(static)
#pragma omp parallel for
for (uint64_t i = 0; i < m; ++i){
for (uint64_t j = 0; j < n; ++j){
tmp = A(i,j);
@@ -34,6 +34,85 @@ namespace numerics{
return C;
}
// ---------------- Rows-only OpenMP ----------------
template <typename T>
utils::Matrix<T> matmul_rows_omp(const utils::Matrix<T>& A,
const utils::Matrix<T>& B) {
if (A.cols() != B.rows()) throw std::runtime_error("matmul_rows_omp: dim mismatch");
const uint64_t m=A.rows(), n=A.cols(), p=B.cols();
utils::Matrix<T> C(m, p, T{0});
#pragma omp parallel for schedule(static)
for (uint64_t i=0;i<m;++i) {
for (uint64_t j=0;j<p;++j) {
T acc=T{0};
for (uint64_t k=0;k<n;++k) {
acc += A(i,k)*B(k,j);
}
C(i,j)=acc;
}
}
return C;
}
// -------------- Collapse(2) OpenMP ----------------
template <typename T>
utils::Matrix<T> matmul_collapse_omp(const utils::Matrix<T>& A,
const utils::Matrix<T>& B) {
if (A.cols() != B.rows()) throw std::runtime_error("matmul_collapse_omp: dim mismatch");
const uint64_t m=A.rows(), n=A.cols(), p=B.cols();
utils::Matrix<T> C(m, p, T{0});
#pragma omp parallel for collapse(2) schedule(static)
for (uint64_t i=0;i<m;++i) {
for (uint64_t j=0;j<p;++j) {
T acc=T{0};
for (uint64_t k=0;k<n;++k){
acc += A(i,k)*B(k,j);
}
C(i,j)=acc;
}
}
return C;
}
// -------------------- Auto selector ---------------------
template <typename T>
utils::Matrix<T> matmul_auto(const utils::Matrix<T>& A,
const utils::Matrix<T>& B) {
const uint64_t m=A.rows(), p=B.cols();
const uint64_t work = m * p;
bool can_parallel = omp_config::omp_parallel_allowed();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
// Tiny problems: serial is cheapest.
if (!can_parallel || work < static_cast<uint64_t>(threads)*4ull) {
return matmul(A,B);
}
// Plenty of (i,j) work → collapse(2) is a great default.
else if (work >= 8ull * static_cast<uint64_t>(threads)) {
return matmul_collapse_omp(A,B);
}
// Many rows and very few columns → rows-only cheaper overhead.
else if (m >= static_cast<uint64_t>(threads) && p <= 4) {
return matmul_rows_omp(A,B);
}
else{
// Safe fallback
return matmul(A,B);
}
}
+97 -3
View File
@@ -3,11 +3,13 @@
#include "./utils/matrix.h"
#include "./core/omp_config.h"
namespace numerics{
// y = A * x, where A is (m×n) and x is length n and y is length m
// =================================================
// y = A * x (MatrixVector product)
// =================================================
template <typename T>
utils::Vector<T> matvec(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
if (A.cols() != x.size()) {
@@ -18,6 +20,27 @@ namespace numerics{
const uint64_t n = A.cols();
utils::Vector<T> y(m, T{0});
for (uint64_t i = 0; i < m; ++i) {
for (uint64_t j = 0; j < n; ++j) {
y[i] += A(i, j) * x[j];
}
}
return y;
}
// -------------- Collapse(2) OpenMP ----------------
template <typename T>
utils::Vector<T> matvec_omp(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
if (A.cols() != x.size()) {
throw std::runtime_error("matvec: dimension mismatch");
}
const uint64_t m = A.rows();
const uint64_t n = A.cols();
utils::Vector<T> y(m, T{0});
#pragma omp parallel for schedule(static)
for (uint64_t i = 0; i < m; ++i) {
T acc = T{0};
for (uint64_t j = 0; j < n; ++j) {
@@ -28,7 +51,34 @@ namespace numerics{
return y;
}
// y = x * A, where x is length m and A is (m×n) -> y is length n
// -------------- Auto OpenMP ----------------
template <typename T>
utils::Vector<T> matvec_auto(const utils::Matrix<T>& A,
const utils::Vector<T>& x) {
uint64_t work = A.rows() * A.cols();
bool can_parallel = omp_config::omp_parallel_allowed();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
if (can_parallel || work > static_cast<uint64_t>(threads) * 4ull) {
return matvec_omp(A,x);
}
else{
// Safe fallback
return matvec(A,x);
}
}
// =================================================
// y = x * A (VectorMatrix product)
// =================================================
template <typename T>
utils::Vector<T> vecmat(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
if (x.size() != A.rows()) {
@@ -38,6 +88,26 @@ namespace numerics{
const uint64_t n = A.cols();
utils::Vector<T> y(n, T{0});
for (uint64_t j = 0; j < n; ++j) {
for (uint64_t i = 0; i < m; ++i) {
y[j] += x[i] * A(i, j);
}
}
return y;
}
// -------------- Collapse(2) OpenMP ----------------
template <typename T>
utils::Vector<T> vecmat_omp(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
if (x.size() != A.rows()) {
throw std::runtime_error("vecmat: dimension mismatch");
}
const uint64_t m = A.rows();
const uint64_t n = A.cols();
utils::Vector<T> y(n, T{0});
#pragma omp parallel for schedule(static)
for (uint64_t j = 0; j < n; ++j) {
T acc = T{0};
for (uint64_t i = 0; i < m; ++i) {
@@ -48,6 +118,30 @@ namespace numerics{
return y;
}
// -------------- Auto OpenMP ----------------
template <typename T>
utils::Vector<T> vecmat_auto(const utils::Vector<T>& x,
const utils::Matrix<T>& A) {
uint64_t work = A.rows() * A.cols();
bool can_parallel = omp_config::omp_parallel_allowed();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
if (can_parallel || work > static_cast<uint64_t>(threads) * 4ull) {
return vecmat_omp(x,A);
}
else{
// Safe fallback
return vecmat(x,A);
}
}
} // namespace numerics
-22
View File
@@ -43,28 +43,6 @@ namespace numerics{
}
} // namespace numerics
#endif // _transpose_n_
-11
View File
@@ -1,11 +0,0 @@
#ifndef _numerics_n_
#define _numerics_n_
namespace utils{
double random(const double& min, const double& max);
}
#endif // _numerics_n_
+14
View File
@@ -66,6 +66,20 @@ public:
bool operator!=(const Vector<T>& a) const { return !(*this == a); }
//##################################################
//# VECTOR: nearly_equal_vec #
//##################################################
bool nearly_equal_vec(const Vector<T>& a, double tol=1e-12) const {
if (a.size() != v.size()) return false;
for (uint64_t i=0;i<a.size();++i) {
if (std::fabs(a[i]-v[i])>tol) {
return false;
}
}
return true;
}
//##################################################
//# VECTOR: Scalar Addition #
//##################################################
+62 -5
View File
@@ -10,6 +10,11 @@ SRC_DIR := src
INC_DIR := include
OBJ_DIR := obj
BIN_DIR := bin
TEST_BIN := $(BIN_DIR)/tests
# All test sources
#
@@ -23,13 +28,30 @@ SRCS := $(shell find $(SRC_DIR) -name '*.cpp')
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(SRCS))
# === Test sources ===
TEST_SRCS := $(shell find test -name 'test_*.cpp')
TEST_OBJS := $(patsubst test/%.cpp, $(OBJ_DIR)/test/%.o, $(TEST_SRCS))
# The single file that defines TEST_MAIN / main()
TEST_MAIN := $(OBJ_DIR)/test/test_all.o
# === OpenMP runtime configuration (override-able) ===
OMP_PROC_BIND ?= close # close|spread|master
OMP_PLACES ?= cores # cores|threads|sockets
OMP_MAX_LEVELS ?= 1 # 1 = no nested teams; set 2+ to allow nesting
OMP_MAX_LEVELS ?= 2 # 1 = no nested teams; set 2+ to allow nesting
OMP_THREADS ?= 16 # e.g. "16" or "8,4" for nested (outer,inner)
OMP_DYNAMIC ?= true # true/false: let runtime adjust threads
OMP_DISPLAY_ENV ?= FALSE # TRUE to print runtime config at startup
OMP_DYNAMIC ?= TRUE # TRUE/FALSE: let runtime adjust threads
OMP_SCHEDULE ?= STATIC # STATIC recommended for matvec/matmul
OMP_DISPLAY_ENV ?= TRUE # TRUE to print runtime config at startup
# Export OMP defaults so child makes or tools see them (not strictly required)
export OMP_PROC_BIND
export OMP_PLACES
export OMP_MAX_LEVELS
export OMP_THREADS
export OMP_DYNAMIC
export OMP_SCHEDULE
export OMP_DISPLAY_ENV
@@ -49,11 +71,19 @@ $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp
# === Run with OpenMP env set only for the run ===
.PHONY: run
run: $(TARGET)
OMP_PROC_BIND=$(OMP_PROC_BIND) \
@echo ">>> OMP_PROC_BIND=$(OMP_PROC_BIND)"
@echo ">>> OMP_PLACES=$(OMP_PLACES)"
@echo ">>> OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS)"
@echo ">>> OMP_NUM_THREADS=$(OMP_THREADS)"
@echo ">>> OMP_DYNAMIC=$(OMP_DYNAMIC)"
@echo ">>> OMP_SCHEDULE=$(OMP_SCHEDULE)"
@echo ">>> OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV)"
@OMP_PROC_BIND=$(OMP_PROC_BIND) \
OMP_PLACES=$(OMP_PLACES) \
OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS) \
OMP_NUM_THREADS="$(OMP_THREADS)" \
OMP_DYNAMIC=$(OMP_DYNAMIC) \
OMP_SCHEDULE=$(OMP_SCHEDULE) \
OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV) \
./$(TARGET)
@@ -77,4 +107,31 @@ info:
@echo "Source files: $(SRCS)"
@echo "Object files: $(OBJS)"
@echo "CXXFLAGS: $(CXXFLAGS)"
@echo "LDFLAGS: $(LDFLAGS)"
@echo "LDFLAGS: $(LDFLAGS)"
.PHONY: test
test: $(TEST_BIN)
@echo ">>> OMP_PROC_BIND=$(OMP_PROC_BIND)"
@echo ">>> OMP_PLACES=$(OMP_PLACES)"
@echo ">>> OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS)"
@echo ">>> OMP_NUM_THREADS=$(OMP_THREADS)"
@echo ">>> OMP_DYNAMIC=$(OMP_DYNAMIC)"
@echo ">>> OMP_SCHEDULE=$(OMP_SCHEDULE)"
@echo ">>> OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV)"
@OMP_PROC_BIND=$(OMP_PROC_BIND) \
OMP_PLACES=$(OMP_PLACES) \
OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS) \
OMP_NUM_THREADS="$(OMP_THREADS)" \
OMP_DYNAMIC=$(OMP_DYNAMIC) \
OMP_SCHEDULE=$(OMP_SCHEDULE) \
OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV) \
$(TEST_BIN)
$(TEST_BIN): $(TEST_OBJS) $(TEST_MAIN)
@mkdir -p $(BIN_DIR)
$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
$(OBJ_DIR)/test/%.o: test/%.cpp
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) -c $< -o $@
BIN
View File
Binary file not shown.
+6 -553
View File
@@ -1,5 +1,6 @@
#include "./utils/utils.h"
#include "./numerics/numerics.h"
#include "./decomp/decomp.h"
#include "./core/omp_config.h"
#include <iostream>
@@ -7,566 +8,18 @@
#define CHECK(cond, msg) \
do { if (!(cond)) throw std::runtime_error(msg); } while (0)
template <typename F>
void expect_throw(F&& f, const char* msg_if_no_throw) {
try {
f();
throw std::runtime_error(msg_if_no_throw);
} catch (const std::exception&) {
// ok: an exception was thrown
}
}
int main(int argc, char const *argv[])
{
utils::Md A;
decomp::LUd lu;
lu.factor(A);
// Single-level, 16 threads, runtime may adjust
omp_configure(/*max_levels=*/1, /*dynamic=*/true, /*threads_per_level=*/{16});
//omp_configure(/*max_levels=*/1, /*dynamic=*/true, /*threads_per_level=*/{16});
using utils::Vi;
using utils::Vf;
using utils::Vd;
using utils::Mi;
using utils::Mf;
using utils::Md;
// ---------------- Equality / Inequality ----------------
{
Vf a(3, 1.0f); // [1,1,1]
Vf b(3, 1.0f); // [1,1,1]
Vf c(3, 2.0f); // [2,2,2]
CHECK(a == b, "a should equal b");
CHECK(!(a != b), "a should not be != b");
CHECK(a != c, "a should not equal c");
// mutate one element
a[1] = 5.0f;
CHECK(a != b, "after mutation, a should differ from b");
a[1] = 1.0f; // restore
}
// ---------------- Vector + Vector (and +=) ----------------
{
Vf a(3, 1.0f); // [1,1,1]
Vf b(3, 2.0f); // [2,2,2]
Vf expect(3, 3.0f); // [3,3,3]
Vf c = a + b;
CHECK(c == expect, "a + b should be [3,3,3]");
a += b; // a becomes [3,3,3]
CHECK(a == expect, "a += b should produce [3,3,3]");
}
// ---------------- Vector - Vector (and -=) ----------------
{
Vf a(3, 5.0f); // [5,5,5]
Vf b(3, 2.0f); // [2,2,2]
Vf expect(3, 3.0f); // [3,3,3]
Vf c = a - b;
CHECK(c == expect, "a - b should be [3,3,3]");
a -= b; // a becomes [3,3,3]
CHECK(a == expect, "a -= b should produce [3,3,3]");
}
// ---------------- Elementwise Multiply (and *=) ----------------
{
Vf a(3, 2.0f); // [2,2,2]
Vf b(3, 3.0f); // [3,3,3]
Vf expect(3, 6.0f); // [6,6,6]
Vf c = a * b;
CHECK(c == expect, "a * b (elemwise) should be [6,6,6]");
a *= b; // a becomes [6,6,6]
CHECK(a == expect, "a *= b should produce [6,6,6]");
}
// ---------------- Elementwise Divide (and /=) ----------------
{
Vf a(3, 6.0f); // [6,6,6]
Vf b(3, 2.0f); // [2,2,2]
Vf expect(3, 3.0f); // [3,3,3]
Vf c = a / b;
CHECK(c == expect, "a / b (elemwise) should be [3,3,3]");
a /= b; // a becomes [3,3,3]
CHECK(a == expect, "a /= b should produce [3,3,3]");
}
// ---------------- Scalar + Vector (and +=) ----------------
{
Vf a(3, 1.0f); // [1,1,1]
Vf expect1(3, 6.0f); // [6,6,6]
Vf expect2(3, 3.0f); // [3,3,3]
Vf c = a + 5.0f; // v + s
CHECK(c == expect1, "a + 5 should be [6,6,6]");
Vf d = 2.0f + a; // s + v (friend operator)
CHECK(d == expect2, "2 + a should be [3,3,3]");
a += 2.0f; // a becomes [3,3,3]
CHECK(a == expect2, "a += 2 should produce [3,3,3]");
}
// ---------------- Scalar - Vector / Vector - Scalar ----------------
{
Vf a(3, 5.0f); // [5,5,5]
Vf expect1(3, 3.0f); // [3,3,3]
Vf expect2(3, -3.0f); // [ -3,-3,-3 ] if 2 - a (only if you've implemented it)
Vf c = a - 2.0f; // v - s
CHECK(c == expect1, "a - 2 should be [3,3,3]");
// NOTE: Your friend operator-(U a, const Vector<T> b) currently returns (b - a),
// which means `2 - a` computes `a - 2`. That's a bit unusual.
// We'll avoid asserting 2 - a here to match your current implementation choice.
(void)expect2; // silence unused warning
}
// ---------------- Scalar * Vector / Vector * Scalar ----------------
{
Vf a(3, 2.0f); // [2,2,2]
uint64_t b = 3; // 3
Vf expect(3, 6.0f); // [6,6,6]
Vf c = a * b; // v * s
CHECK(c == expect, "a * 3 should be [6,6,6]");
Vf d = b * a; // s * v (friend)
CHECK(d == expect, "3 * a should be [6,6,6]");
a *= b; // a becomes [6,6,6]
CHECK(a == expect, "a *= 3 should produce [6,6,6]");
}
// ---------------- Vector / Scalar (and /= scalar) ----------------
{
Vf a(3, 6.0f); // [6,6,6]
uint64_t b = 2; // 3
Vf expect(3, 3.0f); // [3,3,3]
Vf c = a / b; // v / s
CHECK(c == expect, "a / 2 should be [3,3,3]");
a /= b; // a becomes [3,3,3]
CHECK(a == expect, "a /= 2 should produce [3,3,3]");
}
// ---------- sum ----------
{
Vf a(3, 2.0f); // [2,2,2]
CHECK(a.sum() == 6.0f, "sum failed");
}
// ---------- dot ----------
{
Vf a(3, 2.0f); // [2,2,2]
Vf b(3, 3.0f); // [3,3,3]
CHECK(a.dot(b) == 18.0f, "dot failed"); // 2*3 * 3 = 18
Vf c(4, 1.0f);
expect_throw([&]{ (void)a.dot(c); }, "dot should throw on size mismatch");
}
// ---------- norm ----------
{
Vf a(3, 2.0f); // [2,2,2]
float n = a.norm();
CHECK(std::fabs(n - std::sqrt(12.0f)) < 1e-6f, "norm failed");
}
// ---------- normalize ----------
{
Vf a(3, 3.0f); // [3,3,3], norm = sqrt(27)
Vf b = a.normalize();
float n = b.norm();
CHECK(std::fabs(n - 1.0f) < 1e-6f, "normalize failed");
expect_throw([&]{ Vf z(3, 0.0f); z.inplace_normalize(); }, "normalize should throw on zero norm");
}
// ---------- scalar power ----------
{
Vf a(3, 2.0f); // [2,2,2]
Vf c = a.power(3); // [8,8,8]
CHECK(c == Vf(3, 8.0f), "power(scalar) failed");
Vf d = a; d.inplace_power(4); // [16,16,16]
CHECK(d == Vf(3, 16.0f), "inplace_power(scalar) failed");
}
// ---------- vector power ----------
{
Vf base(3, 2.0f); // [2,2,2]
Vf exps; exps.v = {1.0f, 2.0f, 3.0f}; // explicit construction for clarity
Vf out = base.power(exps); // [2^1, 2^2, 2^3] = [2,4,8]
Vf expect; expect.v = {2.0f, 4.0f, 8.0f};
CHECK(out == expect, "power(vector) failed");
expect_throw([&]{ Vf bad(2, 1.0f); (void)base.power(bad); },
"power(vector) should throw on size mismatch");
}
// ---------- square ----------
{
Vf a; a.v = {4.0f, 9.0f, 16.0f};
Vf b = a.sqrt(); // [4,9,16]
Vf expect; expect.v = {2.0f, 3.0f, 4.0f};
CHECK(b == expect, "sqrt failed");
a.inplace_sqrt(); // mutate a to [4,9,16]
CHECK(a == expect, "inplace_square failed");
}
// ---------- scalar commutative friends (s + v, s * v) ----------
{
Vf a(3, 2.0f); // [2,2,2]
Vf b = 3.0f + a; // [5,5,5]
Vf c = a + 3.0f; // [5,5,5]
CHECK(b == c, "s+v commutative failed");
Vf d = 4.0f * a; // [8,8,8]
Vf e = a * 4.0f; // [8,8,8]
CHECK(d == e, "s*v commutative failed");
}
// ---------------- Size mismatch throws ----------------
{
Vf a(3, 1.0f);
Vf b(4, 2.0f);
expect_throw([&] { a.inplace_add(b); },
"inplace_add should throw on size mismatch");
expect_throw([&] { (void)(a + b); },
"operator+ should throw (through add) on size mismatch");
expect_throw([&] { a.inplace_subtract(b); },
"inplace_subtract should throw on size mismatch");
expect_throw([&] { (void)(a - b); },
"operator- should throw (through subtract) on size mismatch");
expect_throw([&] { a.inplace_multiply(b); },
"inplace_multiply should throw on size mismatch");
expect_throw([&] { (void)(a * b); },
"operator* should throw (through multiply) on size mismatch");
expect_throw([&] { a.inplace_divide(b); },
"inplace_divide should throw on size mismatch");
expect_throw([&] { (void)(a / b); },
"operator/ should throw (through divide) on size mismatch");
}
{
auto* a = new utils::Vf(3, 1.0f); // constructor runs
delete a; // <- calls ~Vector() and frees memory
}
{
Vf a(2, 1.0f); // a = [1, 1]
Vf b(2, 1.0f); // b = [1, 1]
a.clear(); // a = []
CHECK(a.size() == 0, "clear() did not empty vector");
a.resize(2, 1.0f); // a = [1, 1]
CHECK(a == b, "clear/resize lifecycle failed");
}
std::cout << "All Vector tests passed ✅\n";
// shape + element access
{
Mf M(3, 4, 0.0f);
CHECK(M.rows()==3 && M.cols()==4, "shape failed");
M(1,1) = 5.0f;
CHECK(M(1,1) == 5.0f, "write/read element failed");
// ensure independence of other cells
CHECK(M(0,0) == 0.0f && M(2,3) == 0.0f, "unexpected element modified");
}
// set/get row (with size checks)
{
Mf M(2, 3, 0.0f); // 2x3
Vf r(3, 0.0f);
r[0]=1; r[1]=2; r[2]=3;
M.set_row(1, r);
Vf g = M.get_row(1);
CHECK(g.size()==3, "get_row size wrong");
CHECK(g[0]==1 && g[1]==2 && g[2]==3, "get_row values wrong");
// size mismatch should throw
bool threw=false;
try {
Vf bad(2, 9.0f);
M.set_row(0, bad);
} catch (const std::exception&) { threw=true; }
CHECK(threw, "set_row should throw on size mismatch");
}
// set/get col (with size checks)
{
Mf M(3, 2, 0.0f); // 3x2
Vf c(3, 0.0f);
c[0]=4; c[1]=5; c[2]=6;
M.set_col(1, c);
Vf h = M.get_col(1);
CHECK(h.size()==3, "get_col size wrong");
CHECK(h[0]==4 && h[1]==5 && h[2]==6, "get_col values wrong");
bool threw=false;
try {
Vf bad(2, 7.0f);
M.set_col(0, bad);
} catch (const std::exception&) { threw=true; }
CHECK(threw, "set_col should throw on size mismatch");
}
// swap_rows / swap_cols
{
Mf M(3, 3, 0.0f);
// set rows to [1,2,3], [4,5,6], [7,8,9]
for (uint64_t j=0;j<3;++j) M(0,j) = 1.0f + j;
for (uint64_t j=0;j<3;++j) M(1,j) = 4.0f + j;
for (uint64_t j=0;j<3;++j) M(2,j) = 7.0f + j;
M.swap_rows(0,2);
CHECK(M(0,0)==7 && M(0,1)==8 && M(0,2)==9, "swap_rows top row wrong");
CHECK(M(2,0)==1 && M(2,1)==2 && M(2,2)==3, "swap_rows bottom row wrong");
M.swap_cols(0,2);
// after col swap: first row should be [9,8,7]
CHECK(M(0,0)==9 && M(0,1)==8 && M(0,2)==7, "swap_cols first row wrong");
// bottom row should be [3,2,1]
CHECK(M(2,0)==3 && M(2,1)==2 && M(2,2)==1, "swap_cols last row wrong");
}
// Exact integer comparison / Floating-point exact equality / Floating-point with small perturbation
{
Mi A(2,2,0);
A(0,0)=1; A(0,1)=2;
A(1,0)=3; A(1,1)=4;
Mi B(2,2,0);
B(0,0)=1; B(0,1)=2;
B(1,0)=3; B(1,1)=4;
Mi C(2,2,0);
C(0,0)=9; C(0,1)=9;
C(1,0)=9; C(1,1)=9;
CHECK(A == B, "Matrix == failed on identical int matrices");
CHECK(!(A != B), "Matrix != failed on identical int matrices");
CHECK(A != C, "Matrix != failed on different int matrices");
// Floating-point exact equality
Md F1(2,2,0.0);
F1(0,0)=1.0; F1(0,1)=2.0;
F1(1,0)=3.0; F1(1,1)=4.0;
Md F2(2,2,0.0);
F2(0,0)=1.0; F2(0,1)=2.0;
F2(1,0)=3.0; F2(1,1)=4.0;
CHECK(F1 == F2, "Matrix == failed on identical float matrices");
// Floating-point with small perturbation
Md F3 = F1;
F3(1,1) += 1e-10; // tiny difference
CHECK(!(F1 == F3), "Matrix == should fail on exact compare with perturbation");
CHECK(F1.nearly_equal(F3, 1e-9), "Matrix nearly_equal failed with tolerance");
// Larger perturbation
F3(1,1) += 1e-3;
CHECK(!F1.nearly_equal(F3, 1e-6), "Matrix nearly_equal should fail when tolerance too small");
CHECK(F1.nearly_equal(F3, 1e-2), "Matrix nearly_equal should pass with loose tolerance");
}
std::cout << "Matrix basic tests passed ✅\n";
// --- Test: normal transpose ---
{
Mf M(2, 3, 0.0f);
// Fill: [ [1,2,3],
// [4,5,6] ]
M(0,0)=1; M(0,1)=2; M(0,2)=3;
M(1,0)=4; M(1,1)=5; M(1,2)=6;
Mf MT = numerics::transpose(M);
// Should be shape 3x2
CHECK(MT.rows()==3 && MT.cols()==2, "transpose shape wrong");
// Values: [ [1,4], [2,5], [3,6] ]
CHECK(MT(0,0)==1 && MT(0,1)==4, "transpose value (0,*) wrong");
CHECK(MT(1,0)==2 && MT(1,1)==5, "transpose value (1,*) wrong");
CHECK(MT(2,0)==3 && MT(2,1)==6, "transpose value (2,*) wrong");
//std::cout << "Original M:\n" << M << "\n";
//std::cout << "Transposed MT:\n" << MT << "\n\n";
}
// --- Test: inplace transpose (square only) ---
{
Mf S(3, 3, 0.0f);
// Fill with row-major increasing
float val = 1.0f;
for (uint64_t i=0;i<S.rows();++i) {
for (uint64_t j=0;j<S.cols();++j) {
S(i,j) = val++;
}
}
// S =
// [1,2,3]
// [4,5,6]
// [7,8,9]
numerics::inplace_transpose(S);
// Expected after transpose:
// [1,4,7]
// [2,5,8]
// [3,6,9]
CHECK(S(0,1)==4 && S(0,2)==7, "inplace_transpose first row wrong");
CHECK(S(1,0)==2 && S(1,2)==8, "inplace_transpose second row wrong");
CHECK(S(2,0)==3 && S(2,1)==6, "inplace_transpose third row wrong");
//std::cout << "Square matrix after inplace_transpose:\n" << S << "\n\n";
}
// --- Test: inplace transpose throws on non-square ---
{
Mf Rect(2, 3, 1.0f);
bool threw = false;
try {
numerics::inplace_transpose(Rect);
} catch (const std::runtime_error&) {
threw = true;
}
CHECK(threw, "inplace_transpose should throw on non-square matrix");
}
std::cout << "Transpose tests passed ✅\n";
// matmul test
{
Md A(2,2,0.0);
A(0,0) = 1; A(0,1) = 2;
A(1,0) = 3; A(1,1) = 4;
Md B(2,2,0.0);
B(0,0) = 2; B(0,1) = 0;
B(1,0) = 1; B(1,1) = 2;
Md C = numerics::matmul(A, B);
// Expected result:
// [1*2+2*1, 1*0+2*2] = [4, 4]
// [3*2+4*1, 3*0+4*2] = [10, 8]
CHECK(C(0,0)==4 && C(0,1)==4, "matmul: first row wrong");
CHECK(C(1,0)==10 && C(1,1)==8, "matmul: second row wrong");
}
std::cout << "Matmul test passed ✅\n";
// matvec test
{
// A = [[1,2,3],
// [4,5,6]] (2x3)
Md A(2,3,0.0);
A(0,0)=1; A(0,1)=2; A(0,2)=3;
A(1,0)=4; A(1,1)=5; A(1,2)=6;
// x = [7,8,9]
Vd x(3,0.0);
x[0]=7; x[1]=8; x[2]=9;
// y = A*x = [50, 122]
Vd y = numerics::matvec<double>(A, x);
CHECK(y.size()==2, "matvec size wrong");
CHECK(y[0]==50 && y[1]==122, "matvec values wrong");
// dimension mismatch should throw
bool threw = false;
try {
Vd bad(4,1.0);
(void)numerics::matvec<double>(A, bad);
} catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "matvec: expected throw on dim mismatch");
}
std::cout << "matvec tests passed ✅\n";
// vecmat test
{
// A = [[1,2],
// [3,4]] (2x2)
Md A(2,2,0.0);
A(0,0)=1; A(0,1)=2;
A(1,0)=3; A(1,1)=4;
// x^T = [5,6]
Vd x(2,0.0);
x[0]=5; x[1]=6;
// y = x^T * A = [5*1+6*3, 5*2+6*4] = [23, 34]
Vd y = numerics::vecmat<double>(x, A);
CHECK(y.size()==2, "vecmat size wrong");
CHECK(y[0]==23 && y[1]==34, "vecmat values wrong");
// mismatch should throw
bool threw = false;
try {
Md B(3,2,0.0); // 3x2, doesn't match x size 2
(void)numerics::vecmat<double>(x, B);
} catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "vecmat: expected throw on dim mismatch");
}
std::cout << "vecmat tests passed ✅\n";
// Inverse 'Gauss-Jordan' tests
{
Md A(2,2,0.0);
A(0,0)=4; A(0,1)=7;
A(1,0)=2; A(1,1)=6;
Md Ai = numerics::inverse(A, "Gauss-Jordan");
Md I1 = numerics::matmul(A, Ai);
Md I2 = numerics::matmul(Ai, A);
Md I(2,2,0.0);
I(0,0)=1; I(1,1)=1;
CHECK(I1.nearly_equal(I), "A*inv(A) != I");
CHECK(I2.nearly_equal(I), "inv(A)*A != I");
}
std::cout << "Inverse 'Gauss-Jordan' tests passed ✅\n";
return 0;
}
+2
View File
@@ -0,0 +1,2 @@
#define TEST_MAIN
#include "test_common.h"
+52
View File
@@ -0,0 +1,52 @@
#ifndef _test_common_n_
#define _test_common_n_
#pragma once
#include <iostream>
#include <stdexcept>
#include <string>
#include <vector>
struct TestFailure : public std::runtime_error {
using std::runtime_error::runtime_error;
};
#define CHECK(cond, msg) do { if (!(cond)) throw TestFailure(msg); } while (0)
#define CHECK_EQ(a,b,msg) do { if (!((a)==(b))) { throw TestFailure(std::string(msg) + " (" #a " != " #b ")"); } } while (0)
#define TEST_CASE(name) \
static void name(); \
struct name##_registrar { name##_registrar(){ TestRegistry::add(#name, &name);} } name##_registrar_instance; \
static void name()
struct TestRegistry {
using Fn = void(*)();
static std::vector<std::pair<std::string, Fn>>& list() {
static std::vector<std::pair<std::string, Fn>> v; return v;
}
static void add(const std::string& name, Fn fn) { list().push_back({name, fn}); }
};
// Default test runner main()
#ifdef TEST_MAIN
int main() {
int fails = 0;
for (auto& t : TestRegistry::list()) {
try {
t.second();
std::cout << "[PASS] " << t.first << "\n";
} catch (const TestFailure& e) {
std::cerr << "[FAIL] " << t.first << " -> " << e.what() << "\n";
++fails;
} catch (const std::exception& e) {
std::cerr << "[ERROR] " << t.first << " -> " << e.what() << "\n";
++fails;
}
}
std::cout << (fails ? "Some tests failed ❌\n" : "All tests passed ✅\n");
return fails ? 1 : 0;
}
#endif
#endif // _test_common_n_
+126
View File
@@ -0,0 +1,126 @@
#include "test_common.h"
#include "./utils/utils.h"
#include "./numerics/inverse.h"
#include "./numerics/matmul.h"
TEST_CASE(Inverse_2x2_WellConditioned) {
using T = double;
// A = [[4,7],[2,6]] inverse = (1/10) * [[6,-7],[-2,4]]
utils::Matrix<T> A(2,2, T{0});
A(0,0)=4; A(0,1)=7;
A(1,0)=2; A(1,1)=6;
auto Ainv = numerics::inverse<T>(A); // out-of-place
// Check A * Ainv ≈ I and Ainv * A ≈ I
auto Ileft = numerics::matmul(A, Ainv);
auto Iright = numerics::matmul(Ainv, A);
utils::Md Iref(2,2, T{0});
for (uint64_t i=0;i<Iref.rows();++i) Iref(i,i)=T{1};
//auto Iref = eye<T>(2);
CHECK((Ileft.nearly_equal(Iref, 1e-12)), "A * inverse(A) ≠ I");
CHECK((Iright.nearly_equal(Iref, 1e-12)), "inverse(A) * A ≠ I");
}
TEST_CASE(Inverse_InPlace_Equals_OutOfPlace) {
using T = double;
utils::Matrix<T> A(3,3, T{0});
// A = [[3, 0, 2],
// [2, 0, -2],
// [0, 1, 1]]
A(0,0)=3; A(0,1)=0; A(0,2)= 2;
A(1,0)=2; A(1,1)=0; A(1,2)=-2;
A(2,0)=0; A(2,1)=1; A(2,2)= 1;
auto Ainv_ref = numerics::inverse<T>(A); // copy path
auto A_inp = A;
numerics::inplace_inverse<T>(A_inp); // in-place path
CHECK((A_inp.nearly_equal(Ainv_ref, 1e-12)), "in-place inverse differs from out-of-place");
}
TEST_CASE(Inverse_Singular_Throws) {
using T = double;
utils::Matrix<T> S(2,2, T{0});
// Singular: rows are multiples → det = 0
S(0,0)=1; S(0,1)=2;
S(1,0)=2; S(1,1)=4;
bool threw=false;
try {
auto _ = numerics::inverse<T>(S);
(void)_;
} catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "inverse should throw on singular matrix");
threw=false;
try {
numerics::inplace_inverse<T>(S);
} catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "inplace_inverse should throw on singular matrix");
}
TEST_CASE(Inverse_RoundTrip_DiagonallyDominant_5x5) {
// Build a well-conditioned 5x5: diagonally dominant
utils::Md A(5,5,0.0);
for (uint64_t i=0;i<5;++i) {
double rowsum = 0.0;
for (uint64_t j=0;j<5;++j) {
if (i==j) continue;
A(i,j) = 0.01 * double(1 + ((i+1)*(j+3)) % 7);
rowsum += std::fabs(A(i,j));
}
A(i,i) = rowsum + 1.0; // strictly diagonally dominant
}
utils::Md A_copy = A; // ensure wrapper doesn't mutate input
utils::Md Ainv = numerics::inverse<double>(A);
// Input must be unchanged by the non-inplace wrapper
CHECK(A.nearly_equal(A_copy, 0.0), "inverse wrapper modified input");
utils::Md I(5,5, 0);
for (uint64_t i=0;i<I.rows();++i) I(i,i)=1;
auto L = numerics::matmul<double>(A, Ainv);
auto R = numerics::matmul<double>(Ainv, A);
CHECK(L.nearly_equal(I, 1e-10), "A * Ainv not close to I");
CHECK(R.nearly_equal(I, 1e-10), "Ainv * A not close to I");
}
TEST_CASE(Inverse_NonSquare_Throws) {
// Non-square: 2x3 — algorithm expects square; should throw
utils::Md A(2,3,0.0);
bool threw = false;
try {
numerics::inplace_inverse<double>(A);
} catch (const std::runtime_error&) {
threw = true;
} catch (...) {
threw = true; // any failure is fine; must not silently succeed
}
CHECK(threw, "inplace_inverse should throw on non-square matrix");
}
TEST_CASE(Inverse_Unknown_Method_Throws) {
utils::Md A(3,3, 0);
for (uint64_t i=0;i<A.rows();++i) A(i,i)=1;
bool threw = false;
try {
numerics::inplace_inverse<double>(A, "NotARealMethod");
} catch (const std::runtime_error&) {
threw = true;
}
CHECK(threw, "should throw for unknown inverse method");
}
+164
View File
@@ -0,0 +1,164 @@
#include "test_common.h"
#include "./utils/utils.h"
#include "./numerics/matmul.h"
#include <chrono>
// ============ Basic correctness ============
TEST_CASE(Matmul_Serial_Simple3x3) {
utils::Md A(3,3,0.0), B(3,3,0.0);
// A = [[1,2,3],[4,5,6],[7,8,9]]
double v=1.0;
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) A(i,j)=v++;
// B = [[9,8,7],[6,5,4],[3,2,1]]
double w=9.0;
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) B(i,j)=w--;
auto C = numerics::matmul<double>(A,B);
// Hand-checked first row:
// row0 dot columns:
// c00 = 1*9 + 2*6 + 3*3 = 30
// c01 = 1*8 + 2*5 + 3*2 = 24
// c02 = 1*7 + 2*4 + 3*1 = 18
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
CHECK(C(0,0)==30.0 && C(0,1)==24.0 && C(0,2)==18.0, "first row wrong");
}
TEST_CASE(Matmul_OMP_Equals_Serial) {
utils::Md A(4,5,0.0), B(5,3,0.0);
// Fill deterministic
for (uint64_t i=0;i<A.rows();++i)
for (uint64_t j=0;j<A.cols();++j)
A(i,j) = 0.1*(1 + (i*17 + j*19)%10);
for (uint64_t i=0;i<B.rows();++i)
for (uint64_t j=0;j<B.cols();++j)
B(i,j) = 0.2*(1 + (i*23 + j*29)%10);
auto Cs = numerics::matmul<double>(A,B);
auto Cr = numerics::matmul_rows_omp<double>(A,B);
auto Cc = numerics::matmul_collapse_omp<double>(A,B);
auto Ca = numerics::matmul_auto<double>(A,B);
CHECK((Cs.nearly_equal(Cr, 1e-12)), "rows_omp != serial");
CHECK((Cs.nearly_equal(Cc, 1e-12)), "collapse_omp != serial");
CHECK((Cs.nearly_equal(Ca, 1e-12)), "auto != serial");
}
// ============ Dimension mismatch ============
TEST_CASE(Matmul_DimensionMismatch_Throws) {
utils::Md A(2,3,0.0), B(4,2,0.0);
bool threw=false;
try { auto _ = numerics::matmul<double>(A,B); (void)_; }
catch (const std::runtime_error&) { threw=true; }
CHECK(threw, "serial should throw on dim mismatch");
threw=false; try { auto _ = numerics::matmul_rows_omp<double>(A,B); (void)_; }
catch (const std::runtime_error&) { threw=true; }
CHECK(threw, "rows_omp should throw on dim mismatch");
threw=false; try { auto _ = numerics::matmul_collapse_omp<double>(A,B); (void)_; }
catch (const std::runtime_error&) { threw=true; }
CHECK(threw, "collapse_omp should throw on dim mismatch");
}
// ============ Edge cases ============
TEST_CASE(Matmul_Edges_ZeroDims) {
// (0xK) * (KxP) -> (0xP)
utils::Md A0(0,5,0.0), B1(5,3,0.0);
auto C0 = numerics::matmul<double>(A0,B1);
CHECK(C0.rows()==0 && C0.cols()==3, "0xK * KxP shape wrong");
// (MxK) * (Kx0) -> (Mx0)
utils::Md A2(7,4,0.0), B0(4,0,0.0);
auto C1 = numerics::matmul<double>(A2,B0);
CHECK(C1.rows()==7 && C1.cols()==0, "MxK * Kx0 shape wrong");
}
// ============ Identity sanity ============
TEST_CASE(Matmul_Identity) {
const uint64_t n=5;
utils::Md I(n,n,0.0), A(n,n,0.0);
for (uint64_t i=0;i<n;++i) I(i,i)=1.0;
for (uint64_t i=0;i<n;++i)
for (uint64_t j=0;j<n;++j)
A(i,j) = (i==j)? 2.0 : ( (i<j)? 1.0 : -1.0 );
auto L = numerics::matmul<double>(I,A);
auto R = numerics::matmul<double>(A,I);
CHECK(L == A, "I*A != A");
CHECK(R == A, "A*I != A");
}
// ============ Perf sanity (same kernel: 1 thread vs many) ============
template <class F>
static double time_it(F&& f, int iters=1) {
auto t0 = std::chrono::high_resolution_clock::now();
for (int i=0;i<iters;++i) f();
auto t1 = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double>(t1 - t0).count();
}
TEST_CASE(Matmul_Perf_Sanity_RowOMP) {
#ifndef _OPENMP
return;
#else
int hw = omp_get_max_threads();
if (hw <= 1) return;
const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed
utils::Md A(m,k,0.0), B(k,p,0.0);
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i+j%7)*0.001;
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i*3+j%5)*0.0005;
// Warm-up
(void) numerics::matmul_rows_omp<double>(A,B);
int prev = omp_get_max_threads();
auto t0 = std::chrono::high_resolution_clock::now();
omp_set_num_threads(1);
numerics::matmul_rows_omp<double>(A,B);
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
omp_set_num_threads(hw);
t0 = std::chrono::high_resolution_clock::now();
numerics::matmul_rows_omp<double>(A,B);
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
omp_set_num_threads(prev);
// Must not be notably slower with many threads
CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread");
#endif
}
TEST_CASE(Matmul_Perf_Sanity_CollapseOMP) {
#ifndef _OPENMP
return;
#else
int hw = omp_get_max_threads();
if (hw <= 1) return;
const uint64_t m=512, k=512, p=512;
utils::Md A(m,k,0.0), B(k,p,0.0);
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i*7+j%11)*0.0003;
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i%13+j)*0.0002;
(void) numerics::matmul_collapse_omp<double>(A,B); // warm-up
int prev = omp_get_max_threads();
auto t0 = std::chrono::high_resolution_clock::now();
omp_set_num_threads(1);
numerics::matmul_collapse_omp<double>(A,B);
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
omp_set_num_threads(hw);
t0 = std::chrono::high_resolution_clock::now();
numerics::matmul_collapse_omp<double>(A,B);
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
omp_set_num_threads(prev);
CHECK(tN <= t1 * 1.05, "collapse_omp: multi-thread slower than single-thread");
#endif
}
+142
View File
@@ -0,0 +1,142 @@
#include "test_common.h"
#include "./utils/utils.h"
using utils::Vf; using utils::Vd; using utils::Vi;
using utils::Mf; using utils::Md; using utils::Mi;
// ---------- Construction & element access ----------
TEST_CASE(Matrix_Construct_Access) {
Md M; // default
CHECK(M.rows()==0 && M.cols()==0, "default ctor dims wrong");
Mf A(2,3, 1.0f);
CHECK(A.rows()==2 && A.cols()==3, "ctor dims wrong");
CHECK(A(0,0)==1.0f && A(1,2)==1.0f, "fill wrong");
A(0,1)=2.5f; A(1,0)=3.5f;
CHECK(A(0,1)==2.5f && A(1,0)==3.5f, "operator() set/get failed");
}
// ---------- Equality, inequality, nearly_equal ----------
TEST_CASE(Matrix_Equality) {
Mi A(2,2,0), B(2,2,0), C(2,2,1);
A(0,0)=1; A(1,1)=1; // A = I
B(0,0)=1; B(1,1)=1; // B = I
CHECK(A == B, "== failed identical");
CHECK(!(A != B), "!= failed identical");
CHECK(A != C, "!= failed different");
Md F1(2,2,0.0), F2(2,2,0.0);
F1(0,0)=1.0; F1(1,1)=2.0;
F2(0,0)=1.0; F2(1,1)=2.0 + 5e-10; // tiny perturbation
CHECK(!(F1 == F2), "operator== is exact; should differ");
CHECK(F1.nearly_equal(F2, 1e-9), "nearly_equal should accept tiny delta");
CHECK(!F1.nearly_equal(F2, 1e-12), "nearly_equal too strict should fail");
}
// ---------- Row helpers ----------
TEST_CASE(Matrix_Row_Get_Set) {
Mf M(3,4, 0.0f);
Vf r(4, 0.0f);
for (uint64_t j=0;j<4;++j) r[j] = float(j+1); // [1,2,3,4]
M.set_row(1, r);
auto out = M.get_row(1);
CHECK(out == r, "set_row/get_row mismatch");
// size mismatch should throw
bool threw=false;
try { Vf bad(3, 9.0f); M.set_row(2, bad); } catch (const std::exception&) { threw=true; }
CHECK(threw, "set_row should throw on size mismatch");
// out of range
threw=false;
try { (void)M.get_row(3); } catch (const std::out_of_range&) { threw=true; }
CHECK(threw, "get_row should throw on OOB index");
}
// ---------- Column helpers ----------
TEST_CASE(Matrix_Col_Get_Set) {
Md M(3,2, 0.0);
Vd c(3, 0.0);
c[0]=10; c[1]=20; c[2]=30;
M.set_col(1, c);
auto out = M.get_col(1);
CHECK(out == c, "set_col/get_col mismatch");
// size mismatch should throw
bool threw=false;
try { Vd bad(2, 9.0); M.set_col(0, bad); } catch (const std::exception&) { threw=true; }
CHECK(threw, "set_col should throw on size mismatch");
// out of range
threw=false;
try { (void)M.get_col(2); } catch (const std::out_of_range&) { threw=true; }
CHECK(threw, "get_col should throw on OOB index");
}
// ---------- swap_rows / swap_cols ----------
TEST_CASE(Matrix_Swap_Rows_Cols) {
Mi M(2,3,0);
// Row 0: [1,2,3], Row 1: [4,5,6]
M(0,0)=1; M(0,1)=2; M(0,2)=3;
M(1,0)=4; M(1,1)=5; M(1,2)=6;
M.swap_rows(0,1);
CHECK(M(0,0)==4 && M(0,1)==5 && M(0,2)==6, "swap_rows row0 wrong");
CHECK(M(1,0)==1 && M(1,1)==2 && M(1,2)==3, "swap_rows row1 wrong");
// swap back via cols
M.swap_cols(0,2);
// After swapping col0<->col2:
// Row0: [6,5,4], Row1: [3,2,1]
CHECK(M(0,0)==6 && M(0,1)==5 && M(0,2)==4, "swap_cols row0 wrong");
CHECK(M(1,0)==3 && M(1,1)==2 && M(1,2)==1, "swap_cols row1 wrong");
// no-op swap (a==b) should not crash or change
M.swap_rows(1,1);
M.swap_cols(2,2);
// OOB checks
bool threw=false;
try { M.swap_rows(5,1); } catch (const std::out_of_range&) { threw=true; }
CHECK(threw, "swap_rows should throw on OOB");
threw=false;
try { M.swap_cols(0,9); } catch (const std::out_of_range&) { threw=true; }
CHECK(threw, "swap_cols should throw on OOB");
}
// ---------- data() layout (contiguous row-major) ----------
TEST_CASE(Matrix_Data_Layout) {
Md M(2,3, 0.0);
// Fill increasing sequence
double val=1.0;
for (uint64_t i=0;i<M.rows();++i)
for (uint64_t j=0;j<M.cols();++j)
M(i,j) = val++;
const double* p = M.data();
// Expect row-major: [1,2,3,4,5,6]
CHECK(p[0]==1.0 && p[1]==2.0 && p[2]==3.0 && p[3]==4.0 && p[4]==5.0 && p[5]==6.0,
"data() row-major layout wrong");
}
// ---------- Stream output ----------
TEST_CASE(Matrix_StreamOutput) {
Mf M(2,2,0.0f);
M(0,0)=1.0f; M(0,1)=2.0f;
M(1,0)=3.0f; M(1,1)=4.0f;
std::ostringstream oss;
oss << M;
const std::string s = oss.str();
// Format example:
// [[1.000, 2.000],
// [3.000, 4.000]]
CHECK(s.find("[[1.000, 2.000],") != std::string::npos, "ostream first row format");
CHECK(s.find("[3.000, 4.000]]") != std::string::npos, "ostream second row format");
}
+237
View File
@@ -0,0 +1,237 @@
#include "test_common.h"
#include "./utils/utils.h" // matrix.h, vector.h
#include "./numerics/matvec.h" // numerics::matvec / inplace_transpose
#include <chrono>
using utils::Vi; using utils::Vf; using utils::Vd;
using utils::Mi; using utils::Mf; using utils::Md;
// ------------------------------------------------------------
// matvec: y = A * x
// ------------------------------------------------------------
TEST_CASE(Matvec_Serial_Simple) {
// A = [[1,2,3],
// [4,5,6]]
Md A(2,3,0.0);
A(0,0)=1; A(0,1)=2; A(0,2)=3;
A(1,0)=4; A(1,1)=5; A(1,2)=6;
Vd x(3,0.0); x[0]=7; x[1]=8; x[2]=9;
auto y = numerics::matvec<double>(A,x); // [ 1*7+2*8+3*9 , 4*7+5*8+6*9 ] = [50, 122]
CHECK(y.size()==2, "matvec size wrong");
CHECK(y[0]==50.0 && y[1]==122.0, "matvec values wrong");
}
TEST_CASE(Matvec_OMP_Equals_Serial) {
Md A(3,3,0.0);
// A = I * 2
for (uint64_t i=0;i<3;++i) A(i,i)=2.0;
Vd x(3,0.0); x[0]=1; x[1]=2; x[2]=3;
auto ys = numerics::matvec<double>(A,x);
auto yp = numerics::matvec_omp<double>(A,x);
CHECK((ys.nearly_equal_vec(yp)), "matvec_omp != matvec");
}
TEST_CASE(Matvec_Auto_Equals_Serial) {
Md A(2,2,0.0); A(0,0)=2; A(0,1)=1; A(1,0)=0.5; A(1,1)=3;
Vd x(2,0.0); x[0]=4; x[1]=5;
auto ys = numerics::matvec<double>(A,x);
auto ya = numerics::matvec_auto<double>(A,x);
CHECK((ys.nearly_equal_vec(ya)), "matvec_auto != serial");
}
TEST_CASE(Matvec_DimensionMismatch_Throws) {
Md A(2,3,0.0);
Vd x(4,0.0);
bool threw=false;
try { auto _ = numerics::matvec<double>(A,x); (void)_; }
catch (const std::runtime_error&) { threw=true; }
CHECK(threw, "matvec must throw on dimension mismatch");
}
TEST_CASE(Matvec_Zero_Edges) {
Md A(0,3,0.0); // 0x3
Vd x(3,1.0);
auto y = numerics::matvec<double>(A,x);
CHECK(y.size()==0, "0xN * x should return size 0 vector");
Md B(2,0,0.0); // 2x0
Vd z(0,0.0);
auto y2 = numerics::matvec<double>(B,z);
CHECK(y2.size()==2 && y2[0]==0.0 && y2[1]==0.0, "N×0 * 0 should return zeros of size N");
}
TEST_CASE(Matvec_Float_Tolerance) {
Mf A(2,2,0.0f); A(0,0)=1.0f; A(0,1)=2.0f; A(1,0)=3.0f; A(1,1)=4.0f;
Vf x(2,0.0f); x[0]=0.1f; x[1]=0.2f;
auto y1 = numerics::matvec<float>(A,x);
auto y2 = numerics::matvec_omp<float>(A,x);
CHECK((y1.nearly_equal_vec(y2,1e-6f)), "matvec float omp mismatch");
}
//
// ---------- Auto inside an outer parallel region (no accidental nested teams) ----------
// We just check correctness; performance is environment-dependent.
//
TEST_CASE(Matvec_Auto_Inside_Outer_Parallel_Correctness) {
const uint64_t m=64, n=64;
Md A(m,n,1.0); Vd x(n,2.0);
//fill_deterministic(A); fill_deterministic(x);
Vd ref = numerics::matvec<double>(A,x);
// Call auto inside an outer team
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int rep=0; rep<32; ++rep) {
auto y = numerics::matvec_auto<double>(A,x);
// Each thread checks its own result equals reference
if (!(y.nearly_equal_vec(ref))) {
throw TestFailure("matvec_auto wrong under outer parallel region");
}
}
}
TEST_CASE(Matvec_Speed_Sanity) {
const uint64_t m=4096, n=4096; // ~16M MACs; adjust if needed
Md A(m,n,1.0); Vd x(n,2.0);
//fill_deterministic(A); fill_deterministic(x);
auto t0 = std::chrono::high_resolution_clock::now();
auto yS = numerics::matvec(A,x);
double tp = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
t0 = std::chrono::high_resolution_clock::now();
auto yP = numerics::matvec_omp(A,x);
double ts = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
CHECK((yS.nearly_equal_vec(yP)), "matvec_omp != matvec_serial (large)");
// Only enforce basic sanity if we *can* use >1 threads:
if (threads > 1) {
// Be generous: just require not significantly slower.
CHECK(tp <= ts, "matvec_omp unexpectedly much slower than serial");
}
}
// ------------------------------------------------------------
// vecmat: y = x * A
// ------------------------------------------------------------
TEST_CASE(Vecmat_Serial_Simple) {
// A = [[1,2],
// [3,4],
// [5,6]] (3x2)
Md A(3,2,0.0);
A(0,0)=1; A(0,1)=2;
A(1,0)=3; A(1,1)=4;
A(2,0)=5; A(2,1)=6;
Vd x(3,0.0); x[0]=7; x[1]=8; x[2]=9;
auto y = numerics::vecmat<double>(x,A); // 1*7+3*8+5*9= 76 ; 2*7+4*8+6*9=100
CHECK(y.size()==2, "vecmat size wrong");
CHECK(y[0]==76.0 && y[1]==100.0, "vecmat values wrong");
}
TEST_CASE(Vecmat_OMP_Equals_Serial) {
Md A(2,2,0.0); A(0,0)=2; A(0,1)=1; A(1,0)=5; A(1,1)=-1;
Vd x(2,0.0); x[0]=0.5; x[1]=1.5;
auto ys = numerics::vecmat<double>(x,A);
auto yp = numerics::vecmat_omp<double>(x,A);
CHECK((ys.nearly_equal_vec(yp)), "vecmat_omp != vecmat");
}
TEST_CASE(Vecmat_Auto_Equals_Serial) {
Md A(2,3,0.0);
A(0,0)=1; A(0,1)=2; A(0,2)=3;
A(1,0)=4; A(1,1)=5; A(1,2)=6;
Vd x(2,0.0); x[0]=1; x[1]=2;
auto ys = numerics::vecmat<double>(x,A);
auto ya = numerics::vecmat_auto<double>(x,A);
CHECK((ys.nearly_equal_vec(ya)), "vecmat_auto != serial");
}
TEST_CASE(Vecmat_DimensionMismatch_Throws) {
Md A(2,2,0.0);
Vd x(3,0.0);
bool threw=false;
try { auto _ = numerics::vecmat<double>(x,A); (void)_; }
catch (const std::runtime_error&) { threw=true; }
CHECK(threw, "vecmat must throw on dimension mismatch");
}
TEST_CASE(Vecmat_Zero_Edges) {
Md A(0,3,0.0);
Vd x(0,0.0);
auto y = numerics::vecmat<double>(x,A); // 0×N times N×M → 0×M
CHECK(y.size()==3 && y[0]==0.0 && y[1]==0.0 && y[2]==0.0, "0-length x times A wrong");
Md B(3,0,0.0);
Vd z(3,1.0);
auto y2 = numerics::vecmat<double>(z,B); // 1x3 * 3x0 → 1x0
CHECK(y2.size()==0, "vecmat with N×0 result size wrong");
}
//
// ---------- Auto inside an outer parallel region (no accidental nested teams) ----------
// We just check correctness; performance is environment-dependent.
//
TEST_CASE(Vecmat_Auto_Inside_Outer_Parallel_Correctness) {
const uint64_t m=64, n=64;
Md A(m,n,1.0); Vd x(m,2.0);
//fill_deterministic(A); fill_deterministic(x);
Vd ref = numerics::vecmat<double>(x,A);
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int rep=0; rep<32; ++rep) {
auto y = numerics::vecmat_auto<double>(x,A);
if (!(y.nearly_equal_vec(ref))) {
throw TestFailure("vecmat_auto wrong under outer parallel region");
}
}
}
TEST_CASE(Vecmat_Speed_Sanity) {
const uint64_t m=4096, n=4096;
Md A(m,n,1.0); Vd x(m,2.0);
//fill_deterministic(A); fill_deterministic(x);
auto t0 = std::chrono::high_resolution_clock::now();
auto yS = numerics::vecmat<double>(x,A);
double ts = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
#ifdef _OPENMP
int threads = omp_get_max_threads();
#else
int threads = 1;
#endif
t0 = std::chrono::high_resolution_clock::now();
auto yP = numerics::vecmat_omp<double>(x,A);
double tp = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
CHECK((yS.nearly_equal_vec(yP)), "vecmat_omp != vecmat_serial (large)");
if (threads > 1) {
CHECK(tp <= ts, "vecmat_omp unexpectedly much slower than serial");
}
}
+88
View File
@@ -0,0 +1,88 @@
#include "test_common.h"
#include "./utils/utils.h" // matrix.h, vector.h
#include "./numerics/transpose.h" // numerics::transpose / inplace_transpose
using utils::Mi; using utils::Mf; using utils::Md;
//
// ---------- Out-of-place transpose (rectangular) ----------
//
TEST_CASE(Transpose_Rectangular_OutOfPlace) {
// A = [ [1, 2, 3],
// [4, 5, 6] ] (2x3)
Md A(2,3,0.0);
A(0,0)=1; A(0,1)=2; A(0,2)=3;
A(1,0)=4; A(1,1)=5; A(1,2)=6;
auto AT = numerics::transpose(A); // (3x2)
CHECK(AT.rows()==3 && AT.cols()==2, "shape wrong after transpose");
CHECK(AT(0,0)==1 && AT(1,0)==2 && AT(2,0)==3, "first column wrong");
CHECK(AT(0,1)==4 && AT(1,1)==5 && AT(2,1)==6, "second column wrong");
// Involution: T(T(A)) == A
auto ATT = numerics::transpose(AT);
CHECK(ATT == A, "transpose(transpose(A)) != A");
}
//
// ---------- In-place transpose (square) ----------
//
TEST_CASE(Transpose_Square_InPlace) {
// 3x3 with distinct values
Mf S(3,3,0.0f);
float val = 1.0f;
for (uint64_t i=0;i<3;++i)
for (uint64_t j=0;j<3;++j)
S(i,j) = val++;
// Make an explicit transpose to compare against
auto ST = numerics::transpose(S);
// In-place should match the out-of-place result
numerics::inplace_transpose(S);
CHECK(S == ST, "inplace_transpose result mismatch");
// Involution: applying in-place again should return original
numerics::inplace_transpose(S);
// Now S should equal the original pre-inplace matrix (which was transposed above)
// We can reconstruct original by transposing ST:
auto orig = numerics::transpose(ST);
CHECK(S == orig, "inplace transpose twice should restore original");
}
//
// ---------- In-place transpose must throw on non-square ----------
//
TEST_CASE(Transpose_InPlace_Throws_On_Rectangular) {
Md R(2,3,0.0); // rectangular
bool threw = false;
try {
numerics::inplace_transpose(R);
} catch (const std::runtime_error&) {
threw = true;
}
CHECK(threw, "inplace_transpose must throw on non-square matrices");
}
//
// ---------- Edge cases: 0x0 and 1x1 ----------
//
TEST_CASE(Transpose_Edge_0x0_1x1) {
// 0x0 should be fine both ways
Md E; // 0x0
auto ET = numerics::transpose(E);
CHECK(ET.rows()==0 && ET.cols()==0, "0x0 transpose shape wrong");
// in-place on 0x0 (rows==cols) should not throw
numerics::inplace_transpose(E);
CHECK(E.rows()==0 && E.cols()==0, "0x0 inplace transpose changed shape");
// 1x1 should be a no-op and not throw
Mi I(1,1,0);
I(0,0) = 42;
auto IT = numerics::transpose(I);
CHECK(IT.rows()==1 && IT.cols()==1 && IT(0,0)==42, "1x1 transpose wrong");
numerics::inplace_transpose(I);
CHECK(I(0,0)==42, "1x1 inplace transpose changed value");
}
+211
View File
@@ -0,0 +1,211 @@
#include "test_common.h"
#include "./utils/utils.h"
using utils::Vf; using utils::Vd; using utils::Vi;
//
// ---------- Basic construction & access ----------
//
TEST_CASE(Vector_Construct_Size_Access) {
Vd a; // default
CHECK(a.size() == 0, "default size must be 0");
Vf b(3, 1.0f); // (n, fill)
CHECK(b.size() == 3, "size wrong");
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "fill wrong");
b[1] = 2.5f;
CHECK(b[1] == 2.5f, "operator[] write failed");
// resize (grow + value)
b.resize(5, 7.0f);
CHECK(b.size() == 5, "resize grow size wrong");
CHECK(b[0] == 1.0f && b[1] == 2.5f && b[2] == 1.0f && b[3] == 7.0f && b[4] == 7.0f,
"resize grow values wrong");
// resize (shrink)
b.resize(2);
CHECK(b.size() == 2, "resize shrink size wrong");
}
TEST_CASE(Vector_Clear_PushBack) {
Vi v(0, 0);
v.push_back(10);
v.push_back(20);
CHECK(v.size() == 2, "push_back size wrong");
CHECK(v[0] == 10 && v[1] == 20, "push_back values wrong");
v.clear();
CHECK(v.size() == 0, "clear failed");
}
//
// ---------- Equality / Inequality (tolerant for float/double) ----------
//
TEST_CASE(Vector_Equality_Tolerant) {
Vd a(3, 1.0), b(3, 1.0);
CHECK(a == b, "== identical failed");
CHECK(!(a != b), "!= identical failed");
// Tiny perturbation within eps (1e-6 default)
b[1] += 1e-7;
CHECK(a == b, "== tolerant failed");
// Larger perturbation should fail equality
b[1] += 1e-4;
CHECK(a != b, "!= with difference failed");
}
//
// ---------- Scalar arithmetic: +, -, *, / (inplace and returning) ----------
//
TEST_CASE(Vector_Scalar_Arithmetic) {
Vf a(3, 1.0f);
// inplace
a.inplace_add(2); // int convertible to float
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_add failed");
a.inplace_subtract(1.5f);
CHECK(std::fabs(a[0] - 1.5f) < 1e-6f &&
std::fabs(a[1] - 1.5f) < 1e-6f &&
std::fabs(a[2] - 1.5f) < 1e-6f, "inplace_subtract failed");
a.inplace_multiply(4.0);
CHECK(a[0] == 6.0f && a[1] == 6.0f && a[2] == 6.0f, "inplace_multiply failed");
a.inplace_divide(2);
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_divide failed");
// returning
auto b = a + 1.0f;
CHECK(b[0] == 4.0f && b[1] == 4.0f && b[2] == 4.0f, "operator+(scalar) failed");
b = a - 2.0f;
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "operator-(scalar) failed");
b = a * 10; // int -> float
CHECK(b[0] == 30.0f && b[1] == 30.0f && b[2] == 30.0f, "operator*(scalar) failed");
b = a / 3.0f;
CHECK(std::fabs(b[0] - 1.0f) < 1e-6f &&
std::fabs(b[1] - 1.0f) < 1e-6f &&
std::fabs(b[2] - 1.0f) < 1e-6f, "operator/(scalar) failed");
// scalar on the left (friends implemented for + and *)
Vf c(3, 2.0f);
auto d = 5 + c; // friend operator+(U, Vector<T>)
CHECK(d[0] == 7.0f && d[1] == 7.0f && d[2] == 7.0f, "scalar + vector failed");
d = 3 * c; // friend operator*(U, Vector<T>)
CHECK(d[0] == 6.0f && d[1] == 6.0f && d[2] == 6.0f, "scalar * vector failed");
}
//
// ---------- Vector arithmetic: +, -, *, / (elementwise) ----------
//
TEST_CASE(Vector_Vector_Arithmetic) {
Vd a(3, 1.0), b(3, 2.0);
// returning
auto c = a + b;
CHECK(c[0]==3.0 && c[1]==3.0 && c[2]==3.0, "vec + vec failed");
c = b - a;
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec - vec failed");
c = a * b;
CHECK(c[0]==2.0 && c[1]==2.0 && c[2]==2.0, "vec * vec failed");
c = b / b;
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec / vec failed");
// inplace
a = Vd(3, 1.0);
a += b;
CHECK(a[0]==3.0 && a[1]==3.0 && a[2]==3.0, "inplace vec + vec failed");
a -= b;
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec - vec failed");
a *= b;
CHECK(a[0]==2.0 && a[1]==2.0 && a[2]==2.0, "inplace vec * vec failed");
a /= b;
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec / vec failed");
}
//
// ---------- Size mismatch error paths ----------
//
TEST_CASE(Vector_SizeMismatch_Throws) {
Vd a(3, 1.0), b(4, 2.0);
bool threw = false;
try { auto c = a + b; (void)c; } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "add should throw on size mismatch");
threw = false;
try { a.inplace_subtract(b); } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "inplace_subtract should throw on size mismatch");
threw = false;
try { auto d = a * b; (void)d; } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "multiply should throw on size mismatch");
threw = false;
try { auto s = a.dot(b); (void)s; } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "dot should throw on size mismatch");
}
//
// ---------- Power / sqrt ----------
//
TEST_CASE(Vector_Power_Sqrt) {
Vd a(3, 2.0); // [2,2,2]
auto b = a.power(3.0); // [8,8,8]
CHECK(b[0]==8.0 && b[1]==8.0 && b[2]==8.0, "scalar power failed");
Vd p(3, 3.0); // [3,3,3]
auto c = b.power(p); // 8^3 = 512
CHECK(c[0]==512.0 && c[1]==512.0 && c[2]==512.0, "vector power failed");
Vd d(3, 9.0);
auto e = d.sqrt(); // [3,3,3]
CHECK(e[0]==3.0 && e[1]==3.0 && e[2]==3.0, "sqrt failed");
// inplace
d.inplace_sqrt(); // becomes [3,3,3]
CHECK(d == e, "inplace_sqrt failed");
}
//
// ---------- Dot / Sum / Norm / Normalize ----------
//
TEST_CASE(Vector_Dot_Sum_Norm_Normalize) {
Vd a(3, 0.0);
a[0]=1.0; a[1]=2.0; a[2]=2.0;
CHECK(a.sum() == 5.0, "sum failed");
CHECK(a.dot(a) == 9.0, "dot self failed");
auto n = a.norm();
CHECK(std::fabs(n - 3.0) < 1e-12, "norm failed");
auto b = a.normalize();
CHECK(std::fabs(b.norm() - 1.0) < 1e-12, "normalize() not unit");
// inplace normalize
a.inplace_normalize();
CHECK(std::fabs(a.norm() - 1.0) < 1e-12, "inplace_normalize not unit");
// zero-norm error
Vd z(3, 0.0);
bool threw = false;
try { z.inplace_normalize(); } catch (const std::runtime_error&) { threw = true; }
CHECK(threw, "normalize zero vector must throw");
}
//
// ---------- Stream output (basic sanity) ----------
//
TEST_CASE(Vector_StreamOutput) {
Vi a(3, 2);
std::ostringstream oss;
oss << a;
auto s = oss.str();
CHECK(s == "[2, 2, 2]", "ostream<< wrong format");
}