Files
Flux-openbuild/test/test_eye.cpp
2025-10-06 20:21:40 +00:00

156 lines
4.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "test_common.h"
#include "./utils/matrix.h"
#include "./numerics/initializers/eye.h"
#include <chrono>
// Tiny helper
template <typename T>
static bool is_identity(const utils::Matrix<T>& M) {
if (M.rows() != M.cols()) return false;
for (std::uint64_t i = 0; i < M.rows(); ++i)
for (std::uint64_t j = 0; j < M.cols(); ++j)
if ((i == j && M(i,j) != T{1}) ||
(i != j && M(i,j) != T{0})) return false;
return true;
}
// ============ Basic correctness ============
TEST_CASE(Generate_EYE_double) {
utils::Matrix<double> I = numerics::eye<double>(5);
CHECK_EQ(I.rows(), 5, "rows should be 5");
CHECK_EQ(I.cols(), 5, "cols should be 5");
CHECK(is_identity(I), "eye<double>(5) is not identity");
}
TEST_CASE(Generate_EYE_int) {
utils::Matrix<uint64_t> I = numerics::eye<uint64_t>(3);
CHECK_EQ(I.rows(), 3, "rows should be 3");
CHECK_EQ(I.cols(), 3, "cols should be 3");
CHECK(is_identity(I), "eye<int>(3) is not identity");
}
TEST_CASE(Inplace_EYE_resize_creates_identity) {
utils::Md A; // empty
numerics::inplace_eye(A, 7); // resize to 7x7 and set identity
CHECK_EQ(A.rows(), 7, "rows should be 7");
CHECK_EQ(A.cols(), 7, "cols should be 7");
CHECK(is_identity(A), "inplace_eye resize did not create identity");
}
TEST_CASE(Inplace_EYE_inplace_on_square) {
utils::Md A(4,4, 42.0); // junk values
numerics::inplace_eye(A); // N==0 → must be square, zero + diag
CHECK(is_identity(A), "inplace_eye on square did not produce identity");
}
TEST_CASE(Inplace_EYE_throws_on_non_square) {
utils::Md A(2,3, 5.0);
bool threw = false;
try {
numerics::inplace_eye(A); // N==0 and non-square → throws
} catch (const std::runtime_error&) {
threw = true;
}
CHECK(threw, "inplace_eye should throw on non-square when N==0");
}
// ============ OpenMP variants (compiled only when -fopenmp is used) ============
#ifdef _OPENMP
TEST_CASE(EYE_OMP_matches_serial) {
utils::Md I1 = numerics::eye<double>(64);
utils::Md I2 = numerics::eye_omp<double>(64);
for (uint64_t i = 0; i < I1.rows(); ++i){
for (uint64_t j = 0; j < I1.cols(); ++j){
CHECK(I1(i,j) == I2(i,j), "eye_omp != eye");
}
}
utils::Md A(64,64, 3.14);
numerics::inplace_eye_omp(A); // N==0 → must be square
CHECK(is_identity(A), "inplace_eye_omp did not produce identity");
}
TEST_CASE(EYE_OMP_speed) {
uint64_t prev = omp_get_max_threads();
if (prev <= 1) return;
omp_set_num_threads(16);
uint64_t N = 16384;
utils::Matrix<uint8_t> I1(N,N,1), I2(N,N,1);
auto t0 = std::chrono::high_resolution_clock::now();
numerics::inplace_eye<uint8_t>(I1);
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
t0 = std::chrono::high_resolution_clock::now();
numerics::inplace_eye_omp<uint8_t>(I2);
double t2 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
omp_set_num_threads(prev);
CHECK(t2 <= t1, "collapse_omp: multi-thread slower than single-thread");
}
TEST_CASE(EYE_OMP_nested_ok) {
// Ensure nested regions are allowed for this test (won't error if not)
uint64_t prev_levels = omp_get_max_active_levels();
uint64_t prev = omp_get_max_threads();
omp_set_max_active_levels(2);
// Outer team size (small); inner will be large when nesting is allowed
const int outer_threads = 1;
// Inner team size to request during nested-enabled run
const int inner_threads = 12;
uint64_t N = 4096*8;
utils::Matrix<uint8_t> I1(N, N, 7);
utils::Matrix<uint8_t> I2(N, N, 7);
// ---------- baseline: inner runs with a small team (no outer region) ----------
omp_set_max_active_levels(1); // disable nesting for baseline
omp_set_num_threads(outer_threads);
for (uint64_t i = 0; i < 2; ++i){
numerics::inplace_eye_omp<uint8_t>(I1);
}
// ---------- nested: outer×inner (only one outer thread launches inner) ----------
omp_set_max_active_levels(2); // allow one nested level
#pragma omp parallel num_threads(outer_threads)
{
#pragma omp single // avoid racing on I1
{
omp_set_num_threads(inner_threads);
for (uint64_t i = 0; i < 2; ++i){
numerics::inplace_eye_omp<uint8_t>(I2);
}
}
}
omp_set_max_active_levels(prev_levels);
omp_set_num_threads(prev);
CHECK(is_identity(I1), "EYE_OMP_nested did not produce identity");
CHECK(is_identity(I2), "EYE_OMP_nested did not produce identity");
}
TEST_CASE(EYE_OMP_auto_is_identity) {
auto I = numerics::eye_omp_auto<double>(32);
CHECK(is_identity(I), "eye_omp_auto did not produce identity");
}
#endif