Finishing up and starting lu decomp
This commit is contained in:
@@ -0,0 +1,164 @@
|
||||
#include "test_common.h"
|
||||
#include "./utils/utils.h"
|
||||
#include "./numerics/matmul.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
|
||||
// ============ Basic correctness ============
|
||||
TEST_CASE(Matmul_Serial_Simple3x3) {
|
||||
utils::Md A(3,3,0.0), B(3,3,0.0);
|
||||
// A = [[1,2,3],[4,5,6],[7,8,9]]
|
||||
double v=1.0;
|
||||
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) A(i,j)=v++;
|
||||
// B = [[9,8,7],[6,5,4],[3,2,1]]
|
||||
double w=9.0;
|
||||
for (uint64_t i=0;i<3;++i) for (uint64_t j=0;j<3;++j) B(i,j)=w--;
|
||||
|
||||
auto C = numerics::matmul<double>(A,B);
|
||||
// Hand-checked first row:
|
||||
// row0 dot columns:
|
||||
// c00 = 1*9 + 2*6 + 3*3 = 30
|
||||
// c01 = 1*8 + 2*5 + 3*2 = 24
|
||||
// c02 = 1*7 + 2*4 + 3*1 = 18
|
||||
CHECK(C.rows()==3 && C.cols()==3, "shape 3x3 wrong");
|
||||
CHECK(C(0,0)==30.0 && C(0,1)==24.0 && C(0,2)==18.0, "first row wrong");
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_OMP_Equals_Serial) {
|
||||
utils::Md A(4,5,0.0), B(5,3,0.0);
|
||||
// Fill deterministic
|
||||
for (uint64_t i=0;i<A.rows();++i)
|
||||
for (uint64_t j=0;j<A.cols();++j)
|
||||
A(i,j) = 0.1*(1 + (i*17 + j*19)%10);
|
||||
for (uint64_t i=0;i<B.rows();++i)
|
||||
for (uint64_t j=0;j<B.cols();++j)
|
||||
B(i,j) = 0.2*(1 + (i*23 + j*29)%10);
|
||||
|
||||
auto Cs = numerics::matmul<double>(A,B);
|
||||
auto Cr = numerics::matmul_rows_omp<double>(A,B);
|
||||
auto Cc = numerics::matmul_collapse_omp<double>(A,B);
|
||||
auto Ca = numerics::matmul_auto<double>(A,B);
|
||||
|
||||
CHECK((Cs.nearly_equal(Cr, 1e-12)), "rows_omp != serial");
|
||||
CHECK((Cs.nearly_equal(Cc, 1e-12)), "collapse_omp != serial");
|
||||
CHECK((Cs.nearly_equal(Ca, 1e-12)), "auto != serial");
|
||||
}
|
||||
|
||||
// ============ Dimension mismatch ============
|
||||
TEST_CASE(Matmul_DimensionMismatch_Throws) {
|
||||
utils::Md A(2,3,0.0), B(4,2,0.0);
|
||||
bool threw=false;
|
||||
try { auto _ = numerics::matmul<double>(A,B); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "serial should throw on dim mismatch");
|
||||
|
||||
threw=false; try { auto _ = numerics::matmul_rows_omp<double>(A,B); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "rows_omp should throw on dim mismatch");
|
||||
|
||||
threw=false; try { auto _ = numerics::matmul_collapse_omp<double>(A,B); (void)_; }
|
||||
catch (const std::runtime_error&) { threw=true; }
|
||||
CHECK(threw, "collapse_omp should throw on dim mismatch");
|
||||
}
|
||||
|
||||
// ============ Edge cases ============
|
||||
TEST_CASE(Matmul_Edges_ZeroDims) {
|
||||
// (0xK) * (KxP) -> (0xP)
|
||||
utils::Md A0(0,5,0.0), B1(5,3,0.0);
|
||||
auto C0 = numerics::matmul<double>(A0,B1);
|
||||
CHECK(C0.rows()==0 && C0.cols()==3, "0xK * KxP shape wrong");
|
||||
|
||||
// (MxK) * (Kx0) -> (Mx0)
|
||||
utils::Md A2(7,4,0.0), B0(4,0,0.0);
|
||||
auto C1 = numerics::matmul<double>(A2,B0);
|
||||
CHECK(C1.rows()==7 && C1.cols()==0, "MxK * Kx0 shape wrong");
|
||||
}
|
||||
|
||||
// ============ Identity sanity ============
|
||||
TEST_CASE(Matmul_Identity) {
|
||||
const uint64_t n=5;
|
||||
utils::Md I(n,n,0.0), A(n,n,0.0);
|
||||
for (uint64_t i=0;i<n;++i) I(i,i)=1.0;
|
||||
for (uint64_t i=0;i<n;++i)
|
||||
for (uint64_t j=0;j<n;++j)
|
||||
A(i,j) = (i==j)? 2.0 : ( (i<j)? 1.0 : -1.0 );
|
||||
|
||||
auto L = numerics::matmul<double>(I,A);
|
||||
auto R = numerics::matmul<double>(A,I);
|
||||
CHECK(L == A, "I*A != A");
|
||||
CHECK(R == A, "A*I != A");
|
||||
}
|
||||
|
||||
// ============ Perf sanity (same kernel: 1 thread vs many) ============
|
||||
template <class F>
|
||||
static double time_it(F&& f, int iters=1) {
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
for (int i=0;i<iters;++i) f();
|
||||
auto t1 = std::chrono::high_resolution_clock::now();
|
||||
return std::chrono::duration<double>(t1 - t0).count();
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_Perf_Sanity_RowOMP) {
|
||||
#ifndef _OPENMP
|
||||
return;
|
||||
#else
|
||||
int hw = omp_get_max_threads();
|
||||
if (hw <= 1) return;
|
||||
|
||||
const uint64_t m=512, k=512, p=512; // ~134M MACs; adjust if needed
|
||||
utils::Md A(m,k,0.0), B(k,p,0.0);
|
||||
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i+j%7)*0.001;
|
||||
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i*3+j%5)*0.0005;
|
||||
|
||||
// Warm-up
|
||||
(void) numerics::matmul_rows_omp<double>(A,B);
|
||||
|
||||
int prev = omp_get_max_threads();
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(1);
|
||||
numerics::matmul_rows_omp<double>(A,B);
|
||||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(hw);
|
||||
t0 = std::chrono::high_resolution_clock::now();
|
||||
numerics::matmul_rows_omp<double>(A,B);
|
||||
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(prev);
|
||||
|
||||
// Must not be notably slower with many threads
|
||||
CHECK(tN <= t1 * 1.05, "rows_omp: multi-thread slower than single-thread");
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_CASE(Matmul_Perf_Sanity_CollapseOMP) {
|
||||
#ifndef _OPENMP
|
||||
return;
|
||||
#else
|
||||
int hw = omp_get_max_threads();
|
||||
if (hw <= 1) return;
|
||||
|
||||
const uint64_t m=512, k=512, p=512;
|
||||
utils::Md A(m,k,0.0), B(k,p,0.0);
|
||||
for (uint64_t i=0;i<m;++i) for (uint64_t j=0;j<k;++j) A(i,j)= (i*7+j%11)*0.0003;
|
||||
for (uint64_t i=0;i<k;++i) for (uint64_t j=0;j<p;++j) B(i,j)= (i%13+j)*0.0002;
|
||||
|
||||
(void) numerics::matmul_collapse_omp<double>(A,B); // warm-up
|
||||
|
||||
int prev = omp_get_max_threads();
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(1);
|
||||
numerics::matmul_collapse_omp<double>(A,B);
|
||||
double t1 = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(hw);
|
||||
t0 = std::chrono::high_resolution_clock::now();
|
||||
numerics::matmul_collapse_omp<double>(A,B);
|
||||
double tN = std::chrono::duration<double>(std::chrono::high_resolution_clock::now() - t0).count();
|
||||
|
||||
omp_set_num_threads(prev);
|
||||
|
||||
CHECK(tN <= t1 * 1.05, "collapse_omp: multi-thread slower than single-thread");
|
||||
#endif
|
||||
}
|
||||
Reference in New Issue
Block a user