#include "test_common.h" #include "./numerics/matequal.h" using utils::Vf; using utils::Vd; using utils::Vi; using utils::Mf; using utils::Md; using utils::Mi; // ---------- helpers ---------- template static void fill_seq(utils::Matrix& M, T start = T{0}, T step = T{1}) { std::uint64_t k = 0; for (std::uint64_t i = 0; i < M.rows(); ++i) for (std::uint64_t j = 0; j < M.cols(); ++j, ++k) M(i,j) = start + step * static_cast(k); } // ---------- tests ---------- TEST_CASE(matequal_shape_mismatch) { utils::Mi A(3,3,0), B(3,4,0); CHECK(!numerics::matequal(A,B), "shape mismatch should be false (serial)"); #ifdef _OPENMP CHECK(!numerics::matequal_omp(A,B), "shape mismatch should be false (omp)"); #endif CHECK(!numerics::matequal_auto(A,B), "shape mismatch should be false (auto)"); } TEST_CASE(matequal_int_true_false) { utils::Mi A(4,5,0), B(4,5,0); fill_seq(A, int64_t(0), int64_t(1)); fill_seq(B, int64_t(0), int64_t(1)); CHECK(numerics::matequal(A,B), "ints equal (serial)"); #ifdef _OPENMP CHECK(numerics::matequal_omp(A,B), "ints equal (omp)"); #endif // flip one element B(2,3) += 1; CHECK(!numerics::matequal(A,B), "ints differ (serial)"); #ifdef _OPENMP CHECK(!numerics::matequal_omp(A,B), "ints differ (omp)"); // will FAIL if your omp branch uses '!=' #endif } TEST_CASE(matequal_double_tolerance) { utils::Md A(3,3,0.0), B(3,3,0.0); fill_seq(A, double(1.0), double(0.125)); fill_seq(B, double(1.0), double(0.125)); // tiny perturbation within default tol B(1,1) += 1e-12; CHECK(numerics::matequal(A,B), "double within tol (serial)"); #ifdef _OPENMP CHECK(numerics::matequal_omp(A,B), "double within tol (omp)"); #endif // larger perturbation exceeds tol B(0,2) += 1e-6; CHECK(!numerics::matequal(A,B, 1e-9), "double exceeds tol (serial)"); #ifdef _OPENMP CHECK(!numerics::matequal_omp(A,B, 1e-9), "double exceeds tol (omp)"); #endif } TEST_CASE(matequal_auto_agrees) { // Choose size so auto likely takes the OMP path when available, // but this test only checks correctness, not which path was taken. utils::Md A(256,256,0.0), B(256,256,0.0); fill_seq(A, double(0.0), double(0.01)); fill_seq(B, double(0.0), double(0.01)); CHECK(numerics::matequal_auto(A,B), "auto equal"); B(5,7) += 1e-3; CHECK(!numerics::matequal_auto(A,B, 1e-9), "auto detects mismatch"); } #ifdef _OPENMP TEST_CASE(mateequal_omp_nested_callsite) { // Verify correctness when called inside an outer parallel region. utils::Mi A(128,128,0), B(128,128,0); fill_seq(A, int64_t(0), int64_t(1)); fill_seq(B, int64_t(0), int64_t(1)); // allow one nested level; inner region inside mateequal_omp may spawn a team int prev_levels = omp_get_max_active_levels(); omp_set_max_active_levels(2); bool ok_equal = false, ok_diff = false; #pragma omp parallel num_threads(2) shared(ok_equal, ok_diff) { #pragma omp single { ok_equal = numerics::matequal_omp(A,B); B(10,10) += 1; // introduce a mismatch ok_diff = !numerics::matequal_omp(A,B); } } omp_set_max_active_levels(prev_levels); CHECK(ok_equal, "nested equal should be true"); CHECK(ok_diff, "nested mismatch should be false"); } #endif