Sync public subset from Flux

This commit is contained in:
Gitea CI
2025-10-06 20:21:40 +00:00
parent b2d00af0e1
commit 8892d58e66
15 changed files with 1825 additions and 0 deletions

127
test/test_matmul.cpp Normal file
View File

@@ -0,0 +1,127 @@
#include "test_common.h"
#include "./utils/utils.h"
#include "./numerics/matmul.h"
#include <chrono>
// ---------- helpers ----------
template <typename T>
static bool mats_equal(const utils::Matrix<T>& X, const utils::Matrix<T>& Y, double tol = 0.0) {
if (X.rows()!=Y.rows() || X.cols()!=Y.cols()) return false;
if (std::is_floating_point<T>::value) {
for (std::uint64_t i=0;i<X.rows();++i)
for (std::uint64_t j=0;j<X.cols();++j)
if (std::fabs(double(X(i,j)) - double(Y(i,j))) > tol) return false;
} else {
for (std::uint64_t i=0;i<X.rows();++i)
for (std::uint64_t j=0;j<X.cols();++j)
if (X(i,j) != Y(i,j)) return false;
}
return true;
}
template <typename T>
static void fill_seq(utils::Matrix<T>& M, T start = T(0), T step = T(1)) {
std::uint64_t k = 0;
for (std::uint64_t i=0;i<M.rows();++i)
for (std::uint64_t j=0;j<M.cols();++j,++k)
M(i,j) = start + step * static_cast<T>(k);
}
// ---------- tests ----------
// Small known example: (3x2) · (2x3)
TEST_CASE(Matmul_Small_Known) {
utils::Mi A(3,2,0), B(2,3,0);
// A = [1 2; 3 4; 5 6]
A(0,0)=1; A(0,1)=2;
A(1,0)=3; A(1,1)=4;
A(2,0)=5; A(2,1)=6;
// B = [7 8 9; 10 11 12]
B(0,0)=7; B(0,1)=8; B(0,2)=9;
B(1,0)=10; B(1,1)=11; B(1,2)=12;
auto C = numerics::matmul(A,B);
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
// Expected C:
// [27 30 33]
// [61 68 75]
// [95 106 117]
CHECK(C(0,0)==27 && C(0,1)==30 && C(0,2)==33, "row 0 wrong");
CHECK(C(1,0)==61 && C(1,1)==68 && C(1,2)==75, "row 1 wrong");
CHECK(C(2,0)==95 && C(2,1)==106 && C(2,2)==117, "row 2 wrong");
}
TEST_CASE(Matmul_DimMismatch_Throws) {
utils::Md A(2,3,1.0), B(4,2,2.0); // A.cols()!=B.rows()
bool threw=false;
try { (void)numerics::matmul(A,B); } catch(const std::runtime_error&) { threw=true; }
CHECK(threw, "matmul should throw on dim mismatch");
}
// Compare all variants vs serial on a moderate size
TEST_CASE(Matmul_Variants_Equal_Int) {
const std::uint64_t m=32, n=24, p=16;
utils::Mi A(m,n,0), B(n,p,0);
// deterministic fill (no randomness)
fill_seq(A, int64_t(1), int64_t(1));
fill_seq(B, int64_t(2), int64_t(3));
auto C_ref = numerics::matmul(A,B);
auto C_rows = numerics::matmul_rows_omp(A,B);
auto C_collapse = numerics::matmul_collapse_omp(A,B);
auto C_auto = numerics::matmul_auto(A,B);
CHECK(mats_equal(C_rows, C_ref), "rows_omp != serial");
CHECK(mats_equal(C_collapse, C_ref), "collapse_omp != serial");
CHECK(mats_equal(C_auto, C_ref), "auto != serial");
}
TEST_CASE(Matmul_Variants_Equal_Double) {
const std::uint64_t m=33, n=17, p=19;
utils::Md A(m,n,0.0), B(n,p,0.0);
fill_seq(A, 0.1, 0.01);
fill_seq(B, 1.0, 0.02);
auto C_ref = numerics::matmul(A,B);
auto C_rows = numerics::matmul_rows_omp(A,B);
auto C_collapse = numerics::matmul_collapse_omp(A,B);
auto C_auto = numerics::matmul_auto(A,B);
CHECK(mats_equal(C_rows, C_ref, 1e-9), "rows_omp != serial (double)");
CHECK(mats_equal(C_collapse, C_ref, 1e-9), "collapse_omp != serial (double)");
CHECK(mats_equal(C_auto, C_ref, 1e-9), "auto != serial (double)");
}
// Nested callsite sanity: call OMP variant from within an outer region
#ifdef _OPENMP
TEST_CASE(Matmul_OMP_Nested_Callsite) {
const std::uint64_t m=48, n=24, p=32;
utils::Mi A(m,n,0), B(n,p,0);
fill_seq(A, int64_t(1), int64_t(2));
fill_seq(B, int64_t(3), int64_t(1));
auto C_ref = numerics::matmul(A,B);
int prev_levels = omp_get_max_active_levels();
omp_set_max_active_levels(2);
utils::Mi C_nested;
#pragma omp parallel num_threads(2)
{
#pragma omp single
{
// either variant is fine; collapse(2) has more parallelism
C_nested = numerics::matmul_collapse_omp(A,B);
}
}
omp_set_max_active_levels(prev_levels);
CHECK(mats_equal(C_nested, C_ref), "nested collapse_omp result mismatch");
}
#endif