#include "test_common.h" #include "./utils/utils.h" #include "./numerics/matmul.h" #include // ============ Basic correctness ============ TEST_CASE(Matmul_Serial_Simple3x3) { utils::Md A(3,3,0.0), B(3,3,0.0); // A = [[1,2,3],[4,5,6],[7,8,9]] double v=1.0; for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) A(i,j)=v++; // B = [[9,8,7],[6,5,4],[3,2,1]] double w=9.0; for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) B(i,j)=w--; auto C = numerics::matmul(A,B); // Hand-checked first row: // row0 dot columns: // c00 = 1*9 + 2*6 + 3*3 = 30 // c01 = 1*8 + 2*5 + 3*2 = 24 // c02 = 1*7 + 2*4 + 3*1 = 18 CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong"); CHECK(C(0,0)==30.0 && C(0,1)==24.0 && C(0,2)==18.0, "first row wrong"); } TEST_CASE(Matmul_OMP_Equals_Serial) { utils::Md A(4,5,0.0), B(5,3,0.0); // Fill deterministic for (uint64_t i=0;i(A,B); auto Cr = numerics::matmul_rows_omp(A,B); auto Cc = numerics::matmul_collapse_omp(A,B); auto Ca = numerics::matmul_auto(A,B); CHECK((Cs.nearly_equal(Cr, 1e-12)), "rows_omp != serial"); CHECK((Cs.nearly_equal(Cc, 1e-12)), "collapse_omp != serial"); CHECK((Cs.nearly_equal(Ca, 1e-12)), "auto != serial"); } // ============ Dimension mismatch ============ TEST_CASE(Matmul_DimensionMismatch_Throws) { utils::Md A(2,3,0.0), B(4,2,0.0); bool threw=false; try { auto _ = numerics::matmul(A,B); (void)_; } catch (const std::runtime_error&) { threw=true; } CHECK(threw, "serial should throw on dim mismatch"); threw=false; try { auto _ = numerics::matmul_rows_omp(A,B); (void)_; } catch (const std::runtime_error&) { threw=true; } CHECK(threw, "rows_omp should throw on dim mismatch"); threw=false; try { auto _ = numerics::matmul_collapse_omp(A,B); (void)_; } catch (const std::runtime_error&) { threw=true; } CHECK(threw, "collapse_omp should throw on dim mismatch"); } // ============ Edge cases ============ TEST_CASE(Matmul_Edges_ZeroDims) { // (0xK) * (KxP) -> (0xP) utils::Md A0(0,5,0.0), B1(5,3,0.0); auto C0 = numerics::matmul(A0,B1); CHECK(C0.rows()==0 && C0.cols()==3, "0xK * KxP shape wrong"); // (MxK) * (Kx0) -> (Mx0) utils::Md A2(7,4,0.0), B0(4,0,0.0); auto C1 = numerics::matmul(A2,B0); CHECK(C1.rows()==7 && C1.cols()==0, "MxK * Kx0 shape wrong"); } // ============ Identity sanity ============ TEST_CASE(Matmul_Identity) { const uint64_t n=5; utils::Md I(n,n,0.0), A(n,n,0.0); for (uint64_t i=0;i(I,A); auto R = numerics::matmul(A,I); CHECK(L == A, "I*A != A"); CHECK(R == A, "A*I != A"); } // ============ Perf sanity (same kernel: 1 thread vs many) ============ template static double time_it(F&& f, int iters=1) { auto t0 = std::chrono::high_resolution_clock::now(); for (int i=0;i(t1 - t0).count(); } TEST_CASE(Matmul_Perf_Sanity_RowOMP) { #ifndef _OPENMP return; #else int hw = omp_get_max_threads(); if (hw <= 1) return; const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed utils::Md A(m,k,0.0), B(k,p,0.0); for (uint64_t i=0;i(A,B); int prev = omp_get_max_threads(); auto t0 = std::chrono::high_resolution_clock::now(); omp_set_num_threads(1); numerics::matmul_rows_omp(A,B); double t1 = std::chrono::duration(std::chrono::high_resolution_clock::now() - t0).count(); omp_set_num_threads(hw); t0 = std::chrono::high_resolution_clock::now(); numerics::matmul_rows_omp(A,B); double tN = std::chrono::duration(std::chrono::high_resolution_clock::now() - t0).count(); omp_set_num_threads(prev); // Must not be notably slower with many threads CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread"); #endif } TEST_CASE(Matmul_Perf_Sanity_CollapseOMP) { #ifndef _OPENMP return; #else int hw = omp_get_max_threads(); if (hw <= 1) return; const uint64_t m=512, k=512, p=512; utils::Md A(m,k,0.0), B(k,p,0.0); for (uint64_t i=0;i(A,B); // warm-up int prev = omp_get_max_threads(); auto t0 = std::chrono::high_resolution_clock::now(); omp_set_num_threads(1); numerics::matmul_collapse_omp(A,B); double t1 = std::chrono::duration(std::chrono::high_resolution_clock::now() - t0).count(); omp_set_num_threads(hw); t0 = std::chrono::high_resolution_clock::now(); numerics::matmul_collapse_omp(A,B); double tN = std::chrono::duration(std::chrono::high_resolution_clock::now() - t0).count(); omp_set_num_threads(prev); CHECK(tN <= t1 * 1.05, "collapse_omp: multi-thread slower than single-thread"); #endif }