Sync public subset from Flux
This commit is contained in:
108
test/test_matequal.cpp
Normal file
108
test/test_matequal.cpp
Normal file
@@ -0,0 +1,108 @@
|
||||
|
||||
#include "test_common.h"
|
||||
#include "./numerics/matequal.h"
|
||||
|
||||
using utils::Vf; using utils::Vd; using utils::Vi;
|
||||
using utils::Mf; using utils::Md; using utils::Mi;
|
||||
|
||||
|
||||
// ---------- helpers ----------
|
||||
template <typename T>
|
||||
static void fill_seq(utils::Matrix<T>& M, T start = T{0}, T step = T{1}) {
|
||||
std::uint64_t k = 0;
|
||||
for (std::uint64_t i = 0; i < M.rows(); ++i)
|
||||
for (std::uint64_t j = 0; j < M.cols(); ++j, ++k)
|
||||
M(i,j) = start + step * static_cast<T>(k);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// ---------- tests ----------
|
||||
|
||||
TEST_CASE(matequal_shape_mismatch) {
|
||||
utils::Mi A(3,3,0), B(3,4,0);
|
||||
CHECK(!numerics::matequal(A,B), "shape mismatch should be false (serial)");
|
||||
|
||||
#ifdef _OPENMP
|
||||
CHECK(!numerics::matequal_omp(A,B), "shape mismatch should be false (omp)");
|
||||
#endif
|
||||
CHECK(!numerics::matequal_auto(A,B), "shape mismatch should be false (auto)");
|
||||
}
|
||||
|
||||
TEST_CASE(matequal_int_true_false) {
|
||||
utils::Mi A(4,5,0), B(4,5,0);
|
||||
fill_seq(A, int64_t(0), int64_t(1));
|
||||
fill_seq(B, int64_t(0), int64_t(1));
|
||||
CHECK(numerics::matequal(A,B), "ints equal (serial)");
|
||||
#ifdef _OPENMP
|
||||
CHECK(numerics::matequal_omp(A,B), "ints equal (omp)");
|
||||
#endif
|
||||
// flip one element
|
||||
B(2,3) += 1;
|
||||
CHECK(!numerics::matequal(A,B), "ints differ (serial)");
|
||||
#ifdef _OPENMP
|
||||
CHECK(!numerics::matequal_omp(A,B), "ints differ (omp)"); // will FAIL if your omp branch uses '!='
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_CASE(matequal_double_tolerance) {
|
||||
utils::Md A(3,3,0.0), B(3,3,0.0);
|
||||
fill_seq(A, double(1.0), double(0.125));
|
||||
fill_seq(B, double(1.0), double(0.125));
|
||||
// tiny perturbation within default tol
|
||||
B(1,1) += 1e-12;
|
||||
CHECK(numerics::matequal(A,B), "double within tol (serial)");
|
||||
#ifdef _OPENMP
|
||||
CHECK(numerics::matequal_omp(A,B), "double within tol (omp)");
|
||||
#endif
|
||||
// larger perturbation exceeds tol
|
||||
B(0,2) += 1e-6;
|
||||
CHECK(!numerics::matequal(A,B, 1e-9), "double exceeds tol (serial)");
|
||||
#ifdef _OPENMP
|
||||
CHECK(!numerics::matequal_omp(A,B, 1e-9), "double exceeds tol (omp)");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_CASE(matequal_auto_agrees) {
|
||||
// Choose size so auto likely takes the OMP path when available,
|
||||
// but this test only checks correctness, not which path was taken.
|
||||
utils::Md A(256,256,0.0), B(256,256,0.0);
|
||||
fill_seq(A, double(0.0), double(0.01));
|
||||
fill_seq(B, double(0.0), double(0.01));
|
||||
CHECK(numerics::matequal_auto(A,B), "auto equal");
|
||||
|
||||
B(5,7) += 1e-3;
|
||||
CHECK(!numerics::matequal_auto(A,B, 1e-9), "auto detects mismatch");
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
TEST_CASE(mateequal_omp_nested_callsite) {
|
||||
// Verify correctness when called inside an outer parallel region.
|
||||
utils::Mi A(128,128,0), B(128,128,0);
|
||||
fill_seq(A, int64_t(0), int64_t(1));
|
||||
fill_seq(B, int64_t(0), int64_t(1));
|
||||
|
||||
// allow one nested level; inner region inside mateequal_omp may spawn a team
|
||||
int prev_levels = omp_get_max_active_levels();
|
||||
omp_set_max_active_levels(2);
|
||||
|
||||
bool ok_equal = false, ok_diff = false;
|
||||
|
||||
#pragma omp parallel num_threads(2) shared(ok_equal, ok_diff)
|
||||
{
|
||||
#pragma omp single
|
||||
{
|
||||
ok_equal = numerics::matequal_omp(A,B);
|
||||
B(10,10) += 1; // introduce a mismatch
|
||||
ok_diff = !numerics::matequal_omp(A,B);
|
||||
}
|
||||
}
|
||||
|
||||
omp_set_max_active_levels(prev_levels);
|
||||
|
||||
CHECK(ok_equal, "nested equal should be true");
|
||||
CHECK(ok_diff, "nested mismatch should be false");
|
||||
}
|
||||
#endif
|
||||
Reference in New Issue
Block a user