Finishing up and starting lu decomp
This commit is contained in:
BIN
Binary file not shown.
+42
-23
@@ -4,31 +4,50 @@
|
|||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
|
|
||||||
|
|
||||||
// Configure OpenMP behavior at runtime.
|
namespace omp_config{
|
||||||
inline void omp_configure(int max_active_levels,
|
|
||||||
bool dynamic_threads,
|
|
||||||
const std::vector<int>& threads_per_level = {},
|
|
||||||
bool bind_close = true)
|
|
||||||
{
|
|
||||||
// 1) Allow nested parallel regions (levels of teams)
|
|
||||||
// Example: outer #pragma omp parallel ... { inner #pragma omp parallel ... }
|
|
||||||
omp_set_max_active_levels(max_active_levels); // 1 = only top-level; 2+ enables nesting
|
|
||||||
|
|
||||||
// 2) Let the runtime shrink/grow thread counts if it thinks it should
|
// Configure OpenMP behavior at runtime.
|
||||||
// (helps avoid oversubscription when you accidentally ask for too many threads)
|
inline void omp_configure(int max_active_levels,
|
||||||
omp_set_dynamic(dynamic_threads ? 1 : 0);
|
bool dynamic_threads,
|
||||||
|
const std::vector<int>& threads_per_level = {},
|
||||||
|
bool bind_close = true)
|
||||||
|
{
|
||||||
|
// 1) Allow nested parallel regions (levels of teams)
|
||||||
|
// Example: outer #pragma omp parallel ... { inner #pragma omp parallel ... }
|
||||||
|
omp_set_max_active_levels(max_active_levels); // 1 = only top-level; 2+ enables nesting
|
||||||
|
|
||||||
// 3) Thread binding (keep threads near their cores) is controlled via env vars,
|
// 2) Let the runtime shrink/grow thread counts if it thinks it should
|
||||||
// so here we just *recommend* a good default (see below). You *can* setenv()
|
// (helps avoid oversubscription when you accidentally ask for too many threads)
|
||||||
// in code, but it’s cleaner to do it outside the program.
|
omp_set_dynamic(dynamic_threads ? 1 : 0);
|
||||||
(void)bind_close; // documented below in env var section
|
|
||||||
|
|
||||||
// 4) Top-level default thread count (inner levels are usually set per region)
|
// 3) Thread binding (keep threads near their cores) is controlled via env vars,
|
||||||
if (!threads_per_level.empty()) {
|
// so here we just *recommend* a good default (see below). You *can* setenv()
|
||||||
omp_set_num_threads(threads_per_level[0]); // e.g. 16 for the outermost team
|
// in code, but it’s cleaner to do it outside the program.
|
||||||
// Inner levels:
|
(void)bind_close; // documented below in env var section
|
||||||
// - Use num_threads(threads_per_level[L]) on the inner #pragma omp parallel
|
|
||||||
// - or set OMP_NUM_THREADS="outer,inner,inner2" as an environment variable
|
// 4) Top-level default thread count (inner levels are usually set per region)
|
||||||
|
if (!threads_per_level.empty()) {
|
||||||
|
omp_set_num_threads(threads_per_level[0]); // e.g. 16 for the outermost team
|
||||||
|
// Inner levels:
|
||||||
|
// - Use num_threads(threads_per_level[L]) on the inner #pragma omp parallel
|
||||||
|
// - or set OMP_NUM_THREADS="outer,inner,inner2" as an environment variable
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// ---------- Helper: may we create another team? ----------
|
||||||
|
inline bool omp_parallel_allowed() {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
// If we’re not in parallel, we can spawn.
|
||||||
|
if (!omp_in_parallel()) return true;
|
||||||
|
|
||||||
|
// Already inside parallel: allow only if nesting is enabled and not at limit.
|
||||||
|
int level = omp_get_active_level(); // 0 outside parallel, 1 inside, ...
|
||||||
|
int maxlv = omp_get_max_active_levels(); // user/runtime cap
|
||||||
|
return static_cast<bool>(level < maxlv);
|
||||||
|
#else
|
||||||
|
return false; // no OpenMP → no extra teams
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace omp_config
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "./decomp/lu.h"
|
||||||
|
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "./utils/vector.h"
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
|
||||||
|
namespace decomp{
|
||||||
|
|
||||||
|
// Stores PA = LU with partial pivoting (row permutations).
|
||||||
|
template <typename T>
|
||||||
|
struct LU{
|
||||||
|
utils::Matrix<T> LUmat; // combined L (unit diagonal implied) and U
|
||||||
|
std::vector<uint64_t> piv; // pivot indices (row permutations)
|
||||||
|
bool singular= false; // set true if zero (or near-zero) pivots encountered
|
||||||
|
|
||||||
|
LU() = default;
|
||||||
|
|
||||||
|
explicit LU(const utils::Matrix<T>& A) {
|
||||||
|
factor(A);
|
||||||
|
}
|
||||||
|
|
||||||
|
void factor(const utils::Matrix<T>&A){
|
||||||
|
|
||||||
|
const uint64_t rows = A.rows();
|
||||||
|
if (rows != A.cols()){
|
||||||
|
throw std::runtime_error("LU: factor non-square");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rows == 0){
|
||||||
|
LUmat = A;
|
||||||
|
piv.clear();
|
||||||
|
singular = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
T big, tmp;
|
||||||
|
|
||||||
|
LUmat = A;
|
||||||
|
piv.resize(rows); // piv stores the implicit scaling of each row.
|
||||||
|
//double d = 1.0; // No row interchanges yet.
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < rows; ++i){ // Loop over rows to get the implicit scaling information.
|
||||||
|
big = T{0};
|
||||||
|
for (uint64_t j = 0; j < rows; ++j){
|
||||||
|
tmp=std::abs(LUmat(i,j));
|
||||||
|
if (tmp > big){
|
||||||
|
big = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (big == T{0}){
|
||||||
|
throw std::runtime_error("Singular matrix in LU.factor()");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // struct LU
|
||||||
|
|
||||||
|
typedef LU<float> LUf;
|
||||||
|
typedef LU<double> LUd;
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace decomp
|
||||||
+17
-78
@@ -5,100 +5,39 @@
|
|||||||
#include "./utils/vector.h"
|
#include "./utils/vector.h"
|
||||||
#include "./utils/matrix.h"
|
#include "./utils/matrix.h"
|
||||||
|
|
||||||
|
#include "./numerics/inverse/inverse_gauss_jordan.h"
|
||||||
|
#include "./numerics/inverse/inverse_lu.h"
|
||||||
|
|
||||||
|
#include <omp.h>
|
||||||
|
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void inplace_inverse(utils::Matrix<T>& A, std::string method = "Gauss-Jordan"){
|
void inplace_inverse(utils::Matrix<T>& A, std::string method = "Gauss-Jordan"){
|
||||||
|
|
||||||
|
if (A.rows() != A.cols()) {
|
||||||
|
throw std::runtime_error("inplace_inverse: non-square matrix");
|
||||||
|
}
|
||||||
|
|
||||||
if (method == "Gauss-Jordan"){
|
if (method == "Gauss-Jordan"){
|
||||||
|
inverse_gj(A);
|
||||||
utils::Matrix<T> B(A.rows(),A.cols(), T{0});
|
}
|
||||||
|
|
||||||
|
|
||||||
uint64_t icol{0}, irow{0}, rows{A.rows()}, cols{A.cols()};
|
|
||||||
double big, dum, pivinv, temp;
|
|
||||||
utils::Vi indxc(rows,0), indxr(rows,0), ipiv(rows,0);
|
|
||||||
|
|
||||||
//for (uint64_t j = 0; j < N; ++j){ ipiv[j] = 0;}
|
|
||||||
for (uint64_t i = 0; i < rows; i++){
|
|
||||||
big = 0.0;
|
|
||||||
for (uint64_t j = 0; j < rows; j++){
|
|
||||||
if (ipiv[j] != 1){
|
|
||||||
for (uint64_t k = 0; k < rows; k++){
|
|
||||||
if (ipiv[k] == 0){
|
|
||||||
if (abs(A(j,k)) >= big){
|
|
||||||
big = abs(A(j,k));
|
|
||||||
irow = j;
|
|
||||||
icol = k;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ipiv[icol]++;
|
|
||||||
if (irow != icol){
|
|
||||||
for (uint64_t l = 0; l < rows; l++){ // SWAP
|
|
||||||
temp = A(irow,l);
|
|
||||||
A(irow,l) = A(icol,l);
|
|
||||||
A(icol,l) = temp;
|
|
||||||
}
|
|
||||||
for (uint64_t l = 0; l < cols; l++){ // SWAP temp matrix
|
|
||||||
temp = B(irow,l);
|
|
||||||
B(irow,l) = B(icol,l);
|
|
||||||
B(icol,l) = temp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
indxr[i] = irow;
|
|
||||||
indxc[i] = icol;
|
|
||||||
if (A(icol,icol) == 0.0){
|
|
||||||
throw std::runtime_error("utill:inplace_inverse('Gauss-Jordan' - Singular Matrix");
|
|
||||||
}
|
|
||||||
pivinv= 1.0/A(icol,icol);
|
|
||||||
A(icol,icol)=1.0;
|
|
||||||
|
|
||||||
for (uint64_t l = 0; l < rows; l++){
|
|
||||||
A(icol,l) *= pivinv;
|
|
||||||
}
|
|
||||||
for (uint64_t l = 0; l < cols; l++){
|
|
||||||
B(icol,l) *= pivinv;
|
|
||||||
}
|
|
||||||
for (uint64_t ll = 0; ll < rows; ll++){
|
|
||||||
if (ll != icol){
|
|
||||||
dum = A(ll,icol);
|
|
||||||
A(ll,icol) = 0;
|
|
||||||
for (uint64_t l = 0; l < rows; l++){
|
|
||||||
A(ll,l) -= A(icol,l)*dum;
|
|
||||||
}
|
|
||||||
for (uint64_t l = 0; l < rows; l++){
|
|
||||||
B(ll,l) -= B(icol,l)*dum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
//m = temp_m;
|
|
||||||
for (int64_t l = rows-1; l >= 0; l--){
|
|
||||||
if (indxr[l] != indxc[l]){
|
|
||||||
for (uint64_t k = 0; k < rows; k++){
|
|
||||||
temp = A(k,indxr[l]);
|
|
||||||
A(k,indxr[l]) = A(k,indxc[l]);
|
|
||||||
A(k,indxc[l]) = temp;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else{
|
else{
|
||||||
throw std::runtime_error("numerics::inplace_inverse(" + method + ") - Not implemented yet \r \nImplemented: 'Gauss-Jordan',");
|
throw std::runtime_error("numerics::inplace_inverse(" + method + ") - Not implemented yet \r \nImplemented: 'Gauss-Jordan',");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
utils::Matrix<T> inverse(utils::Matrix<T>& A, std::string method = "Gauss-Jordan"){
|
utils::Matrix<T> inverse(utils::Matrix<T>& A, std::string method = "Gauss-Jordan"){
|
||||||
|
|
||||||
|
|
||||||
utils::Matrix<T> B = A;
|
utils::Matrix<T> B = A;
|
||||||
|
|
||||||
inplace_inverse(B, method);
|
inplace_inverse(B, method);
|
||||||
|
|
||||||
return B;
|
return B;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,94 @@
|
|||||||
|
#ifndef _inverse_gj_n_
|
||||||
|
#define _inverse_gj_n_
|
||||||
|
|
||||||
|
|
||||||
|
#include "./utils/vector.h"
|
||||||
|
#include "./utils/matrix.h"
|
||||||
|
|
||||||
|
#include <omp.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace numerics{
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void inverse_gj(utils::Matrix<T>& A){
|
||||||
|
utils::Matrix<T> B(A.rows(),A.cols(), T{0});
|
||||||
|
|
||||||
|
|
||||||
|
uint64_t icol{0}, irow{0}, rows{A.rows()}, cols{A.cols()};
|
||||||
|
double big, dum, pivinv, temp;
|
||||||
|
utils::Vi indxc(rows,0), indxr(rows,0), ipiv(rows,0);
|
||||||
|
|
||||||
|
//for (uint64_t j = 0; j < N; ++j){ ipiv[j] = 0;}
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < rows; i++){
|
||||||
|
big = 0.0;
|
||||||
|
for (uint64_t j = 0; j < rows; j++){
|
||||||
|
if (ipiv[j] != 1){
|
||||||
|
for (uint64_t k = 0; k < rows; k++){
|
||||||
|
if (ipiv[k] == 0){
|
||||||
|
if (abs(A(j,k)) >= big){
|
||||||
|
big = abs(A(j,k));
|
||||||
|
irow = j;
|
||||||
|
icol = k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ipiv[icol]++;
|
||||||
|
if (irow != icol){
|
||||||
|
for (uint64_t l = 0; l < rows; l++){ // SWAP
|
||||||
|
temp = A(irow,l);
|
||||||
|
A(irow,l) = A(icol,l);
|
||||||
|
A(icol,l) = temp;
|
||||||
|
}
|
||||||
|
for (uint64_t l = 0; l < cols; l++){ // SWAP temp matrix
|
||||||
|
temp = B(irow,l);
|
||||||
|
B(irow,l) = B(icol,l);
|
||||||
|
B(icol,l) = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
indxr[i] = irow;
|
||||||
|
indxc[i] = icol;
|
||||||
|
if (A(icol,icol) == 0.0){
|
||||||
|
throw std::runtime_error("utill:inplace_inverse('Gauss-Jordan' - Singular Matrix");
|
||||||
|
}
|
||||||
|
pivinv= 1.0/A(icol,icol);
|
||||||
|
A(icol,icol)=1.0;
|
||||||
|
|
||||||
|
for (uint64_t l = 0; l < rows; l++){
|
||||||
|
A(icol,l) *= pivinv;
|
||||||
|
}
|
||||||
|
for (uint64_t l = 0; l < cols; l++){
|
||||||
|
B(icol,l) *= pivinv;
|
||||||
|
}
|
||||||
|
for (uint64_t ll = 0; ll < rows; ll++){
|
||||||
|
if (ll != icol){
|
||||||
|
dum = A(ll,icol);
|
||||||
|
A(ll,icol) = 0;
|
||||||
|
for (uint64_t l = 0; l < rows; l++){
|
||||||
|
A(ll,l) -= A(icol,l)*dum;
|
||||||
|
}
|
||||||
|
for (uint64_t l = 0; l < rows; l++){
|
||||||
|
B(ll,l) -= B(icol,l)*dum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
//m = temp_m;
|
||||||
|
for (int64_t l = rows-1; l >= 0; l--){
|
||||||
|
if (indxr[l] != indxc[l]){
|
||||||
|
for (uint64_t k = 0; k < rows; k++){
|
||||||
|
temp = A(k,indxr[l]);
|
||||||
|
A(k,indxr[l]) = A(k,indxc[l]);
|
||||||
|
A(k,indxc[l]) = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace numerics
|
||||||
|
|
||||||
|
#endif // _inverse_gj_n_
|
||||||
@@ -3,10 +3,12 @@
|
|||||||
|
|
||||||
|
|
||||||
#include "./utils/matrix.h"
|
#include "./utils/matrix.h"
|
||||||
|
#include "./core/omp_config.h"
|
||||||
|
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
|
|
||||||
|
// ---------------- Serial baseline ----------------
|
||||||
template <typename T>
|
template <typename T>
|
||||||
utils::Matrix<T> matmul(const utils::Matrix<T>& A, const utils::Matrix<T>& B){
|
utils::Matrix<T> matmul(const utils::Matrix<T>& A, const utils::Matrix<T>& B){
|
||||||
|
|
||||||
@@ -19,10 +21,8 @@ namespace numerics{
|
|||||||
const uint64_t p = B.cols();
|
const uint64_t p = B.cols();
|
||||||
T tmp;
|
T tmp;
|
||||||
|
|
||||||
utils::Matrix<T> C(m, n, T{0});
|
utils::Matrix<T> C(m, p, T{0});
|
||||||
|
|
||||||
//#pragma omp parallel for collapse(2) schedule(static)
|
|
||||||
#pragma omp parallel for
|
|
||||||
for (uint64_t i = 0; i < m; ++i){
|
for (uint64_t i = 0; i < m; ++i){
|
||||||
for (uint64_t j = 0; j < n; ++j){
|
for (uint64_t j = 0; j < n; ++j){
|
||||||
tmp = A(i,j);
|
tmp = A(i,j);
|
||||||
@@ -34,6 +34,85 @@ namespace numerics{
|
|||||||
return C;
|
return C;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------- Rows-only OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> matmul_rows_omp(const utils::Matrix<T>& A,
|
||||||
|
const utils::Matrix<T>& B) {
|
||||||
|
if (A.cols() != B.rows()) throw std::runtime_error("matmul_rows_omp: dim mismatch");
|
||||||
|
const uint64_t m=A.rows(), n=A.cols(), p=B.cols();
|
||||||
|
|
||||||
|
utils::Matrix<T> C(m, p, T{0});
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (uint64_t i=0;i<m;++i) {
|
||||||
|
for (uint64_t j=0;j<p;++j) {
|
||||||
|
T acc=T{0};
|
||||||
|
for (uint64_t k=0;k<n;++k) {
|
||||||
|
acc += A(i,k)*B(k,j);
|
||||||
|
}
|
||||||
|
C(i,j)=acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return C;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------- Collapse(2) OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> matmul_collapse_omp(const utils::Matrix<T>& A,
|
||||||
|
const utils::Matrix<T>& B) {
|
||||||
|
if (A.cols() != B.rows()) throw std::runtime_error("matmul_collapse_omp: dim mismatch");
|
||||||
|
const uint64_t m=A.rows(), n=A.cols(), p=B.cols();
|
||||||
|
utils::Matrix<T> C(m, p, T{0});
|
||||||
|
|
||||||
|
#pragma omp parallel for collapse(2) schedule(static)
|
||||||
|
for (uint64_t i=0;i<m;++i) {
|
||||||
|
for (uint64_t j=0;j<p;++j) {
|
||||||
|
T acc=T{0};
|
||||||
|
for (uint64_t k=0;k<n;++k){
|
||||||
|
acc += A(i,k)*B(k,j);
|
||||||
|
}
|
||||||
|
C(i,j)=acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return C;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -------------------- Auto selector ---------------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Matrix<T> matmul_auto(const utils::Matrix<T>& A,
|
||||||
|
const utils::Matrix<T>& B) {
|
||||||
|
const uint64_t m=A.rows(), p=B.cols();
|
||||||
|
const uint64_t work = m * p;
|
||||||
|
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
int threads = omp_get_max_threads();
|
||||||
|
#else
|
||||||
|
int threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
// Tiny problems: serial is cheapest.
|
||||||
|
if (!can_parallel || work < static_cast<uint64_t>(threads)*4ull) {
|
||||||
|
return matmul(A,B);
|
||||||
|
}
|
||||||
|
// Plenty of (i,j) work → collapse(2) is a great default.
|
||||||
|
else if (work >= 8ull * static_cast<uint64_t>(threads)) {
|
||||||
|
return matmul_collapse_omp(A,B);
|
||||||
|
}
|
||||||
|
// Many rows and very few columns → rows-only cheaper overhead.
|
||||||
|
else if (m >= static_cast<uint64_t>(threads) && p <= 4) {
|
||||||
|
return matmul_rows_omp(A,B);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Safe fallback
|
||||||
|
return matmul(A,B);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,11 +3,13 @@
|
|||||||
|
|
||||||
|
|
||||||
#include "./utils/matrix.h"
|
#include "./utils/matrix.h"
|
||||||
|
#include "./core/omp_config.h"
|
||||||
|
|
||||||
namespace numerics{
|
namespace numerics{
|
||||||
|
|
||||||
// y = A * x, where A is (m×n) and x is length n and y is length m
|
// =================================================
|
||||||
|
// y = A * x (Matrix–Vector product)
|
||||||
|
// =================================================
|
||||||
template <typename T>
|
template <typename T>
|
||||||
utils::Vector<T> matvec(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
|
utils::Vector<T> matvec(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
|
||||||
if (A.cols() != x.size()) {
|
if (A.cols() != x.size()) {
|
||||||
@@ -18,6 +20,27 @@ namespace numerics{
|
|||||||
const uint64_t n = A.cols();
|
const uint64_t n = A.cols();
|
||||||
|
|
||||||
utils::Vector<T> y(m, T{0});
|
utils::Vector<T> y(m, T{0});
|
||||||
|
|
||||||
|
for (uint64_t i = 0; i < m; ++i) {
|
||||||
|
for (uint64_t j = 0; j < n; ++j) {
|
||||||
|
y[i] += A(i, j) * x[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
// -------------- Collapse(2) OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Vector<T> matvec_omp(const utils::Matrix<T>& A, const utils::Vector<T>& x) {
|
||||||
|
if (A.cols() != x.size()) {
|
||||||
|
throw std::runtime_error("matvec: dimension mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint64_t m = A.rows();
|
||||||
|
const uint64_t n = A.cols();
|
||||||
|
|
||||||
|
utils::Vector<T> y(m, T{0});
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
for (uint64_t i = 0; i < m; ++i) {
|
for (uint64_t i = 0; i < m; ++i) {
|
||||||
T acc = T{0};
|
T acc = T{0};
|
||||||
for (uint64_t j = 0; j < n; ++j) {
|
for (uint64_t j = 0; j < n; ++j) {
|
||||||
@@ -28,7 +51,34 @@ namespace numerics{
|
|||||||
return y;
|
return y;
|
||||||
}
|
}
|
||||||
|
|
||||||
// y = x * A, where x is length m and A is (m×n) -> y is length n
|
// -------------- Auto OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Vector<T> matvec_auto(const utils::Matrix<T>& A,
|
||||||
|
const utils::Vector<T>& x) {
|
||||||
|
|
||||||
|
|
||||||
|
uint64_t work = A.rows() * A.cols();
|
||||||
|
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
#ifdef _OPENMP
|
||||||
|
int threads = omp_get_max_threads();
|
||||||
|
#else
|
||||||
|
int threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (can_parallel || work > static_cast<uint64_t>(threads) * 4ull) {
|
||||||
|
return matvec_omp(A,x);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Safe fallback
|
||||||
|
return matvec(A,x);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// =================================================
|
||||||
|
// y = x * A (Vector–Matrix product)
|
||||||
|
// =================================================
|
||||||
template <typename T>
|
template <typename T>
|
||||||
utils::Vector<T> vecmat(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
|
utils::Vector<T> vecmat(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
|
||||||
if (x.size() != A.rows()) {
|
if (x.size() != A.rows()) {
|
||||||
@@ -38,6 +88,26 @@ namespace numerics{
|
|||||||
const uint64_t n = A.cols();
|
const uint64_t n = A.cols();
|
||||||
|
|
||||||
utils::Vector<T> y(n, T{0});
|
utils::Vector<T> y(n, T{0});
|
||||||
|
|
||||||
|
for (uint64_t j = 0; j < n; ++j) {
|
||||||
|
for (uint64_t i = 0; i < m; ++i) {
|
||||||
|
y[j] += x[i] * A(i, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return y;
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------- Collapse(2) OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Vector<T> vecmat_omp(const utils::Vector<T>& x, const utils::Matrix<T>& A) {
|
||||||
|
if (x.size() != A.rows()) {
|
||||||
|
throw std::runtime_error("vecmat: dimension mismatch");
|
||||||
|
}
|
||||||
|
const uint64_t m = A.rows();
|
||||||
|
const uint64_t n = A.cols();
|
||||||
|
|
||||||
|
utils::Vector<T> y(n, T{0});
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
for (uint64_t j = 0; j < n; ++j) {
|
for (uint64_t j = 0; j < n; ++j) {
|
||||||
T acc = T{0};
|
T acc = T{0};
|
||||||
for (uint64_t i = 0; i < m; ++i) {
|
for (uint64_t i = 0; i < m; ++i) {
|
||||||
@@ -48,6 +118,30 @@ namespace numerics{
|
|||||||
return y;
|
return y;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -------------- Auto OpenMP ----------------
|
||||||
|
template <typename T>
|
||||||
|
utils::Vector<T> vecmat_auto(const utils::Vector<T>& x,
|
||||||
|
const utils::Matrix<T>& A) {
|
||||||
|
|
||||||
|
uint64_t work = A.rows() * A.cols();
|
||||||
|
|
||||||
|
bool can_parallel = omp_config::omp_parallel_allowed();
|
||||||
|
#ifdef _OPENMP
|
||||||
|
int threads = omp_get_max_threads();
|
||||||
|
#else
|
||||||
|
int threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (can_parallel || work > static_cast<uint64_t>(threads) * 4ull) {
|
||||||
|
return vecmat_omp(x,A);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
// Safe fallback
|
||||||
|
return vecmat(x,A);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace numerics
|
} // namespace numerics
|
||||||
|
|
||||||
|
|||||||
@@ -43,28 +43,6 @@ namespace numerics{
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
} // namespace numerics
|
} // namespace numerics
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif // _transpose_n_
|
#endif // _transpose_n_
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
#ifndef _numerics_n_
|
|
||||||
#define _numerics_n_
|
|
||||||
|
|
||||||
|
|
||||||
namespace utils{
|
|
||||||
|
|
||||||
double random(const double& min, const double& max);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // _numerics_n_
|
|
||||||
@@ -66,6 +66,20 @@ public:
|
|||||||
|
|
||||||
bool operator!=(const Vector<T>& a) const { return !(*this == a); }
|
bool operator!=(const Vector<T>& a) const { return !(*this == a); }
|
||||||
|
|
||||||
|
//##################################################
|
||||||
|
//# VECTOR: nearly_equal_vec #
|
||||||
|
//##################################################
|
||||||
|
|
||||||
|
bool nearly_equal_vec(const Vector<T>& a, double tol=1e-12) const {
|
||||||
|
if (a.size() != v.size()) return false;
|
||||||
|
for (uint64_t i=0;i<a.size();++i) {
|
||||||
|
if (std::fabs(a[i]-v[i])>tol) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//##################################################
|
//##################################################
|
||||||
//# VECTOR: Scalar Addition #
|
//# VECTOR: Scalar Addition #
|
||||||
//##################################################
|
//##################################################
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ SRC_DIR := src
|
|||||||
INC_DIR := include
|
INC_DIR := include
|
||||||
OBJ_DIR := obj
|
OBJ_DIR := obj
|
||||||
BIN_DIR := bin
|
BIN_DIR := bin
|
||||||
|
TEST_BIN := $(BIN_DIR)/tests
|
||||||
|
|
||||||
|
# All test sources
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@@ -23,13 +28,30 @@ SRCS := $(shell find $(SRC_DIR) -name '*.cpp')
|
|||||||
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(SRCS))
|
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(SRCS))
|
||||||
|
|
||||||
|
|
||||||
|
# === Test sources ===
|
||||||
|
TEST_SRCS := $(shell find test -name 'test_*.cpp')
|
||||||
|
TEST_OBJS := $(patsubst test/%.cpp, $(OBJ_DIR)/test/%.o, $(TEST_SRCS))
|
||||||
|
# The single file that defines TEST_MAIN / main()
|
||||||
|
TEST_MAIN := $(OBJ_DIR)/test/test_all.o
|
||||||
|
|
||||||
|
|
||||||
# === OpenMP runtime configuration (override-able) ===
|
# === OpenMP runtime configuration (override-able) ===
|
||||||
OMP_PROC_BIND ?= close # close|spread|master
|
OMP_PROC_BIND ?= close # close|spread|master
|
||||||
OMP_PLACES ?= cores # cores|threads|sockets
|
OMP_PLACES ?= cores # cores|threads|sockets
|
||||||
OMP_MAX_LEVELS ?= 1 # 1 = no nested teams; set 2+ to allow nesting
|
OMP_MAX_LEVELS ?= 2 # 1 = no nested teams; set 2+ to allow nesting
|
||||||
OMP_THREADS ?= 16 # e.g. "16" or "8,4" for nested (outer,inner)
|
OMP_THREADS ?= 16 # e.g. "16" or "8,4" for nested (outer,inner)
|
||||||
OMP_DYNAMIC ?= true # true/false: let runtime adjust threads
|
OMP_DYNAMIC ?= TRUE # TRUE/FALSE: let runtime adjust threads
|
||||||
OMP_DISPLAY_ENV ?= FALSE # TRUE to print runtime config at startup
|
OMP_SCHEDULE ?= STATIC # STATIC recommended for matvec/matmul
|
||||||
|
OMP_DISPLAY_ENV ?= TRUE # TRUE to print runtime config at startup
|
||||||
|
|
||||||
|
# Export OMP defaults so child makes or tools see them (not strictly required)
|
||||||
|
export OMP_PROC_BIND
|
||||||
|
export OMP_PLACES
|
||||||
|
export OMP_MAX_LEVELS
|
||||||
|
export OMP_THREADS
|
||||||
|
export OMP_DYNAMIC
|
||||||
|
export OMP_SCHEDULE
|
||||||
|
export OMP_DISPLAY_ENV
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -49,11 +71,19 @@ $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp
|
|||||||
# === Run with OpenMP env set only for the run ===
|
# === Run with OpenMP env set only for the run ===
|
||||||
.PHONY: run
|
.PHONY: run
|
||||||
run: $(TARGET)
|
run: $(TARGET)
|
||||||
OMP_PROC_BIND=$(OMP_PROC_BIND) \
|
@echo ">>> OMP_PROC_BIND=$(OMP_PROC_BIND)"
|
||||||
|
@echo ">>> OMP_PLACES=$(OMP_PLACES)"
|
||||||
|
@echo ">>> OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS)"
|
||||||
|
@echo ">>> OMP_NUM_THREADS=$(OMP_THREADS)"
|
||||||
|
@echo ">>> OMP_DYNAMIC=$(OMP_DYNAMIC)"
|
||||||
|
@echo ">>> OMP_SCHEDULE=$(OMP_SCHEDULE)"
|
||||||
|
@echo ">>> OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV)"
|
||||||
|
@OMP_PROC_BIND=$(OMP_PROC_BIND) \
|
||||||
OMP_PLACES=$(OMP_PLACES) \
|
OMP_PLACES=$(OMP_PLACES) \
|
||||||
OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS) \
|
OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS) \
|
||||||
OMP_NUM_THREADS="$(OMP_THREADS)" \
|
OMP_NUM_THREADS="$(OMP_THREADS)" \
|
||||||
OMP_DYNAMIC=$(OMP_DYNAMIC) \
|
OMP_DYNAMIC=$(OMP_DYNAMIC) \
|
||||||
|
OMP_SCHEDULE=$(OMP_SCHEDULE) \
|
||||||
OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV) \
|
OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV) \
|
||||||
./$(TARGET)
|
./$(TARGET)
|
||||||
|
|
||||||
@@ -78,3 +108,30 @@ info:
|
|||||||
@echo "Object files: $(OBJS)"
|
@echo "Object files: $(OBJS)"
|
||||||
@echo "CXXFLAGS: $(CXXFLAGS)"
|
@echo "CXXFLAGS: $(CXXFLAGS)"
|
||||||
@echo "LDFLAGS: $(LDFLAGS)"
|
@echo "LDFLAGS: $(LDFLAGS)"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: $(TEST_BIN)
|
||||||
|
@echo ">>> OMP_PROC_BIND=$(OMP_PROC_BIND)"
|
||||||
|
@echo ">>> OMP_PLACES=$(OMP_PLACES)"
|
||||||
|
@echo ">>> OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS)"
|
||||||
|
@echo ">>> OMP_NUM_THREADS=$(OMP_THREADS)"
|
||||||
|
@echo ">>> OMP_DYNAMIC=$(OMP_DYNAMIC)"
|
||||||
|
@echo ">>> OMP_SCHEDULE=$(OMP_SCHEDULE)"
|
||||||
|
@echo ">>> OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV)"
|
||||||
|
@OMP_PROC_BIND=$(OMP_PROC_BIND) \
|
||||||
|
OMP_PLACES=$(OMP_PLACES) \
|
||||||
|
OMP_MAX_ACTIVE_LEVELS=$(OMP_MAX_LEVELS) \
|
||||||
|
OMP_NUM_THREADS="$(OMP_THREADS)" \
|
||||||
|
OMP_DYNAMIC=$(OMP_DYNAMIC) \
|
||||||
|
OMP_SCHEDULE=$(OMP_SCHEDULE) \
|
||||||
|
OMP_DISPLAY_ENV=$(OMP_DISPLAY_ENV) \
|
||||||
|
$(TEST_BIN)
|
||||||
|
|
||||||
|
$(TEST_BIN): $(TEST_OBJS) $(TEST_MAIN)
|
||||||
|
@mkdir -p $(BIN_DIR)
|
||||||
|
$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
|
||||||
|
|
||||||
|
$(OBJ_DIR)/test/%.o: test/%.cpp
|
||||||
|
@mkdir -p $(dir $@)
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
BIN
Binary file not shown.
+6
-553
@@ -1,5 +1,6 @@
|
|||||||
#include "./utils/utils.h"
|
#include "./utils/utils.h"
|
||||||
#include "./numerics/numerics.h"
|
#include "./numerics/numerics.h"
|
||||||
|
#include "./decomp/decomp.h"
|
||||||
#include "./core/omp_config.h"
|
#include "./core/omp_config.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
@@ -7,566 +8,18 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define CHECK(cond, msg) \
|
|
||||||
do { if (!(cond)) throw std::runtime_error(msg); } while (0)
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void expect_throw(F&& f, const char* msg_if_no_throw) {
|
|
||||||
try {
|
|
||||||
f();
|
|
||||||
throw std::runtime_error(msg_if_no_throw);
|
|
||||||
} catch (const std::exception&) {
|
|
||||||
// ok: an exception was thrown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int main(int argc, char const *argv[])
|
int main(int argc, char const *argv[])
|
||||||
{
|
{
|
||||||
|
|
||||||
|
utils::Md A;
|
||||||
|
decomp::LUd lu;
|
||||||
|
lu.factor(A);
|
||||||
|
|
||||||
// Single-level, 16 threads, runtime may adjust
|
// Single-level, 16 threads, runtime may adjust
|
||||||
omp_configure(/*max_levels=*/1, /*dynamic=*/true, /*threads_per_level=*/{16});
|
//omp_configure(/*max_levels=*/1, /*dynamic=*/true, /*threads_per_level=*/{16});
|
||||||
|
|
||||||
using utils::Vi;
|
|
||||||
using utils::Vf;
|
|
||||||
using utils::Vd;
|
|
||||||
using utils::Mi;
|
|
||||||
using utils::Mf;
|
|
||||||
using utils::Md;
|
|
||||||
|
|
||||||
// ---------------- Equality / Inequality ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 1.0f); // [1,1,1]
|
|
||||||
Vf b(3, 1.0f); // [1,1,1]
|
|
||||||
Vf c(3, 2.0f); // [2,2,2]
|
|
||||||
|
|
||||||
CHECK(a == b, "a should equal b");
|
|
||||||
CHECK(!(a != b), "a should not be != b");
|
|
||||||
CHECK(a != c, "a should not equal c");
|
|
||||||
|
|
||||||
// mutate one element
|
|
||||||
a[1] = 5.0f;
|
|
||||||
CHECK(a != b, "after mutation, a should differ from b");
|
|
||||||
a[1] = 1.0f; // restore
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Vector + Vector (and +=) ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 1.0f); // [1,1,1]
|
|
||||||
Vf b(3, 2.0f); // [2,2,2]
|
|
||||||
Vf expect(3, 3.0f); // [3,3,3]
|
|
||||||
|
|
||||||
Vf c = a + b;
|
|
||||||
CHECK(c == expect, "a + b should be [3,3,3]");
|
|
||||||
|
|
||||||
a += b; // a becomes [3,3,3]
|
|
||||||
CHECK(a == expect, "a += b should produce [3,3,3]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Vector - Vector (and -=) ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 5.0f); // [5,5,5]
|
|
||||||
Vf b(3, 2.0f); // [2,2,2]
|
|
||||||
Vf expect(3, 3.0f); // [3,3,3]
|
|
||||||
|
|
||||||
Vf c = a - b;
|
|
||||||
CHECK(c == expect, "a - b should be [3,3,3]");
|
|
||||||
|
|
||||||
a -= b; // a becomes [3,3,3]
|
|
||||||
CHECK(a == expect, "a -= b should produce [3,3,3]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Elementwise Multiply (and *=) ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
Vf b(3, 3.0f); // [3,3,3]
|
|
||||||
Vf expect(3, 6.0f); // [6,6,6]
|
|
||||||
|
|
||||||
Vf c = a * b;
|
|
||||||
CHECK(c == expect, "a * b (elemwise) should be [6,6,6]");
|
|
||||||
|
|
||||||
a *= b; // a becomes [6,6,6]
|
|
||||||
CHECK(a == expect, "a *= b should produce [6,6,6]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Elementwise Divide (and /=) ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 6.0f); // [6,6,6]
|
|
||||||
Vf b(3, 2.0f); // [2,2,2]
|
|
||||||
Vf expect(3, 3.0f); // [3,3,3]
|
|
||||||
|
|
||||||
Vf c = a / b;
|
|
||||||
CHECK(c == expect, "a / b (elemwise) should be [3,3,3]");
|
|
||||||
|
|
||||||
a /= b; // a becomes [3,3,3]
|
|
||||||
CHECK(a == expect, "a /= b should produce [3,3,3]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Scalar + Vector (and +=) ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 1.0f); // [1,1,1]
|
|
||||||
Vf expect1(3, 6.0f); // [6,6,6]
|
|
||||||
Vf expect2(3, 3.0f); // [3,3,3]
|
|
||||||
|
|
||||||
Vf c = a + 5.0f; // v + s
|
|
||||||
CHECK(c == expect1, "a + 5 should be [6,6,6]");
|
|
||||||
|
|
||||||
Vf d = 2.0f + a; // s + v (friend operator)
|
|
||||||
CHECK(d == expect2, "2 + a should be [3,3,3]");
|
|
||||||
|
|
||||||
a += 2.0f; // a becomes [3,3,3]
|
|
||||||
CHECK(a == expect2, "a += 2 should produce [3,3,3]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Scalar - Vector / Vector - Scalar ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 5.0f); // [5,5,5]
|
|
||||||
Vf expect1(3, 3.0f); // [3,3,3]
|
|
||||||
Vf expect2(3, -3.0f); // [ -3,-3,-3 ] if 2 - a (only if you've implemented it)
|
|
||||||
|
|
||||||
Vf c = a - 2.0f; // v - s
|
|
||||||
CHECK(c == expect1, "a - 2 should be [3,3,3]");
|
|
||||||
|
|
||||||
// NOTE: Your friend operator-(U a, const Vector<T> b) currently returns (b - a),
|
|
||||||
// which means `2 - a` computes `a - 2`. That's a bit unusual.
|
|
||||||
// We'll avoid asserting 2 - a here to match your current implementation choice.
|
|
||||||
(void)expect2; // silence unused warning
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Scalar * Vector / Vector * Scalar ----------------
|
|
||||||
{
|
|
||||||
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
uint64_t b = 3; // 3
|
|
||||||
Vf expect(3, 6.0f); // [6,6,6]
|
|
||||||
|
|
||||||
Vf c = a * b; // v * s
|
|
||||||
CHECK(c == expect, "a * 3 should be [6,6,6]");
|
|
||||||
|
|
||||||
Vf d = b * a; // s * v (friend)
|
|
||||||
CHECK(d == expect, "3 * a should be [6,6,6]");
|
|
||||||
|
|
||||||
a *= b; // a becomes [6,6,6]
|
|
||||||
CHECK(a == expect, "a *= 3 should produce [6,6,6]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Vector / Scalar (and /= scalar) ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 6.0f); // [6,6,6]
|
|
||||||
uint64_t b = 2; // 3
|
|
||||||
Vf expect(3, 3.0f); // [3,3,3]
|
|
||||||
|
|
||||||
Vf c = a / b; // v / s
|
|
||||||
CHECK(c == expect, "a / 2 should be [3,3,3]");
|
|
||||||
|
|
||||||
a /= b; // a becomes [3,3,3]
|
|
||||||
CHECK(a == expect, "a /= 2 should produce [3,3,3]");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- sum ----------
|
|
||||||
{
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
CHECK(a.sum() == 6.0f, "sum failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- dot ----------
|
|
||||||
{
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
Vf b(3, 3.0f); // [3,3,3]
|
|
||||||
CHECK(a.dot(b) == 18.0f, "dot failed"); // 2*3 * 3 = 18
|
|
||||||
Vf c(4, 1.0f);
|
|
||||||
expect_throw([&]{ (void)a.dot(c); }, "dot should throw on size mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- norm ----------
|
|
||||||
{
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
float n = a.norm();
|
|
||||||
CHECK(std::fabs(n - std::sqrt(12.0f)) < 1e-6f, "norm failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- normalize ----------
|
|
||||||
{
|
|
||||||
Vf a(3, 3.0f); // [3,3,3], norm = sqrt(27)
|
|
||||||
Vf b = a.normalize();
|
|
||||||
float n = b.norm();
|
|
||||||
CHECK(std::fabs(n - 1.0f) < 1e-6f, "normalize failed");
|
|
||||||
expect_throw([&]{ Vf z(3, 0.0f); z.inplace_normalize(); }, "normalize should throw on zero norm");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- scalar power ----------
|
|
||||||
{
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
Vf c = a.power(3); // [8,8,8]
|
|
||||||
CHECK(c == Vf(3, 8.0f), "power(scalar) failed");
|
|
||||||
|
|
||||||
Vf d = a; d.inplace_power(4); // [16,16,16]
|
|
||||||
CHECK(d == Vf(3, 16.0f), "inplace_power(scalar) failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- vector power ----------
|
|
||||||
{
|
|
||||||
Vf base(3, 2.0f); // [2,2,2]
|
|
||||||
Vf exps; exps.v = {1.0f, 2.0f, 3.0f}; // explicit construction for clarity
|
|
||||||
Vf out = base.power(exps); // [2^1, 2^2, 2^3] = [2,4,8]
|
|
||||||
Vf expect; expect.v = {2.0f, 4.0f, 8.0f};
|
|
||||||
CHECK(out == expect, "power(vector) failed");
|
|
||||||
|
|
||||||
expect_throw([&]{ Vf bad(2, 1.0f); (void)base.power(bad); },
|
|
||||||
"power(vector) should throw on size mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- square ----------
|
|
||||||
{
|
|
||||||
Vf a; a.v = {4.0f, 9.0f, 16.0f};
|
|
||||||
Vf b = a.sqrt(); // [4,9,16]
|
|
||||||
Vf expect; expect.v = {2.0f, 3.0f, 4.0f};
|
|
||||||
CHECK(b == expect, "sqrt failed");
|
|
||||||
|
|
||||||
a.inplace_sqrt(); // mutate a to [4,9,16]
|
|
||||||
CHECK(a == expect, "inplace_square failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- scalar commutative friends (s + v, s * v) ----------
|
|
||||||
{
|
|
||||||
Vf a(3, 2.0f); // [2,2,2]
|
|
||||||
Vf b = 3.0f + a; // [5,5,5]
|
|
||||||
Vf c = a + 3.0f; // [5,5,5]
|
|
||||||
CHECK(b == c, "s+v commutative failed");
|
|
||||||
|
|
||||||
Vf d = 4.0f * a; // [8,8,8]
|
|
||||||
Vf e = a * 4.0f; // [8,8,8]
|
|
||||||
CHECK(d == e, "s*v commutative failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------- Size mismatch throws ----------------
|
|
||||||
{
|
|
||||||
Vf a(3, 1.0f);
|
|
||||||
Vf b(4, 2.0f);
|
|
||||||
|
|
||||||
expect_throw([&] { a.inplace_add(b); },
|
|
||||||
"inplace_add should throw on size mismatch");
|
|
||||||
expect_throw([&] { (void)(a + b); },
|
|
||||||
"operator+ should throw (through add) on size mismatch");
|
|
||||||
expect_throw([&] { a.inplace_subtract(b); },
|
|
||||||
"inplace_subtract should throw on size mismatch");
|
|
||||||
expect_throw([&] { (void)(a - b); },
|
|
||||||
"operator- should throw (through subtract) on size mismatch");
|
|
||||||
expect_throw([&] { a.inplace_multiply(b); },
|
|
||||||
"inplace_multiply should throw on size mismatch");
|
|
||||||
expect_throw([&] { (void)(a * b); },
|
|
||||||
"operator* should throw (through multiply) on size mismatch");
|
|
||||||
expect_throw([&] { a.inplace_divide(b); },
|
|
||||||
"inplace_divide should throw on size mismatch");
|
|
||||||
expect_throw([&] { (void)(a / b); },
|
|
||||||
"operator/ should throw (through divide) on size mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
|
|
||||||
auto* a = new utils::Vf(3, 1.0f); // constructor runs
|
|
||||||
delete a; // <- calls ~Vector() and frees memory
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Vf a(2, 1.0f); // a = [1, 1]
|
|
||||||
Vf b(2, 1.0f); // b = [1, 1]
|
|
||||||
|
|
||||||
a.clear(); // a = []
|
|
||||||
CHECK(a.size() == 0, "clear() did not empty vector");
|
|
||||||
|
|
||||||
a.resize(2, 1.0f); // a = [1, 1]
|
|
||||||
|
|
||||||
CHECK(a == b, "clear/resize lifecycle failed");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "All Vector tests passed ✅\n";
|
|
||||||
|
|
||||||
|
|
||||||
// shape + element access
|
|
||||||
{
|
|
||||||
Mf M(3, 4, 0.0f);
|
|
||||||
CHECK(M.rows()==3 && M.cols()==4, "shape failed");
|
|
||||||
|
|
||||||
M(1,1) = 5.0f;
|
|
||||||
CHECK(M(1,1) == 5.0f, "write/read element failed");
|
|
||||||
|
|
||||||
// ensure independence of other cells
|
|
||||||
CHECK(M(0,0) == 0.0f && M(2,3) == 0.0f, "unexpected element modified");
|
|
||||||
}
|
|
||||||
|
|
||||||
// set/get row (with size checks)
|
|
||||||
{
|
|
||||||
Mf M(2, 3, 0.0f); // 2x3
|
|
||||||
|
|
||||||
Vf r(3, 0.0f);
|
|
||||||
r[0]=1; r[1]=2; r[2]=3;
|
|
||||||
M.set_row(1, r);
|
|
||||||
|
|
||||||
Vf g = M.get_row(1);
|
|
||||||
CHECK(g.size()==3, "get_row size wrong");
|
|
||||||
CHECK(g[0]==1 && g[1]==2 && g[2]==3, "get_row values wrong");
|
|
||||||
|
|
||||||
// size mismatch should throw
|
|
||||||
bool threw=false;
|
|
||||||
try {
|
|
||||||
Vf bad(2, 9.0f);
|
|
||||||
M.set_row(0, bad);
|
|
||||||
} catch (const std::exception&) { threw=true; }
|
|
||||||
CHECK(threw, "set_row should throw on size mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
// set/get col (with size checks)
|
|
||||||
{
|
|
||||||
Mf M(3, 2, 0.0f); // 3x2
|
|
||||||
|
|
||||||
Vf c(3, 0.0f);
|
|
||||||
c[0]=4; c[1]=5; c[2]=6;
|
|
||||||
M.set_col(1, c);
|
|
||||||
|
|
||||||
Vf h = M.get_col(1);
|
|
||||||
CHECK(h.size()==3, "get_col size wrong");
|
|
||||||
CHECK(h[0]==4 && h[1]==5 && h[2]==6, "get_col values wrong");
|
|
||||||
|
|
||||||
bool threw=false;
|
|
||||||
try {
|
|
||||||
Vf bad(2, 7.0f);
|
|
||||||
M.set_col(0, bad);
|
|
||||||
} catch (const std::exception&) { threw=true; }
|
|
||||||
CHECK(threw, "set_col should throw on size mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
// swap_rows / swap_cols
|
|
||||||
{
|
|
||||||
Mf M(3, 3, 0.0f);
|
|
||||||
// set rows to [1,2,3], [4,5,6], [7,8,9]
|
|
||||||
for (uint64_t j=0;j<3;++j) M(0,j) = 1.0f + j;
|
|
||||||
for (uint64_t j=0;j<3;++j) M(1,j) = 4.0f + j;
|
|
||||||
for (uint64_t j=0;j<3;++j) M(2,j) = 7.0f + j;
|
|
||||||
|
|
||||||
M.swap_rows(0,2);
|
|
||||||
CHECK(M(0,0)==7 && M(0,1)==8 && M(0,2)==9, "swap_rows top row wrong");
|
|
||||||
CHECK(M(2,0)==1 && M(2,1)==2 && M(2,2)==3, "swap_rows bottom row wrong");
|
|
||||||
|
|
||||||
M.swap_cols(0,2);
|
|
||||||
// after col swap: first row should be [9,8,7]
|
|
||||||
CHECK(M(0,0)==9 && M(0,1)==8 && M(0,2)==7, "swap_cols first row wrong");
|
|
||||||
// bottom row should be [3,2,1]
|
|
||||||
CHECK(M(2,0)==3 && M(2,1)==2 && M(2,2)==1, "swap_cols last row wrong");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exact integer comparison / Floating-point exact equality / Floating-point with small perturbation
|
|
||||||
{
|
|
||||||
|
|
||||||
Mi A(2,2,0);
|
|
||||||
A(0,0)=1; A(0,1)=2;
|
|
||||||
A(1,0)=3; A(1,1)=4;
|
|
||||||
|
|
||||||
Mi B(2,2,0);
|
|
||||||
B(0,0)=1; B(0,1)=2;
|
|
||||||
B(1,0)=3; B(1,1)=4;
|
|
||||||
|
|
||||||
Mi C(2,2,0);
|
|
||||||
C(0,0)=9; C(0,1)=9;
|
|
||||||
C(1,0)=9; C(1,1)=9;
|
|
||||||
|
|
||||||
CHECK(A == B, "Matrix == failed on identical int matrices");
|
|
||||||
CHECK(!(A != B), "Matrix != failed on identical int matrices");
|
|
||||||
CHECK(A != C, "Matrix != failed on different int matrices");
|
|
||||||
|
|
||||||
// Floating-point exact equality
|
|
||||||
|
|
||||||
|
|
||||||
Md F1(2,2,0.0);
|
|
||||||
F1(0,0)=1.0; F1(0,1)=2.0;
|
|
||||||
F1(1,0)=3.0; F1(1,1)=4.0;
|
|
||||||
|
|
||||||
Md F2(2,2,0.0);
|
|
||||||
F2(0,0)=1.0; F2(0,1)=2.0;
|
|
||||||
F2(1,0)=3.0; F2(1,1)=4.0;
|
|
||||||
|
|
||||||
CHECK(F1 == F2, "Matrix == failed on identical float matrices");
|
|
||||||
|
|
||||||
// Floating-point with small perturbation
|
|
||||||
|
|
||||||
Md F3 = F1;
|
|
||||||
F3(1,1) += 1e-10; // tiny difference
|
|
||||||
|
|
||||||
CHECK(!(F1 == F3), "Matrix == should fail on exact compare with perturbation");
|
|
||||||
CHECK(F1.nearly_equal(F3, 1e-9), "Matrix nearly_equal failed with tolerance");
|
|
||||||
|
|
||||||
// Larger perturbation
|
|
||||||
F3(1,1) += 1e-3;
|
|
||||||
CHECK(!F1.nearly_equal(F3, 1e-6), "Matrix nearly_equal should fail when tolerance too small");
|
|
||||||
CHECK(F1.nearly_equal(F3, 1e-2), "Matrix nearly_equal should pass with loose tolerance");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "Matrix basic tests passed ✅\n";
|
|
||||||
|
|
||||||
// --- Test: normal transpose ---
|
|
||||||
{
|
|
||||||
Mf M(2, 3, 0.0f);
|
|
||||||
// Fill: [ [1,2,3],
|
|
||||||
// [4,5,6] ]
|
|
||||||
M(0,0)=1; M(0,1)=2; M(0,2)=3;
|
|
||||||
M(1,0)=4; M(1,1)=5; M(1,2)=6;
|
|
||||||
|
|
||||||
Mf MT = numerics::transpose(M);
|
|
||||||
|
|
||||||
// Should be shape 3x2
|
|
||||||
CHECK(MT.rows()==3 && MT.cols()==2, "transpose shape wrong");
|
|
||||||
|
|
||||||
// Values: [ [1,4], [2,5], [3,6] ]
|
|
||||||
CHECK(MT(0,0)==1 && MT(0,1)==4, "transpose value (0,*) wrong");
|
|
||||||
CHECK(MT(1,0)==2 && MT(1,1)==5, "transpose value (1,*) wrong");
|
|
||||||
CHECK(MT(2,0)==3 && MT(2,1)==6, "transpose value (2,*) wrong");
|
|
||||||
|
|
||||||
//std::cout << "Original M:\n" << M << "\n";
|
|
||||||
//std::cout << "Transposed MT:\n" << MT << "\n\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Test: inplace transpose (square only) ---
|
|
||||||
{
|
|
||||||
Mf S(3, 3, 0.0f);
|
|
||||||
// Fill with row-major increasing
|
|
||||||
float val = 1.0f;
|
|
||||||
for (uint64_t i=0;i<S.rows();++i) {
|
|
||||||
for (uint64_t j=0;j<S.cols();++j) {
|
|
||||||
S(i,j) = val++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// S =
|
|
||||||
// [1,2,3]
|
|
||||||
// [4,5,6]
|
|
||||||
// [7,8,9]
|
|
||||||
numerics::inplace_transpose(S);
|
|
||||||
|
|
||||||
// Expected after transpose:
|
|
||||||
// [1,4,7]
|
|
||||||
// [2,5,8]
|
|
||||||
// [3,6,9]
|
|
||||||
CHECK(S(0,1)==4 && S(0,2)==7, "inplace_transpose first row wrong");
|
|
||||||
CHECK(S(1,0)==2 && S(1,2)==8, "inplace_transpose second row wrong");
|
|
||||||
CHECK(S(2,0)==3 && S(2,1)==6, "inplace_transpose third row wrong");
|
|
||||||
|
|
||||||
//std::cout << "Square matrix after inplace_transpose:\n" << S << "\n\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Test: inplace transpose throws on non-square ---
|
|
||||||
{
|
|
||||||
Mf Rect(2, 3, 1.0f);
|
|
||||||
bool threw = false;
|
|
||||||
try {
|
|
||||||
numerics::inplace_transpose(Rect);
|
|
||||||
} catch (const std::runtime_error&) {
|
|
||||||
threw = true;
|
|
||||||
}
|
|
||||||
CHECK(threw, "inplace_transpose should throw on non-square matrix");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "Transpose tests passed ✅\n";
|
|
||||||
|
|
||||||
// matmul test
|
|
||||||
{
|
|
||||||
Md A(2,2,0.0);
|
|
||||||
A(0,0) = 1; A(0,1) = 2;
|
|
||||||
A(1,0) = 3; A(1,1) = 4;
|
|
||||||
|
|
||||||
Md B(2,2,0.0);
|
|
||||||
B(0,0) = 2; B(0,1) = 0;
|
|
||||||
B(1,0) = 1; B(1,1) = 2;
|
|
||||||
|
|
||||||
Md C = numerics::matmul(A, B);
|
|
||||||
|
|
||||||
// Expected result:
|
|
||||||
// [1*2+2*1, 1*0+2*2] = [4, 4]
|
|
||||||
// [3*2+4*1, 3*0+4*2] = [10, 8]
|
|
||||||
CHECK(C(0,0)==4 && C(0,1)==4, "matmul: first row wrong");
|
|
||||||
CHECK(C(1,0)==10 && C(1,1)==8, "matmul: second row wrong");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "Matmul test passed ✅\n";
|
|
||||||
|
|
||||||
// matvec test
|
|
||||||
{
|
|
||||||
// A = [[1,2,3],
|
|
||||||
// [4,5,6]] (2x3)
|
|
||||||
Md A(2,3,0.0);
|
|
||||||
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
|
||||||
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
|
||||||
|
|
||||||
// x = [7,8,9]
|
|
||||||
Vd x(3,0.0);
|
|
||||||
x[0]=7; x[1]=8; x[2]=9;
|
|
||||||
|
|
||||||
// y = A*x = [50, 122]
|
|
||||||
Vd y = numerics::matvec<double>(A, x);
|
|
||||||
CHECK(y.size()==2, "matvec size wrong");
|
|
||||||
CHECK(y[0]==50 && y[1]==122, "matvec values wrong");
|
|
||||||
|
|
||||||
// dimension mismatch should throw
|
|
||||||
bool threw = false;
|
|
||||||
try {
|
|
||||||
Vd bad(4,1.0);
|
|
||||||
(void)numerics::matvec<double>(A, bad);
|
|
||||||
} catch (const std::runtime_error&) { threw = true; }
|
|
||||||
CHECK(threw, "matvec: expected throw on dim mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "matvec tests passed ✅\n";
|
|
||||||
|
|
||||||
// vecmat test
|
|
||||||
{
|
|
||||||
// A = [[1,2],
|
|
||||||
// [3,4]] (2x2)
|
|
||||||
Md A(2,2,0.0);
|
|
||||||
A(0,0)=1; A(0,1)=2;
|
|
||||||
A(1,0)=3; A(1,1)=4;
|
|
||||||
|
|
||||||
// x^T = [5,6]
|
|
||||||
Vd x(2,0.0);
|
|
||||||
x[0]=5; x[1]=6;
|
|
||||||
|
|
||||||
// y = x^T * A = [5*1+6*3, 5*2+6*4] = [23, 34]
|
|
||||||
Vd y = numerics::vecmat<double>(x, A);
|
|
||||||
CHECK(y.size()==2, "vecmat size wrong");
|
|
||||||
CHECK(y[0]==23 && y[1]==34, "vecmat values wrong");
|
|
||||||
|
|
||||||
// mismatch should throw
|
|
||||||
bool threw = false;
|
|
||||||
try {
|
|
||||||
Md B(3,2,0.0); // 3x2, doesn't match x size 2
|
|
||||||
(void)numerics::vecmat<double>(x, B);
|
|
||||||
} catch (const std::runtime_error&) { threw = true; }
|
|
||||||
CHECK(threw, "vecmat: expected throw on dim mismatch");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "vecmat tests passed ✅\n";
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Inverse 'Gauss-Jordan' tests
|
|
||||||
{
|
|
||||||
Md A(2,2,0.0);
|
|
||||||
A(0,0)=4; A(0,1)=7;
|
|
||||||
A(1,0)=2; A(1,1)=6;
|
|
||||||
|
|
||||||
Md Ai = numerics::inverse(A, "Gauss-Jordan");
|
|
||||||
|
|
||||||
Md I1 = numerics::matmul(A, Ai);
|
|
||||||
Md I2 = numerics::matmul(Ai, A);
|
|
||||||
|
|
||||||
Md I(2,2,0.0);
|
|
||||||
I(0,0)=1; I(1,1)=1;
|
|
||||||
|
|
||||||
CHECK(I1.nearly_equal(I), "A*inv(A) != I");
|
|
||||||
CHECK(I2.nearly_equal(I), "inv(A)*A != I");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "Inverse 'Gauss-Jordan' tests passed ✅\n";
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
#define TEST_MAIN
|
||||||
|
#include "test_common.h"
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
#ifndef _test_common_n_
|
||||||
|
#define _test_common_n_
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct TestFailure : public std::runtime_error {
|
||||||
|
using std::runtime_error::runtime_error;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#define CHECK(cond, msg) do { if (!(cond)) throw TestFailure(msg); } while (0)
|
||||||
|
#define CHECK_EQ(a,b,msg) do { if (!((a)==(b))) { throw TestFailure(std::string(msg) + " (" #a " != " #b ")"); } } while (0)
|
||||||
|
|
||||||
|
#define TEST_CASE(name) \
|
||||||
|
static void name(); \
|
||||||
|
struct name##_registrar { name##_registrar(){ TestRegistry::add(#name, &name);} } name##_registrar_instance; \
|
||||||
|
static void name()
|
||||||
|
|
||||||
|
struct TestRegistry {
|
||||||
|
using Fn = void(*)();
|
||||||
|
static std::vector<std::pair<std::string, Fn>>& list() {
|
||||||
|
static std::vector<std::pair<std::string, Fn>> v; return v;
|
||||||
|
}
|
||||||
|
static void add(const std::string& name, Fn fn) { list().push_back({name, fn}); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Default test runner main()
|
||||||
|
#ifdef TEST_MAIN
|
||||||
|
int main() {
|
||||||
|
int fails = 0;
|
||||||
|
for (auto& t : TestRegistry::list()) {
|
||||||
|
try {
|
||||||
|
t.second();
|
||||||
|
std::cout << "[PASS] " << t.first << "\n";
|
||||||
|
} catch (const TestFailure& e) {
|
||||||
|
std::cerr << "[FAIL] " << t.first << " -> " << e.what() << "\n";
|
||||||
|
++fails;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
std::cerr << "[ERROR] " << t.first << " -> " << e.what() << "\n";
|
||||||
|
++fails;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << (fails ? "Some tests failed ❌\n" : "All tests passed ✅\n");
|
||||||
|
return fails ? 1 : 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // _test_common_n_
|
||||||
@@ -0,0 +1,126 @@
|
|||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/utils.h"
|
||||||
|
#include "./numerics/inverse.h"
|
||||||
|
#include "./numerics/matmul.h"
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_2x2_WellConditioned) {
|
||||||
|
using T = double;
|
||||||
|
// A = [[4,7],[2,6]] inverse = (1/10) * [[6,-7],[-2,4]]
|
||||||
|
utils::Matrix<T> A(2,2, T{0});
|
||||||
|
A(0,0)=4; A(0,1)=7;
|
||||||
|
A(1,0)=2; A(1,1)=6;
|
||||||
|
|
||||||
|
auto Ainv = numerics::inverse<T>(A); // out-of-place
|
||||||
|
|
||||||
|
// Check A * Ainv ≈ I and Ainv * A ≈ I
|
||||||
|
auto Ileft = numerics::matmul(A, Ainv);
|
||||||
|
auto Iright = numerics::matmul(Ainv, A);
|
||||||
|
|
||||||
|
utils::Md Iref(2,2, T{0});
|
||||||
|
for (uint64_t i=0;i<Iref.rows();++i) Iref(i,i)=T{1};
|
||||||
|
|
||||||
|
//auto Iref = eye<T>(2);
|
||||||
|
|
||||||
|
CHECK((Ileft.nearly_equal(Iref, 1e-12)), "A * inverse(A) ≠ I");
|
||||||
|
CHECK((Iright.nearly_equal(Iref, 1e-12)), "inverse(A) * A ≠ I");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_InPlace_Equals_OutOfPlace) {
|
||||||
|
using T = double;
|
||||||
|
utils::Matrix<T> A(3,3, T{0});
|
||||||
|
// A = [[3, 0, 2],
|
||||||
|
// [2, 0, -2],
|
||||||
|
// [0, 1, 1]]
|
||||||
|
A(0,0)=3; A(0,1)=0; A(0,2)= 2;
|
||||||
|
A(1,0)=2; A(1,1)=0; A(1,2)=-2;
|
||||||
|
A(2,0)=0; A(2,1)=1; A(2,2)= 1;
|
||||||
|
|
||||||
|
auto Ainv_ref = numerics::inverse<T>(A); // copy path
|
||||||
|
|
||||||
|
auto A_inp = A;
|
||||||
|
numerics::inplace_inverse<T>(A_inp); // in-place path
|
||||||
|
|
||||||
|
CHECK((A_inp.nearly_equal(Ainv_ref, 1e-12)), "in-place inverse differs from out-of-place");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_Singular_Throws) {
|
||||||
|
using T = double;
|
||||||
|
utils::Matrix<T> S(2,2, T{0});
|
||||||
|
// Singular: rows are multiples → det = 0
|
||||||
|
S(0,0)=1; S(0,1)=2;
|
||||||
|
S(1,0)=2; S(1,1)=4;
|
||||||
|
|
||||||
|
bool threw=false;
|
||||||
|
try {
|
||||||
|
auto _ = numerics::inverse<T>(S);
|
||||||
|
(void)_;
|
||||||
|
} catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "inverse should throw on singular matrix");
|
||||||
|
|
||||||
|
threw=false;
|
||||||
|
try {
|
||||||
|
numerics::inplace_inverse<T>(S);
|
||||||
|
} catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "inplace_inverse should throw on singular matrix");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_RoundTrip_DiagonallyDominant_5x5) {
|
||||||
|
// Build a well-conditioned 5x5: diagonally dominant
|
||||||
|
utils::Md A(5,5,0.0);
|
||||||
|
for (uint64_t i=0;i<5;++i) {
|
||||||
|
double rowsum = 0.0;
|
||||||
|
for (uint64_t j=0;j<5;++j) {
|
||||||
|
if (i==j) continue;
|
||||||
|
A(i,j) = 0.01 * double(1 + ((i+1)*(j+3)) % 7);
|
||||||
|
rowsum += std::fabs(A(i,j));
|
||||||
|
}
|
||||||
|
A(i,i) = rowsum + 1.0; // strictly diagonally dominant
|
||||||
|
}
|
||||||
|
|
||||||
|
utils::Md A_copy = A; // ensure wrapper doesn't mutate input
|
||||||
|
utils::Md Ainv = numerics::inverse<double>(A);
|
||||||
|
|
||||||
|
// Input must be unchanged by the non-inplace wrapper
|
||||||
|
CHECK(A.nearly_equal(A_copy, 0.0), "inverse wrapper modified input");
|
||||||
|
|
||||||
|
|
||||||
|
utils::Md I(5,5, 0);
|
||||||
|
for (uint64_t i=0;i<I.rows();++i) I(i,i)=1;
|
||||||
|
|
||||||
|
|
||||||
|
auto L = numerics::matmul<double>(A, Ainv);
|
||||||
|
auto R = numerics::matmul<double>(Ainv, A);
|
||||||
|
|
||||||
|
CHECK(L.nearly_equal(I, 1e-10), "A * Ainv not close to I");
|
||||||
|
CHECK(R.nearly_equal(I, 1e-10), "Ainv * A not close to I");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_NonSquare_Throws) {
|
||||||
|
// Non-square: 2x3 — algorithm expects square; should throw
|
||||||
|
utils::Md A(2,3,0.0);
|
||||||
|
bool threw = false;
|
||||||
|
try {
|
||||||
|
numerics::inplace_inverse<double>(A);
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
threw = true;
|
||||||
|
} catch (...) {
|
||||||
|
threw = true; // any failure is fine; must not silently succeed
|
||||||
|
}
|
||||||
|
CHECK(threw, "inplace_inverse should throw on non-square matrix");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE(Inverse_Unknown_Method_Throws) {
|
||||||
|
|
||||||
|
utils::Md A(3,3, 0);
|
||||||
|
for (uint64_t i=0;i<A.rows();++i) A(i,i)=1;
|
||||||
|
|
||||||
|
bool threw = false;
|
||||||
|
try {
|
||||||
|
numerics::inplace_inverse<double>(A, "NotARealMethod");
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
threw = true;
|
||||||
|
}
|
||||||
|
CHECK(threw, "should throw for unknown inverse method");
|
||||||
|
}
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/utils.h"
|
||||||
|
#include "./numerics/matmul.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
|
||||||
|
// ============ Basic correctness ============
|
||||||
|
TEST_CASE(Matmul_Serial_Simple3x3) {
|
||||||
|
utils::Md A(3,3,0.0), B(3,3,0.0);
|
||||||
|
// A = [[1,2,3],[4,5,6],[7,8,9]]
|
||||||
|
double v=1.0;
|
||||||
|
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) A(i,j)=v++;
|
||||||
|
// B = [[9,8,7],[6,5,4],[3,2,1]]
|
||||||
|
double w=9.0;
|
||||||
|
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) B(i,j)=w--;
|
||||||
|
|
||||||
|
auto C = numerics::matmul<double>(A,B);
|
||||||
|
// Hand-checked first row:
|
||||||
|
// row0 dot columns:
|
||||||
|
// c00 = 1*9 + 2*6 + 3*3 = 30
|
||||||
|
// c01 = 1*8 + 2*5 + 3*2 = 24
|
||||||
|
// c02 = 1*7 + 2*4 + 3*1 = 18
|
||||||
|
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
|
||||||
|
CHECK(C(0,0)==30.0 && C(0,1)==24.0 && C(0,2)==18.0, "first row wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matmul_OMP_Equals_Serial) {
|
||||||
|
utils::Md A(4,5,0.0), B(5,3,0.0);
|
||||||
|
// Fill deterministic
|
||||||
|
for (uint64_t i=0;i<A.rows();++i)
|
||||||
|
for (uint64_t j=0;j<A.cols();++j)
|
||||||
|
A(i,j) = 0.1*(1 + (i*17 + j*19)%10);
|
||||||
|
for (uint64_t i=0;i<B.rows();++i)
|
||||||
|
for (uint64_t j=0;j<B.cols();++j)
|
||||||
|
B(i,j) = 0.2*(1 + (i*23 + j*29)%10);
|
||||||
|
|
||||||
|
auto Cs = numerics::matmul<double>(A,B);
|
||||||
|
auto Cr = numerics::matmul_rows_omp<double>(A,B);
|
||||||
|
auto Cc = numerics::matmul_collapse_omp<double>(A,B);
|
||||||
|
auto Ca = numerics::matmul_auto<double>(A,B);
|
||||||
|
|
||||||
|
CHECK((Cs.nearly_equal(Cr, 1e-12)), "rows_omp != serial");
|
||||||
|
CHECK((Cs.nearly_equal(Cc, 1e-12)), "collapse_omp != serial");
|
||||||
|
CHECK((Cs.nearly_equal(Ca, 1e-12)), "auto != serial");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Dimension mismatch ============
|
||||||
|
TEST_CASE(Matmul_DimensionMismatch_Throws) {
|
||||||
|
utils::Md A(2,3,0.0), B(4,2,0.0);
|
||||||
|
bool threw=false;
|
||||||
|
try { auto _ = numerics::matmul<double>(A,B); (void)_; }
|
||||||
|
catch (const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "serial should throw on dim mismatch");
|
||||||
|
|
||||||
|
threw=false; try { auto _ = numerics::matmul_rows_omp<double>(A,B); (void)_; }
|
||||||
|
catch (const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "rows_omp should throw on dim mismatch");
|
||||||
|
|
||||||
|
threw=false; try { auto _ = numerics::matmul_collapse_omp<double>(A,B); (void)_; }
|
||||||
|
catch (const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "collapse_omp should throw on dim mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Edge cases ============
|
||||||
|
TEST_CASE(Matmul_Edges_ZeroDims) {
|
||||||
|
// (0xK) * (KxP) -> (0xP)
|
||||||
|
utils::Md A0(0,5,0.0), B1(5,3,0.0);
|
||||||
|
auto C0 = numerics::matmul<double>(A0,B1);
|
||||||
|
CHECK(C0.rows()==0 && C0.cols()==3, "0xK * KxP shape wrong");
|
||||||
|
|
||||||
|
// (MxK) * (Kx0) -> (Mx0)
|
||||||
|
utils::Md A2(7,4,0.0), B0(4,0,0.0);
|
||||||
|
auto C1 = numerics::matmul<double>(A2,B0);
|
||||||
|
CHECK(C1.rows()==7 && C1.cols()==0, "MxK * Kx0 shape wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Identity sanity ============
|
||||||
|
TEST_CASE(Matmul_Identity) {
|
||||||
|
const uint64_t n=5;
|
||||||
|
utils::Md I(n,n,0.0), A(n,n,0.0);
|
||||||
|
for (uint64_t i=0;i<n;++i) I(i,i)=1.0;
|
||||||
|
for (uint64_t i=0;i<n;++i)
|
||||||
|
for (uint64_t j=0;j<n;++j)
|
||||||
|
A(i,j) = (i==j)? 2.0 : ( (i<j)? 1.0 : -1.0 );
|
||||||
|
|
||||||
|
auto L = numerics::matmul<double>(I,A);
|
||||||
|
auto R = numerics::matmul<double>(A,I);
|
||||||
|
CHECK(L == A, "I*A != A");
|
||||||
|
CHECK(R == A, "A*I != A");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Perf sanity (same kernel: 1 thread vs many) ============
|
||||||
|
template <class F>
|
||||||
|
static double time_it(F&& f, int iters=1) {
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
for (int i=0;i<iters;++i) f();
|
||||||
|
auto t1 = std::chrono::high_resolution_clock::now();
|
||||||
|
return std::chrono::duration<double>(t1 - t0).count();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matmul_Perf_Sanity_RowOMP) {
|
||||||
|
#ifndef _OPENMP
|
||||||
|
return;
|
||||||
|
#else
|
||||||
|
int hw = omp_get_max_threads();
|
||||||
|
if (hw <= 1) return;
|
||||||
|
|
||||||
|
const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed
|
||||||
|
utils::Md A(m,k,0.0), B(k,p,0.0);
|
||||||
|
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i+j%7)*0.001;
|
||||||
|
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i*3+j%5)*0.0005;
|
||||||
|
|
||||||
|
// Warm-up
|
||||||
|
(void) numerics::matmul_rows_omp<double>(A,B);
|
||||||
|
|
||||||
|
int prev = omp_get_max_threads();
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
omp_set_num_threads(1);
|
||||||
|
numerics::matmul_rows_omp<double>(A,B);
|
||||||
|
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
omp_set_num_threads(hw);
|
||||||
|
t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
numerics::matmul_rows_omp<double>(A,B);
|
||||||
|
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
omp_set_num_threads(prev);
|
||||||
|
|
||||||
|
// Must not be notably slower with many threads
|
||||||
|
CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matmul_Perf_Sanity_CollapseOMP) {
|
||||||
|
#ifndef _OPENMP
|
||||||
|
return;
|
||||||
|
#else
|
||||||
|
int hw = omp_get_max_threads();
|
||||||
|
if (hw <= 1) return;
|
||||||
|
|
||||||
|
const uint64_t m=512, k=512, p=512;
|
||||||
|
utils::Md A(m,k,0.0), B(k,p,0.0);
|
||||||
|
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i*7+j%11)*0.0003;
|
||||||
|
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i%13+j)*0.0002;
|
||||||
|
|
||||||
|
(void) numerics::matmul_collapse_omp<double>(A,B); // warm-up
|
||||||
|
|
||||||
|
int prev = omp_get_max_threads();
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
omp_set_num_threads(1);
|
||||||
|
numerics::matmul_collapse_omp<double>(A,B);
|
||||||
|
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
omp_set_num_threads(hw);
|
||||||
|
t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
numerics::matmul_collapse_omp<double>(A,B);
|
||||||
|
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
omp_set_num_threads(prev);
|
||||||
|
|
||||||
|
CHECK(tN <= t1 * 1.05, "collapse_omp: multi-thread slower than single-thread");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
|
||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/utils.h"
|
||||||
|
|
||||||
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||||
|
using utils::Mf; using utils::Md; using utils::Mi;
|
||||||
|
|
||||||
|
|
||||||
|
// ---------- Construction & element access ----------
|
||||||
|
TEST_CASE(Matrix_Construct_Access) {
|
||||||
|
Md M; // default
|
||||||
|
CHECK(M.rows()==0 && M.cols()==0, "default ctor dims wrong");
|
||||||
|
|
||||||
|
Mf A(2,3, 1.0f);
|
||||||
|
CHECK(A.rows()==2 && A.cols()==3, "ctor dims wrong");
|
||||||
|
CHECK(A(0,0)==1.0f && A(1,2)==1.0f, "fill wrong");
|
||||||
|
|
||||||
|
A(0,1)=2.5f; A(1,0)=3.5f;
|
||||||
|
CHECK(A(0,1)==2.5f && A(1,0)==3.5f, "operator() set/get failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Equality, inequality, nearly_equal ----------
|
||||||
|
TEST_CASE(Matrix_Equality) {
|
||||||
|
Mi A(2,2,0), B(2,2,0), C(2,2,1);
|
||||||
|
A(0,0)=1; A(1,1)=1; // A = I
|
||||||
|
B(0,0)=1; B(1,1)=1; // B = I
|
||||||
|
|
||||||
|
CHECK(A == B, "== failed identical");
|
||||||
|
CHECK(!(A != B), "!= failed identical");
|
||||||
|
CHECK(A != C, "!= failed different");
|
||||||
|
|
||||||
|
Md F1(2,2,0.0), F2(2,2,0.0);
|
||||||
|
F1(0,0)=1.0; F1(1,1)=2.0;
|
||||||
|
F2(0,0)=1.0; F2(1,1)=2.0 + 5e-10; // tiny perturbation
|
||||||
|
CHECK(!(F1 == F2), "operator== is exact; should differ");
|
||||||
|
CHECK(F1.nearly_equal(F2, 1e-9), "nearly_equal should accept tiny delta");
|
||||||
|
CHECK(!F1.nearly_equal(F2, 1e-12), "nearly_equal too strict should fail");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Row helpers ----------
|
||||||
|
TEST_CASE(Matrix_Row_Get_Set) {
|
||||||
|
Mf M(3,4, 0.0f);
|
||||||
|
Vf r(4, 0.0f);
|
||||||
|
for (uint64_t j=0;j<4;++j) r[j] = float(j+1); // [1,2,3,4]
|
||||||
|
|
||||||
|
M.set_row(1, r);
|
||||||
|
auto out = M.get_row(1);
|
||||||
|
CHECK(out == r, "set_row/get_row mismatch");
|
||||||
|
|
||||||
|
// size mismatch should throw
|
||||||
|
bool threw=false;
|
||||||
|
try { Vf bad(3, 9.0f); M.set_row(2, bad); } catch (const std::exception&) { threw=true; }
|
||||||
|
CHECK(threw, "set_row should throw on size mismatch");
|
||||||
|
|
||||||
|
// out of range
|
||||||
|
threw=false;
|
||||||
|
try { (void)M.get_row(3); } catch (const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "get_row should throw on OOB index");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Column helpers ----------
|
||||||
|
TEST_CASE(Matrix_Col_Get_Set) {
|
||||||
|
Md M(3,2, 0.0);
|
||||||
|
Vd c(3, 0.0);
|
||||||
|
c[0]=10; c[1]=20; c[2]=30;
|
||||||
|
|
||||||
|
M.set_col(1, c);
|
||||||
|
auto out = M.get_col(1);
|
||||||
|
CHECK(out == c, "set_col/get_col mismatch");
|
||||||
|
|
||||||
|
// size mismatch should throw
|
||||||
|
bool threw=false;
|
||||||
|
try { Vd bad(2, 9.0); M.set_col(0, bad); } catch (const std::exception&) { threw=true; }
|
||||||
|
CHECK(threw, "set_col should throw on size mismatch");
|
||||||
|
|
||||||
|
// out of range
|
||||||
|
threw=false;
|
||||||
|
try { (void)M.get_col(2); } catch (const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "get_col should throw on OOB index");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- swap_rows / swap_cols ----------
|
||||||
|
TEST_CASE(Matrix_Swap_Rows_Cols) {
|
||||||
|
Mi M(2,3,0);
|
||||||
|
// Row 0: [1,2,3], Row 1: [4,5,6]
|
||||||
|
M(0,0)=1; M(0,1)=2; M(0,2)=3;
|
||||||
|
M(1,0)=4; M(1,1)=5; M(1,2)=6;
|
||||||
|
|
||||||
|
M.swap_rows(0,1);
|
||||||
|
CHECK(M(0,0)==4 && M(0,1)==5 && M(0,2)==6, "swap_rows row0 wrong");
|
||||||
|
CHECK(M(1,0)==1 && M(1,1)==2 && M(1,2)==3, "swap_rows row1 wrong");
|
||||||
|
|
||||||
|
// swap back via cols
|
||||||
|
M.swap_cols(0,2);
|
||||||
|
// After swapping col0<->col2:
|
||||||
|
// Row0: [6,5,4], Row1: [3,2,1]
|
||||||
|
CHECK(M(0,0)==6 && M(0,1)==5 && M(0,2)==4, "swap_cols row0 wrong");
|
||||||
|
CHECK(M(1,0)==3 && M(1,1)==2 && M(1,2)==1, "swap_cols row1 wrong");
|
||||||
|
|
||||||
|
// no-op swap (a==b) should not crash or change
|
||||||
|
M.swap_rows(1,1);
|
||||||
|
M.swap_cols(2,2);
|
||||||
|
|
||||||
|
// OOB checks
|
||||||
|
bool threw=false;
|
||||||
|
try { M.swap_rows(5,1); } catch (const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "swap_rows should throw on OOB");
|
||||||
|
threw=false;
|
||||||
|
try { M.swap_cols(0,9); } catch (const std::out_of_range&) { threw=true; }
|
||||||
|
CHECK(threw, "swap_cols should throw on OOB");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- data() layout (contiguous row-major) ----------
|
||||||
|
TEST_CASE(Matrix_Data_Layout) {
|
||||||
|
Md M(2,3, 0.0);
|
||||||
|
// Fill increasing sequence
|
||||||
|
double val=1.0;
|
||||||
|
for (uint64_t i=0;i<M.rows();++i)
|
||||||
|
for (uint64_t j=0;j<M.cols();++j)
|
||||||
|
M(i,j) = val++;
|
||||||
|
|
||||||
|
const double* p = M.data();
|
||||||
|
// Expect row-major: [1,2,3,4,5,6]
|
||||||
|
CHECK(p[0]==1.0 && p[1]==2.0 && p[2]==3.0 && p[3]==4.0 && p[4]==5.0 && p[5]==6.0,
|
||||||
|
"data() row-major layout wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------- Stream output ----------
|
||||||
|
TEST_CASE(Matrix_StreamOutput) {
|
||||||
|
Mf M(2,2,0.0f);
|
||||||
|
M(0,0)=1.0f; M(0,1)=2.0f;
|
||||||
|
M(1,0)=3.0f; M(1,1)=4.0f;
|
||||||
|
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << M;
|
||||||
|
const std::string s = oss.str();
|
||||||
|
// Format example:
|
||||||
|
// [[1.000, 2.000],
|
||||||
|
// [3.000, 4.000]]
|
||||||
|
CHECK(s.find("[[1.000, 2.000],") != std::string::npos, "ostream first row format");
|
||||||
|
CHECK(s.find("[3.000, 4.000]]") != std::string::npos, "ostream second row format");
|
||||||
|
}
|
||||||
@@ -0,0 +1,237 @@
|
|||||||
|
|
||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/utils.h" // matrix.h, vector.h
|
||||||
|
#include "./numerics/matvec.h" // numerics::matvec / inplace_transpose
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
using utils::Vi; using utils::Vf; using utils::Vd;
|
||||||
|
using utils::Mi; using utils::Mf; using utils::Md;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// ------------------------------------------------------------
|
||||||
|
// matvec: y = A * x
|
||||||
|
// ------------------------------------------------------------
|
||||||
|
TEST_CASE(Matvec_Serial_Simple) {
|
||||||
|
// A = [[1,2,3],
|
||||||
|
// [4,5,6]]
|
||||||
|
Md A(2,3,0.0);
|
||||||
|
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||||
|
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
||||||
|
Vd x(3,0.0); x[0]=7; x[1]=8; x[2]=9;
|
||||||
|
|
||||||
|
auto y = numerics::matvec<double>(A,x); // [ 1*7+2*8+3*9 , 4*7+5*8+6*9 ] = [50, 122]
|
||||||
|
CHECK(y.size()==2, "matvec size wrong");
|
||||||
|
CHECK(y[0]==50.0 && y[1]==122.0, "matvec values wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matvec_OMP_Equals_Serial) {
|
||||||
|
Md A(3,3,0.0);
|
||||||
|
// A = I * 2
|
||||||
|
for (uint64_t i=0;i<3;++i) A(i,i)=2.0;
|
||||||
|
Vd x(3,0.0); x[0]=1; x[1]=2; x[2]=3;
|
||||||
|
|
||||||
|
auto ys = numerics::matvec<double>(A,x);
|
||||||
|
auto yp = numerics::matvec_omp<double>(A,x);
|
||||||
|
|
||||||
|
CHECK((ys.nearly_equal_vec(yp)), "matvec_omp != matvec");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matvec_Auto_Equals_Serial) {
|
||||||
|
Md A(2,2,0.0); A(0,0)=2; A(0,1)=1; A(1,0)=0.5; A(1,1)=3;
|
||||||
|
Vd x(2,0.0); x[0]=4; x[1]=5;
|
||||||
|
|
||||||
|
auto ys = numerics::matvec<double>(A,x);
|
||||||
|
auto ya = numerics::matvec_auto<double>(A,x);
|
||||||
|
|
||||||
|
CHECK((ys.nearly_equal_vec(ya)), "matvec_auto != serial");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matvec_DimensionMismatch_Throws) {
|
||||||
|
Md A(2,3,0.0);
|
||||||
|
Vd x(4,0.0);
|
||||||
|
bool threw=false;
|
||||||
|
try { auto _ = numerics::matvec<double>(A,x); (void)_; }
|
||||||
|
catch (const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "matvec must throw on dimension mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matvec_Zero_Edges) {
|
||||||
|
Md A(0,3,0.0); // 0x3
|
||||||
|
Vd x(3,1.0);
|
||||||
|
auto y = numerics::matvec<double>(A,x);
|
||||||
|
CHECK(y.size()==0, "0xN * x should return size 0 vector");
|
||||||
|
|
||||||
|
Md B(2,0,0.0); // 2x0
|
||||||
|
Vd z(0,0.0);
|
||||||
|
auto y2 = numerics::matvec<double>(B,z);
|
||||||
|
CHECK(y2.size()==2 && y2[0]==0.0 && y2[1]==0.0, "N×0 * 0 should return zeros of size N");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Matvec_Float_Tolerance) {
|
||||||
|
Mf A(2,2,0.0f); A(0,0)=1.0f; A(0,1)=2.0f; A(1,0)=3.0f; A(1,1)=4.0f;
|
||||||
|
Vf x(2,0.0f); x[0]=0.1f; x[1]=0.2f;
|
||||||
|
|
||||||
|
auto y1 = numerics::matvec<float>(A,x);
|
||||||
|
auto y2 = numerics::matvec_omp<float>(A,x);
|
||||||
|
|
||||||
|
CHECK((y1.nearly_equal_vec(y2,1e-6f)), "matvec float omp mismatch");
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// ---------- Auto inside an outer parallel region (no accidental nested teams) ----------
|
||||||
|
// We just check correctness; performance is environment-dependent.
|
||||||
|
//
|
||||||
|
TEST_CASE(Matvec_Auto_Inside_Outer_Parallel_Correctness) {
|
||||||
|
const uint64_t m=64, n=64;
|
||||||
|
Md A(m,n,1.0); Vd x(n,2.0);
|
||||||
|
//fill_deterministic(A); fill_deterministic(x);
|
||||||
|
Vd ref = numerics::matvec<double>(A,x);
|
||||||
|
|
||||||
|
// Call auto inside an outer team
|
||||||
|
#ifdef _OPENMP
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
#endif
|
||||||
|
for (int rep=0; rep<32; ++rep) {
|
||||||
|
auto y = numerics::matvec_auto<double>(A,x);
|
||||||
|
// Each thread checks its own result equals reference
|
||||||
|
if (!(y.nearly_equal_vec(ref))) {
|
||||||
|
throw TestFailure("matvec_auto wrong under outer parallel region");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TEST_CASE(Matvec_Speed_Sanity) {
|
||||||
|
const uint64_t m=4096, n=4096; // ~16M MACs; adjust if needed
|
||||||
|
Md A(m,n,1.0); Vd x(n,2.0);
|
||||||
|
//fill_deterministic(A); fill_deterministic(x);
|
||||||
|
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
auto yS = numerics::matvec(A,x);
|
||||||
|
double tp = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
int threads = omp_get_max_threads();
|
||||||
|
#else
|
||||||
|
int threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
auto yP = numerics::matvec_omp(A,x);
|
||||||
|
double ts = std::chrono::duration<double>(t0 - std::chrono::high_resolution_clock::now()).count();
|
||||||
|
|
||||||
|
CHECK((yS.nearly_equal_vec(yP)), "matvec_omp != matvec_serial (large)");
|
||||||
|
// Only enforce basic sanity if we *can* use >1 threads:
|
||||||
|
if (threads > 1) {
|
||||||
|
// Be generous: just require not significantly slower.
|
||||||
|
CHECK(tp <= ts, "matvec_omp unexpectedly much slower than serial");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------------------------------------------------
|
||||||
|
// vecmat: y = x * A
|
||||||
|
// ------------------------------------------------------------
|
||||||
|
TEST_CASE(Vecmat_Serial_Simple) {
|
||||||
|
// A = [[1,2],
|
||||||
|
// [3,4],
|
||||||
|
// [5,6]] (3x2)
|
||||||
|
Md A(3,2,0.0);
|
||||||
|
A(0,0)=1; A(0,1)=2;
|
||||||
|
A(1,0)=3; A(1,1)=4;
|
||||||
|
A(2,0)=5; A(2,1)=6;
|
||||||
|
|
||||||
|
Vd x(3,0.0); x[0]=7; x[1]=8; x[2]=9;
|
||||||
|
|
||||||
|
auto y = numerics::vecmat<double>(x,A); // 1*7+3*8+5*9= 76 ; 2*7+4*8+6*9=100
|
||||||
|
CHECK(y.size()==2, "vecmat size wrong");
|
||||||
|
CHECK(y[0]==76.0 && y[1]==100.0, "vecmat values wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vecmat_OMP_Equals_Serial) {
|
||||||
|
Md A(2,2,0.0); A(0,0)=2; A(0,1)=1; A(1,0)=5; A(1,1)=-1;
|
||||||
|
Vd x(2,0.0); x[0]=0.5; x[1]=1.5;
|
||||||
|
|
||||||
|
auto ys = numerics::vecmat<double>(x,A);
|
||||||
|
auto yp = numerics::vecmat_omp<double>(x,A);
|
||||||
|
|
||||||
|
CHECK((ys.nearly_equal_vec(yp)), "vecmat_omp != vecmat");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vecmat_Auto_Equals_Serial) {
|
||||||
|
Md A(2,3,0.0);
|
||||||
|
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||||
|
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
||||||
|
Vd x(2,0.0); x[0]=1; x[1]=2;
|
||||||
|
|
||||||
|
auto ys = numerics::vecmat<double>(x,A);
|
||||||
|
auto ya = numerics::vecmat_auto<double>(x,A);
|
||||||
|
|
||||||
|
CHECK((ys.nearly_equal_vec(ya)), "vecmat_auto != serial");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vecmat_DimensionMismatch_Throws) {
|
||||||
|
Md A(2,2,0.0);
|
||||||
|
Vd x(3,0.0);
|
||||||
|
bool threw=false;
|
||||||
|
try { auto _ = numerics::vecmat<double>(x,A); (void)_; }
|
||||||
|
catch (const std::runtime_error&) { threw=true; }
|
||||||
|
CHECK(threw, "vecmat must throw on dimension mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vecmat_Zero_Edges) {
|
||||||
|
Md A(0,3,0.0);
|
||||||
|
Vd x(0,0.0);
|
||||||
|
auto y = numerics::vecmat<double>(x,A); // 0×N times N×M → 0×M
|
||||||
|
CHECK(y.size()==3 && y[0]==0.0 && y[1]==0.0 && y[2]==0.0, "0-length x times A wrong");
|
||||||
|
|
||||||
|
Md B(3,0,0.0);
|
||||||
|
Vd z(3,1.0);
|
||||||
|
auto y2 = numerics::vecmat<double>(z,B); // 1x3 * 3x0 → 1x0
|
||||||
|
CHECK(y2.size()==0, "vecmat with N×0 result size wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- Auto inside an outer parallel region (no accidental nested teams) ----------
|
||||||
|
// We just check correctness; performance is environment-dependent.
|
||||||
|
//
|
||||||
|
TEST_CASE(Vecmat_Auto_Inside_Outer_Parallel_Correctness) {
|
||||||
|
const uint64_t m=64, n=64;
|
||||||
|
Md A(m,n,1.0); Vd x(m,2.0);
|
||||||
|
//fill_deterministic(A); fill_deterministic(x);
|
||||||
|
Vd ref = numerics::vecmat<double>(x,A);
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
#endif
|
||||||
|
for (int rep=0; rep<32; ++rep) {
|
||||||
|
auto y = numerics::vecmat_auto<double>(x,A);
|
||||||
|
if (!(y.nearly_equal_vec(ref))) {
|
||||||
|
throw TestFailure("vecmat_auto wrong under outer parallel region");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASE(Vecmat_Speed_Sanity) {
|
||||||
|
const uint64_t m=4096, n=4096;
|
||||||
|
Md A(m,n,1.0); Vd x(m,2.0);
|
||||||
|
//fill_deterministic(A); fill_deterministic(x);
|
||||||
|
|
||||||
|
auto t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
auto yS = numerics::vecmat<double>(x,A);
|
||||||
|
double ts = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
int threads = omp_get_max_threads();
|
||||||
|
#else
|
||||||
|
int threads = 1;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
t0 = std::chrono::high_resolution_clock::now();
|
||||||
|
auto yP = numerics::vecmat_omp<double>(x,A);
|
||||||
|
double tp = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||||
|
|
||||||
|
CHECK((yS.nearly_equal_vec(yP)), "vecmat_omp != vecmat_serial (large)");
|
||||||
|
if (threads > 1) {
|
||||||
|
CHECK(tp <= ts, "vecmat_omp unexpectedly much slower than serial");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
|
||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/utils.h" // matrix.h, vector.h
|
||||||
|
#include "./numerics/transpose.h" // numerics::transpose / inplace_transpose
|
||||||
|
|
||||||
|
using utils::Mi; using utils::Mf; using utils::Md;
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- Out-of-place transpose (rectangular) ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Transpose_Rectangular_OutOfPlace) {
|
||||||
|
// A = [ [1, 2, 3],
|
||||||
|
// [4, 5, 6] ] (2x3)
|
||||||
|
Md A(2,3,0.0);
|
||||||
|
A(0,0)=1; A(0,1)=2; A(0,2)=3;
|
||||||
|
A(1,0)=4; A(1,1)=5; A(1,2)=6;
|
||||||
|
|
||||||
|
auto AT = numerics::transpose(A); // (3x2)
|
||||||
|
|
||||||
|
CHECK(AT.rows()==3 && AT.cols()==2, "shape wrong after transpose");
|
||||||
|
CHECK(AT(0,0)==1 && AT(1,0)==2 && AT(2,0)==3, "first column wrong");
|
||||||
|
CHECK(AT(0,1)==4 && AT(1,1)==5 && AT(2,1)==6, "second column wrong");
|
||||||
|
|
||||||
|
// Involution: T(T(A)) == A
|
||||||
|
auto ATT = numerics::transpose(AT);
|
||||||
|
CHECK(ATT == A, "transpose(transpose(A)) != A");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- In-place transpose (square) ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Transpose_Square_InPlace) {
|
||||||
|
// 3x3 with distinct values
|
||||||
|
Mf S(3,3,0.0f);
|
||||||
|
float val = 1.0f;
|
||||||
|
for (uint64_t i=0;i<3;++i)
|
||||||
|
for (uint64_t j=0;j<3;++j)
|
||||||
|
S(i,j) = val++;
|
||||||
|
|
||||||
|
// Make an explicit transpose to compare against
|
||||||
|
auto ST = numerics::transpose(S);
|
||||||
|
|
||||||
|
// In-place should match the out-of-place result
|
||||||
|
numerics::inplace_transpose(S);
|
||||||
|
CHECK(S == ST, "inplace_transpose result mismatch");
|
||||||
|
|
||||||
|
// Involution: applying in-place again should return original
|
||||||
|
numerics::inplace_transpose(S);
|
||||||
|
// Now S should equal the original pre-inplace matrix (which was transposed above)
|
||||||
|
// We can reconstruct original by transposing ST:
|
||||||
|
auto orig = numerics::transpose(ST);
|
||||||
|
CHECK(S == orig, "inplace transpose twice should restore original");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- In-place transpose must throw on non-square ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Transpose_InPlace_Throws_On_Rectangular) {
|
||||||
|
Md R(2,3,0.0); // rectangular
|
||||||
|
bool threw = false;
|
||||||
|
try {
|
||||||
|
numerics::inplace_transpose(R);
|
||||||
|
} catch (const std::runtime_error&) {
|
||||||
|
threw = true;
|
||||||
|
}
|
||||||
|
CHECK(threw, "inplace_transpose must throw on non-square matrices");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- Edge cases: 0x0 and 1x1 ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Transpose_Edge_0x0_1x1) {
|
||||||
|
// 0x0 should be fine both ways
|
||||||
|
Md E; // 0x0
|
||||||
|
auto ET = numerics::transpose(E);
|
||||||
|
CHECK(ET.rows()==0 && ET.cols()==0, "0x0 transpose shape wrong");
|
||||||
|
// in-place on 0x0 (rows==cols) should not throw
|
||||||
|
numerics::inplace_transpose(E);
|
||||||
|
CHECK(E.rows()==0 && E.cols()==0, "0x0 inplace transpose changed shape");
|
||||||
|
|
||||||
|
// 1x1 should be a no-op and not throw
|
||||||
|
Mi I(1,1,0);
|
||||||
|
I(0,0) = 42;
|
||||||
|
auto IT = numerics::transpose(I);
|
||||||
|
CHECK(IT.rows()==1 && IT.cols()==1 && IT(0,0)==42, "1x1 transpose wrong");
|
||||||
|
numerics::inplace_transpose(I);
|
||||||
|
CHECK(I(0,0)==42, "1x1 inplace transpose changed value");
|
||||||
|
}
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
|
||||||
|
#include "test_common.h"
|
||||||
|
#include "./utils/utils.h"
|
||||||
|
|
||||||
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- Basic construction & access ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_Construct_Size_Access) {
|
||||||
|
Vd a; // default
|
||||||
|
CHECK(a.size() == 0, "default size must be 0");
|
||||||
|
|
||||||
|
Vf b(3, 1.0f); // (n, fill)
|
||||||
|
CHECK(b.size() == 3, "size wrong");
|
||||||
|
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "fill wrong");
|
||||||
|
|
||||||
|
b[1] = 2.5f;
|
||||||
|
CHECK(b[1] == 2.5f, "operator[] write failed");
|
||||||
|
|
||||||
|
// resize (grow + value)
|
||||||
|
b.resize(5, 7.0f);
|
||||||
|
CHECK(b.size() == 5, "resize grow size wrong");
|
||||||
|
CHECK(b[0] == 1.0f && b[1] == 2.5f && b[2] == 1.0f && b[3] == 7.0f && b[4] == 7.0f,
|
||||||
|
"resize grow values wrong");
|
||||||
|
|
||||||
|
// resize (shrink)
|
||||||
|
b.resize(2);
|
||||||
|
CHECK(b.size() == 2, "resize shrink size wrong");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE(Vector_Clear_PushBack) {
|
||||||
|
Vi v(0, 0);
|
||||||
|
v.push_back(10);
|
||||||
|
v.push_back(20);
|
||||||
|
CHECK(v.size() == 2, "push_back size wrong");
|
||||||
|
CHECK(v[0] == 10 && v[1] == 20, "push_back values wrong");
|
||||||
|
|
||||||
|
v.clear();
|
||||||
|
CHECK(v.size() == 0, "clear failed");
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// ---------- Equality / Inequality (tolerant for float/double) ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_Equality_Tolerant) {
|
||||||
|
Vd a(3, 1.0), b(3, 1.0);
|
||||||
|
CHECK(a == b, "== identical failed");
|
||||||
|
CHECK(!(a != b), "!= identical failed");
|
||||||
|
|
||||||
|
// Tiny perturbation within eps (1e-6 default)
|
||||||
|
b[1] += 1e-7;
|
||||||
|
CHECK(a == b, "== tolerant failed");
|
||||||
|
|
||||||
|
// Larger perturbation should fail equality
|
||||||
|
b[1] += 1e-4;
|
||||||
|
CHECK(a != b, "!= with difference failed");
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// ---------- Scalar arithmetic: +, -, *, / (inplace and returning) ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_Scalar_Arithmetic) {
|
||||||
|
Vf a(3, 1.0f);
|
||||||
|
|
||||||
|
// inplace
|
||||||
|
a.inplace_add(2); // int convertible to float
|
||||||
|
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_add failed");
|
||||||
|
|
||||||
|
a.inplace_subtract(1.5f);
|
||||||
|
CHECK(std::fabs(a[0] - 1.5f) < 1e-6f &&
|
||||||
|
std::fabs(a[1] - 1.5f) < 1e-6f &&
|
||||||
|
std::fabs(a[2] - 1.5f) < 1e-6f, "inplace_subtract failed");
|
||||||
|
|
||||||
|
a.inplace_multiply(4.0);
|
||||||
|
CHECK(a[0] == 6.0f && a[1] == 6.0f && a[2] == 6.0f, "inplace_multiply failed");
|
||||||
|
|
||||||
|
a.inplace_divide(2);
|
||||||
|
CHECK(a[0] == 3.0f && a[1] == 3.0f && a[2] == 3.0f, "inplace_divide failed");
|
||||||
|
|
||||||
|
// returning
|
||||||
|
auto b = a + 1.0f;
|
||||||
|
CHECK(b[0] == 4.0f && b[1] == 4.0f && b[2] == 4.0f, "operator+(scalar) failed");
|
||||||
|
|
||||||
|
b = a - 2.0f;
|
||||||
|
CHECK(b[0] == 1.0f && b[1] == 1.0f && b[2] == 1.0f, "operator-(scalar) failed");
|
||||||
|
|
||||||
|
b = a * 10; // int -> float
|
||||||
|
CHECK(b[0] == 30.0f && b[1] == 30.0f && b[2] == 30.0f, "operator*(scalar) failed");
|
||||||
|
|
||||||
|
b = a / 3.0f;
|
||||||
|
CHECK(std::fabs(b[0] - 1.0f) < 1e-6f &&
|
||||||
|
std::fabs(b[1] - 1.0f) < 1e-6f &&
|
||||||
|
std::fabs(b[2] - 1.0f) < 1e-6f, "operator/(scalar) failed");
|
||||||
|
|
||||||
|
// scalar on the left (friends implemented for + and *)
|
||||||
|
Vf c(3, 2.0f);
|
||||||
|
auto d = 5 + c; // friend operator+(U, Vector<T>)
|
||||||
|
CHECK(d[0] == 7.0f && d[1] == 7.0f && d[2] == 7.0f, "scalar + vector failed");
|
||||||
|
|
||||||
|
d = 3 * c; // friend operator*(U, Vector<T>)
|
||||||
|
CHECK(d[0] == 6.0f && d[1] == 6.0f && d[2] == 6.0f, "scalar * vector failed");
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// ---------- Vector arithmetic: +, -, *, / (elementwise) ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_Vector_Arithmetic) {
|
||||||
|
Vd a(3, 1.0), b(3, 2.0);
|
||||||
|
|
||||||
|
// returning
|
||||||
|
auto c = a + b;
|
||||||
|
CHECK(c[0]==3.0 && c[1]==3.0 && c[2]==3.0, "vec + vec failed");
|
||||||
|
|
||||||
|
c = b - a;
|
||||||
|
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec - vec failed");
|
||||||
|
|
||||||
|
c = a * b;
|
||||||
|
CHECK(c[0]==2.0 && c[1]==2.0 && c[2]==2.0, "vec * vec failed");
|
||||||
|
|
||||||
|
c = b / b;
|
||||||
|
CHECK(c[0]==1.0 && c[1]==1.0 && c[2]==1.0, "vec / vec failed");
|
||||||
|
|
||||||
|
// inplace
|
||||||
|
a = Vd(3, 1.0);
|
||||||
|
a += b;
|
||||||
|
CHECK(a[0]==3.0 && a[1]==3.0 && a[2]==3.0, "inplace vec + vec failed");
|
||||||
|
a -= b;
|
||||||
|
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec - vec failed");
|
||||||
|
a *= b;
|
||||||
|
CHECK(a[0]==2.0 && a[1]==2.0 && a[2]==2.0, "inplace vec * vec failed");
|
||||||
|
a /= b;
|
||||||
|
CHECK(a[0]==1.0 && a[1]==1.0 && a[2]==1.0, "inplace vec / vec failed");
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// ---------- Size mismatch error paths ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_SizeMismatch_Throws) {
|
||||||
|
Vd a(3, 1.0), b(4, 2.0);
|
||||||
|
|
||||||
|
bool threw = false;
|
||||||
|
try { auto c = a + b; (void)c; } catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "add should throw on size mismatch");
|
||||||
|
|
||||||
|
threw = false;
|
||||||
|
try { a.inplace_subtract(b); } catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "inplace_subtract should throw on size mismatch");
|
||||||
|
|
||||||
|
threw = false;
|
||||||
|
try { auto d = a * b; (void)d; } catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "multiply should throw on size mismatch");
|
||||||
|
|
||||||
|
threw = false;
|
||||||
|
try { auto s = a.dot(b); (void)s; } catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "dot should throw on size mismatch");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- Power / sqrt ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_Power_Sqrt) {
|
||||||
|
Vd a(3, 2.0); // [2,2,2]
|
||||||
|
auto b = a.power(3.0); // [8,8,8]
|
||||||
|
CHECK(b[0]==8.0 && b[1]==8.0 && b[2]==8.0, "scalar power failed");
|
||||||
|
|
||||||
|
Vd p(3, 3.0); // [3,3,3]
|
||||||
|
auto c = b.power(p); // 8^3 = 512
|
||||||
|
CHECK(c[0]==512.0 && c[1]==512.0 && c[2]==512.0, "vector power failed");
|
||||||
|
|
||||||
|
Vd d(3, 9.0);
|
||||||
|
auto e = d.sqrt(); // [3,3,3]
|
||||||
|
CHECK(e[0]==3.0 && e[1]==3.0 && e[2]==3.0, "sqrt failed");
|
||||||
|
|
||||||
|
// inplace
|
||||||
|
d.inplace_sqrt(); // becomes [3,3,3]
|
||||||
|
CHECK(d == e, "inplace_sqrt failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// ---------- Dot / Sum / Norm / Normalize ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_Dot_Sum_Norm_Normalize) {
|
||||||
|
Vd a(3, 0.0);
|
||||||
|
a[0]=1.0; a[1]=2.0; a[2]=2.0;
|
||||||
|
|
||||||
|
CHECK(a.sum() == 5.0, "sum failed");
|
||||||
|
CHECK(a.dot(a) == 9.0, "dot self failed");
|
||||||
|
|
||||||
|
auto n = a.norm();
|
||||||
|
CHECK(std::fabs(n - 3.0) < 1e-12, "norm failed");
|
||||||
|
|
||||||
|
auto b = a.normalize();
|
||||||
|
CHECK(std::fabs(b.norm() - 1.0) < 1e-12, "normalize() not unit");
|
||||||
|
|
||||||
|
// inplace normalize
|
||||||
|
a.inplace_normalize();
|
||||||
|
CHECK(std::fabs(a.norm() - 1.0) < 1e-12, "inplace_normalize not unit");
|
||||||
|
|
||||||
|
// zero-norm error
|
||||||
|
Vd z(3, 0.0);
|
||||||
|
bool threw = false;
|
||||||
|
try { z.inplace_normalize(); } catch (const std::runtime_error&) { threw = true; }
|
||||||
|
CHECK(threw, "normalize zero vector must throw");
|
||||||
|
}
|
||||||
|
//
|
||||||
|
// ---------- Stream output (basic sanity) ----------
|
||||||
|
//
|
||||||
|
TEST_CASE(Vector_StreamOutput) {
|
||||||
|
Vi a(3, 2);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << a;
|
||||||
|
auto s = oss.str();
|
||||||
|
CHECK(s == "[2, 2, 2]", "ostream<< wrong format");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user