109 lines
3.3 KiB
C++
109 lines
3.3 KiB
C++
|
|
#include "test_common.h"
|
|
#include "./numerics/matequal.h"
|
|
|
|
using utils::Vf; using utils::Vd; using utils::Vi;
|
|
using utils::Mf; using utils::Md; using utils::Mi;
|
|
|
|
|
|
// ---------- helpers ----------
|
|
template <typename T>
|
|
static void fill_seq(utils::Matrix<T>& M, T start = T{0}, T step = T{1}) {
|
|
std::uint64_t k = 0;
|
|
for (std::uint64_t i = 0; i < M.rows(); ++i)
|
|
for (std::uint64_t j = 0; j < M.cols(); ++j, ++k)
|
|
M(i,j) = start + step * static_cast<T>(k);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ---------- tests ----------
|
|
|
|
TEST_CASE(matequal_shape_mismatch) {
|
|
utils::Mi A(3,3,0), B(3,4,0);
|
|
CHECK(!numerics::matequal(A,B), "shape mismatch should be false (serial)");
|
|
|
|
#ifdef _OPENMP
|
|
CHECK(!numerics::matequal_omp(A,B), "shape mismatch should be false (omp)");
|
|
#endif
|
|
CHECK(!numerics::matequal_auto(A,B), "shape mismatch should be false (auto)");
|
|
}
|
|
|
|
TEST_CASE(matequal_int_true_false) {
|
|
utils::Mi A(4,5,0), B(4,5,0);
|
|
fill_seq(A, int64_t(0), int64_t(1));
|
|
fill_seq(B, int64_t(0), int64_t(1));
|
|
CHECK(numerics::matequal(A,B), "ints equal (serial)");
|
|
#ifdef _OPENMP
|
|
CHECK(numerics::matequal_omp(A,B), "ints equal (omp)");
|
|
#endif
|
|
// flip one element
|
|
B(2,3) += 1;
|
|
CHECK(!numerics::matequal(A,B), "ints differ (serial)");
|
|
#ifdef _OPENMP
|
|
CHECK(!numerics::matequal_omp(A,B), "ints differ (omp)"); // will FAIL if your omp branch uses '!='
|
|
#endif
|
|
}
|
|
|
|
TEST_CASE(matequal_double_tolerance) {
|
|
utils::Md A(3,3,0.0), B(3,3,0.0);
|
|
fill_seq(A, double(1.0), double(0.125));
|
|
fill_seq(B, double(1.0), double(0.125));
|
|
// tiny perturbation within default tol
|
|
B(1,1) += 1e-12;
|
|
CHECK(numerics::matequal(A,B), "double within tol (serial)");
|
|
#ifdef _OPENMP
|
|
CHECK(numerics::matequal_omp(A,B), "double within tol (omp)");
|
|
#endif
|
|
// larger perturbation exceeds tol
|
|
B(0,2) += 1e-6;
|
|
CHECK(!numerics::matequal(A,B, 1e-9), "double exceeds tol (serial)");
|
|
#ifdef _OPENMP
|
|
CHECK(!numerics::matequal_omp(A,B, 1e-9), "double exceeds tol (omp)");
|
|
#endif
|
|
}
|
|
|
|
TEST_CASE(matequal_auto_agrees) {
|
|
// Choose size so auto likely takes the OMP path when available,
|
|
// but this test only checks correctness, not which path was taken.
|
|
utils::Md A(256,256,0.0), B(256,256,0.0);
|
|
fill_seq(A, double(0.0), double(0.01));
|
|
fill_seq(B, double(0.0), double(0.01));
|
|
CHECK(numerics::matequal_auto(A,B), "auto equal");
|
|
|
|
B(5,7) += 1e-3;
|
|
CHECK(!numerics::matequal_auto(A,B, 1e-9), "auto detects mismatch");
|
|
}
|
|
|
|
#ifdef _OPENMP
|
|
TEST_CASE(mateequal_omp_nested_callsite) {
|
|
// Verify correctness when called inside an outer parallel region.
|
|
utils::Mi A(128,128,0), B(128,128,0);
|
|
fill_seq(A, int64_t(0), int64_t(1));
|
|
fill_seq(B, int64_t(0), int64_t(1));
|
|
|
|
// allow one nested level; inner region inside mateequal_omp may spawn a team
|
|
int prev_levels = omp_get_max_active_levels();
|
|
omp_set_max_active_levels(2);
|
|
|
|
bool ok_equal = false, ok_diff = false;
|
|
|
|
#pragma omp parallel num_threads(2) shared(ok_equal, ok_diff)
|
|
{
|
|
#pragma omp single
|
|
{
|
|
ok_equal = numerics::matequal_omp(A,B);
|
|
B(10,10) += 1; // introduce a mismatch
|
|
ok_diff = !numerics::matequal_omp(A,B);
|
|
}
|
|
}
|
|
|
|
omp_set_max_active_levels(prev_levels);
|
|
|
|
CHECK(ok_equal, "nested equal should be true");
|
|
CHECK(ok_diff, "nested mismatch should be false");
|
|
}
|
|
#endif
|