#include "test_common.h" #include "./utils/utils.h" #include "./numerics/matmul.h" #include // ---------- helpers ---------- template static bool mats_equal(const utils::Matrix& X, const utils::Matrix& Y, double tol = 0.0) { if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false; if (std::is_floating_point::value) { for (std::uint64_t i=0;i tol) return false; } else { for (std::uint64_t i=0;i static void fill_seq(utils::Matrix& M, T start = T(0), T step = T(1)) { std::uint64_t k = 0; for (std::uint64_t i=0;i(k); } // ---------- tests ---------- // Small known example: (3x2) ยท (2x3) TEST_CASE(Matmul_Small_Known) { utils::Mi A(3,2,0), B(2,3,0); // A = [1 2; 3 4; 5 6] A(0,0)=1; A(0,1)=2; A(1,0)=3; A(1,1)=4; A(2,0)=5; A(2,1)=6; // B = [7 8 9; 10 11 12] B(0,0)=7; B(0,1)=8; B(0,2)=9; B(1,0)=10; B(1,1)=11; B(1,2)=12; auto C = numerics::matmul(A,B); CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong"); // Expected C: // [27 30 33] // [61 68 75] // [95 106 117] CHECK(C(0,0)==27 && C(0,1)==30 && C(0,2)==33, "row 0 wrong"); CHECK(C(1,0)==61 && C(1,1)==68 && C(1,2)==75, "row 1 wrong"); CHECK(C(2,0)==95 && C(2,1)==106 && C(2,2)==117, "row 2 wrong"); } TEST_CASE(Matmul_DimMismatch_Throws) { utils::Md A(2,3,1.0), B(4,2,2.0); // A.cols()!=B.rows() bool threw=false; try { (void)numerics::matmul(A,B); } catch(const std::runtime_error&) { threw=true; } CHECK(threw, "matmul should throw on dim mismatch"); } // Compare all variants vs serial on a moderate size TEST_CASE(Matmul_Variants_Equal_Int) { const std::uint64_t m=32, n=24, p=16; utils::Mi A(m,n,0), B(n,p,0); // deterministic fill (no randomness) fill_seq(A, int64_t(1), int64_t(1)); fill_seq(B, int64_t(2), int64_t(3)); auto C_ref = numerics::matmul(A,B); auto C_rows = numerics::matmul_rows_omp(A,B); auto C_collapse = numerics::matmul_collapse_omp(A,B); auto C_auto = numerics::matmul_auto(A,B); CHECK(mats_equal(C_rows, C_ref), "rows_omp != serial"); CHECK(mats_equal(C_collapse, C_ref), "collapse_omp != serial"); CHECK(mats_equal(C_auto, C_ref), "auto != serial"); } TEST_CASE(Matmul_Variants_Equal_Double) { const std::uint64_t m=33, n=17, p=19; utils::Md A(m,n,0.0), B(n,p,0.0); fill_seq(A, 0.1, 0.01); fill_seq(B, 1.0, 0.02); auto C_ref = numerics::matmul(A,B); auto C_rows = numerics::matmul_rows_omp(A,B); auto C_collapse = numerics::matmul_collapse_omp(A,B); auto C_auto = numerics::matmul_auto(A,B); CHECK(mats_equal(C_rows, C_ref, 1e-9), "rows_omp != serial (double)"); CHECK(mats_equal(C_collapse, C_ref, 1e-9), "collapse_omp != serial (double)"); CHECK(mats_equal(C_auto, C_ref, 1e-9), "auto != serial (double)"); } // Nested callsite sanity: call OMP variant from within an outer region #ifdef _OPENMP TEST_CASE(Matmul_OMP_Nested_Callsite) { const std::uint64_t m=48, n=24, p=32; utils::Mi A(m,n,0), B(n,p,0); fill_seq(A, int64_t(1), int64_t(2)); fill_seq(B, int64_t(3), int64_t(1)); auto C_ref = numerics::matmul(A,B); int prev_levels = omp_get_max_active_levels(); omp_set_max_active_levels(2); utils::Mi C_nested; #pragma omp parallel num_threads(2) { #pragma omp single { // either variant is fine; collapse(2) has more parallelism C_nested = numerics::matmul_collapse_omp(A,B); } } omp_set_max_active_levels(prev_levels); CHECK(mats_equal(C_nested, C_ref), "nested collapse_omp result mismatch"); } #endif