156 lines
4.8 KiB
C++
156 lines
4.8 KiB
C++
#include "test_common.h"
|
||
#include "./utils/matrix.h"
|
||
#include "./numerics/initializers/eye.h"
|
||
|
||
#include <chrono>
|
||
|
||
|
||
|
||
// Tiny helper
|
||
template <typename T>
|
||
static bool is_identity(const utils::Matrix<T>& M) {
|
||
if (M.rows() != M.cols()) return false;
|
||
for (std::uint64_t i = 0; i < M.rows(); ++i)
|
||
for (std::uint64_t j = 0; j < M.cols(); ++j)
|
||
if ((i == j && M(i,j) != T{1}) ||
|
||
(i != j && M(i,j) != T{0})) return false;
|
||
return true;
|
||
}
|
||
|
||
|
||
// ============ Basic correctness ============
|
||
TEST_CASE(Generate_EYE_double) {
|
||
utils::Matrix<double> I = numerics::eye<double>(5);
|
||
CHECK_EQ(I.rows(), 5, "rows should be 5");
|
||
CHECK_EQ(I.cols(), 5, "cols should be 5");
|
||
CHECK(is_identity(I), "eye<double>(5) is not identity");
|
||
}
|
||
|
||
TEST_CASE(Generate_EYE_int) {
|
||
utils::Matrix<uint64_t> I = numerics::eye<uint64_t>(3);
|
||
CHECK_EQ(I.rows(), 3, "rows should be 3");
|
||
CHECK_EQ(I.cols(), 3, "cols should be 3");
|
||
CHECK(is_identity(I), "eye<int>(3) is not identity");
|
||
}
|
||
|
||
TEST_CASE(Inplace_EYE_resize_creates_identity) {
|
||
utils::Md A; // empty
|
||
numerics::inplace_eye(A, 7); // resize to 7x7 and set identity
|
||
CHECK_EQ(A.rows(), 7, "rows should be 7");
|
||
CHECK_EQ(A.cols(), 7, "cols should be 7");
|
||
CHECK(is_identity(A), "inplace_eye resize did not create identity");
|
||
}
|
||
|
||
TEST_CASE(Inplace_EYE_inplace_on_square) {
|
||
utils::Md A(4,4, 42.0); // junk values
|
||
numerics::inplace_eye(A); // N==0 → must be square, zero + diag
|
||
CHECK(is_identity(A), "inplace_eye on square did not produce identity");
|
||
}
|
||
|
||
TEST_CASE(Inplace_EYE_throws_on_non_square) {
|
||
utils::Md A(2,3, 5.0);
|
||
bool threw = false;
|
||
try {
|
||
numerics::inplace_eye(A); // N==0 and non-square → throws
|
||
} catch (const std::runtime_error&) {
|
||
threw = true;
|
||
}
|
||
CHECK(threw, "inplace_eye should throw on non-square when N==0");
|
||
}
|
||
|
||
// ============ OpenMP variants (compiled only when -fopenmp is used) ============
|
||
|
||
#ifdef _OPENMP
|
||
TEST_CASE(EYE_OMP_matches_serial) {
|
||
utils::Md I1 = numerics::eye<double>(64);
|
||
utils::Md I2 = numerics::eye_omp<double>(64);
|
||
|
||
for (uint64_t i = 0; i < I1.rows(); ++i){
|
||
for (uint64_t j = 0; j < I1.cols(); ++j){
|
||
CHECK(I1(i,j) == I2(i,j), "eye_omp != eye");
|
||
}
|
||
|
||
}
|
||
utils::Md A(64,64, 3.14);
|
||
numerics::inplace_eye_omp(A); // N==0 → must be square
|
||
CHECK(is_identity(A), "inplace_eye_omp did not produce identity");
|
||
}
|
||
|
||
|
||
TEST_CASE(EYE_OMP_speed) {
|
||
|
||
uint64_t prev = omp_get_max_threads();
|
||
if (prev <= 1) return;
|
||
|
||
omp_set_num_threads(16);
|
||
uint64_t N = 16384;
|
||
|
||
utils::Matrix<uint8_t> I1(N,N,1), I2(N,N,1);
|
||
|
||
auto t0 = std::chrono::high_resolution_clock::now();
|
||
numerics::inplace_eye<uint8_t>(I1);
|
||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||
|
||
t0 = std::chrono::high_resolution_clock::now();
|
||
numerics::inplace_eye_omp<uint8_t>(I2);
|
||
double t2 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||
|
||
omp_set_num_threads(prev);
|
||
|
||
CHECK(t2 <= t1, "collapse_omp: multi-thread slower than single-thread");
|
||
}
|
||
|
||
|
||
TEST_CASE(EYE_OMP_nested_ok) {
|
||
// Ensure nested regions are allowed for this test (won't error if not)
|
||
|
||
uint64_t prev_levels = omp_get_max_active_levels();
|
||
uint64_t prev = omp_get_max_threads();
|
||
|
||
omp_set_max_active_levels(2);
|
||
// Outer team size (small); inner will be large when nesting is allowed
|
||
const int outer_threads = 1;
|
||
// Inner team size to request during nested-enabled run
|
||
const int inner_threads = 12;
|
||
uint64_t N = 4096*8;
|
||
utils::Matrix<uint8_t> I1(N, N, 7);
|
||
utils::Matrix<uint8_t> I2(N, N, 7);
|
||
|
||
|
||
// ---------- baseline: inner runs with a small team (no outer region) ----------
|
||
omp_set_max_active_levels(1); // disable nesting for baseline
|
||
omp_set_num_threads(outer_threads);
|
||
|
||
for (uint64_t i = 0; i < 2; ++i){
|
||
numerics::inplace_eye_omp<uint8_t>(I1);
|
||
}
|
||
|
||
|
||
// ---------- nested: outer×inner (only one outer thread launches inner) ----------
|
||
omp_set_max_active_levels(2); // allow one nested level
|
||
|
||
#pragma omp parallel num_threads(outer_threads)
|
||
{
|
||
#pragma omp single // avoid racing on I1
|
||
{
|
||
omp_set_num_threads(inner_threads);
|
||
for (uint64_t i = 0; i < 2; ++i){
|
||
numerics::inplace_eye_omp<uint8_t>(I2);
|
||
}
|
||
}
|
||
}
|
||
omp_set_max_active_levels(prev_levels);
|
||
omp_set_num_threads(prev);
|
||
CHECK(is_identity(I1), "EYE_OMP_nested did not produce identity");
|
||
CHECK(is_identity(I2), "EYE_OMP_nested did not produce identity");
|
||
|
||
}
|
||
|
||
|
||
TEST_CASE(EYE_OMP_auto_is_identity) {
|
||
auto I = numerics::eye_omp_auto<double>(32);
|
||
CHECK(is_identity(I), "eye_omp_auto did not produce identity");
|
||
}
|
||
|
||
#endif
|