Sync public subset from Flux
This commit is contained in:
127
test/test_matmul.cpp
Normal file
127
test/test_matmul.cpp
Normal file
@@ -0,0 +1,127 @@
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h"
|
||||
#include "./numerics/matmul.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
// ---------- helpers ----------
|
||||
template <typename T>
|
||||
static bool mats_equal(const utils::Matrix<T>& X, const utils::Matrix<T>& Y, double tol = 0.0) {
|
||||
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
|
||||
if (std::is_floating_point<T>::value) {
|
||||
for (std::uint64_t i=0;i<X.rows();++i)
|
||||
for (std::uint64_t j=0;j<X.cols();++j)
|
||||
if (std::fabs(double(X(i,j)) - double(Y(i,j))) > tol) return false;
|
||||
} else {
|
||||
for (std::uint64_t i=0;i<X.rows();++i)
|
||||
for (std::uint64_t j=0;j<X.cols();++j)
|
||||
if (X(i,j) != Y(i,j)) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void fill_seq(utils::Matrix<T>& M, T start = T(0), T step = T(1)) {
|
||||
std::uint64_t k = 0;
|
||||
for (std::uint64_t i=0;i<M.rows();++i)
|
||||
for (std::uint64_t j=0;j<M.cols();++j,++k)
|
||||
M(i,j) = start + step * static_cast<T>(k);
|
||||
}
|
||||
// ---------- tests ----------
|
||||
|
||||
// Small known example: (3x2) · (2x3)
|
||||
TEST_CASE(Matmul_Small_Known) {
|
||||
utils::Mi A(3,2,0), B(2,3,0);
|
||||
// A = [1 2; 3 4; 5 6]
|
||||
A(0,0)=1; A(0,1)=2;
|
||||
A(1,0)=3; A(1,1)=4;
|
||||
A(2,0)=5; A(2,1)=6;
|
||||
// B = [7 8 9; 10 11 12]
|
||||
B(0,0)=7; B(0,1)=8; B(0,2)=9;
|
||||
B(1,0)=10; B(1,1)=11; B(1,2)=12;
|
||||
|
||||
auto C = numerics::matmul(A,B);
|
||||
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
|
||||
|
||||
// Expected C:
|
||||
// [27 30 33]
|
||||
// [61 68 75]
|
||||
// [95 106 117]
|
||||
CHECK(C(0,0)==27 && C(0,1)==30 && C(0,2)==33, "row 0 wrong");
|
||||
CHECK(C(1,0)==61 && C(1,1)==68 && C(1,2)==75, "row 1 wrong");
|
||||
CHECK(C(2,0)==95 && C(2,1)==106 && C(2,2)==117, "row 2 wrong");
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_DimMismatch_Throws) {
|
||||
utils::Md A(2,3,1.0), B(4,2,2.0); // A.cols()!=B.rows()
|
||||
bool threw=false;
|
||||
try { (void)numerics::matmul(A,B); } catch(const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "matmul should throw on dim mismatch");
|
||||
}
|
||||
|
||||
// Compare all variants vs serial on a moderate size
|
||||
TEST_CASE(Matmul_Variants_Equal_Int) {
|
||||
const std::uint64_t m=32, n=24, p=16;
|
||||
utils::Mi A(m,n,0), B(n,p,0);
|
||||
|
||||
// deterministic fill (no randomness)
|
||||
fill_seq(A, int64_t(1), int64_t(1));
|
||||
fill_seq(B, int64_t(2), int64_t(3));
|
||||
|
||||
auto C_ref = numerics::matmul(A,B);
|
||||
|
||||
auto C_rows = numerics::matmul_rows_omp(A,B);
|
||||
auto C_collapse = numerics::matmul_collapse_omp(A,B);
|
||||
auto C_auto = numerics::matmul_auto(A,B);
|
||||
|
||||
CHECK(mats_equal(C_rows, C_ref), "rows_omp != serial");
|
||||
CHECK(mats_equal(C_collapse, C_ref), "collapse_omp != serial");
|
||||
CHECK(mats_equal(C_auto, C_ref), "auto != serial");
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_Variants_Equal_Double) {
|
||||
const std::uint64_t m=33, n=17, p=19;
|
||||
utils::Md A(m,n,0.0), B(n,p,0.0);
|
||||
|
||||
fill_seq(A, 0.1, 0.01);
|
||||
fill_seq(B, 1.0, 0.02);
|
||||
|
||||
auto C_ref = numerics::matmul(A,B);
|
||||
auto C_rows = numerics::matmul_rows_omp(A,B);
|
||||
auto C_collapse = numerics::matmul_collapse_omp(A,B);
|
||||
auto C_auto = numerics::matmul_auto(A,B);
|
||||
|
||||
CHECK(mats_equal(C_rows, C_ref, 1e-9), "rows_omp != serial (double)");
|
||||
CHECK(mats_equal(C_collapse, C_ref, 1e-9), "collapse_omp != serial (double)");
|
||||
CHECK(mats_equal(C_auto, C_ref, 1e-9), "auto != serial (double)");
|
||||
}
|
||||
|
||||
// Nested callsite sanity: call OMP variant from within an outer region
|
||||
#ifdef _OPENMP
|
||||
TEST_CASE(Matmul_OMP_Nested_Callsite) {
|
||||
const std::uint64_t m=48, n=24, p=32;
|
||||
utils::Mi A(m,n,0), B(n,p,0);
|
||||
fill_seq(A, int64_t(1), int64_t(2));
|
||||
fill_seq(B, int64_t(3), int64_t(1));
|
||||
|
||||
auto C_ref = numerics::matmul(A,B);
|
||||
|
||||
int prev_levels = omp_get_max_active_levels();
|
||||
omp_set_max_active_levels(2);
|
||||
|
||||
utils::Mi C_nested;
|
||||
#pragma omp parallel num_threads(2)
|
||||
{
|
||||
#pragma omp single
|
||||
{
|
||||
// either variant is fine; collapse(2) has more parallelism
|
||||
C_nested = numerics::matmul_collapse_omp(A,B);
|
||||
}
|
||||
}
|
||||
|
||||
omp_set_max_active_levels(prev_levels);
|
||||
|
||||
CHECK(mats_equal(C_nested, C_ref), "nested collapse_omp result mismatch");
|
||||
}
|
||||
#endif
|
||||
Reference in New Issue
Block a user