#pragma once #include //uint64_t //#include // std::runtime_error #include "utils/vector.h" #include "utils/matrix.h" namespace numerics::detail{ // ---------------- Matrix -> Scalar ---------------- template utils::Vector argmax_serial(const utils::Matrix& A) { const uint64_t rows = A.rows(); const uint64_t cols = A.cols(); utils::Vector idx(2, 0); T max = A(0,0); for (uint64_t i = 0; i < rows; ++i){ for (uint64_t j = 0; j < cols; ++j){ if (max < A(i,j)){ max = A(i,j); idx[0] = i; idx[1] = j; } } } return idx; } // ---------------- Vector -> Scalar ---------------- template uint64_t argmax_serial(const utils::Vector& v) { const uint64_t N = v.size(); uint64_t idx = 0; T max = v[0]; for (uint64_t i = 1; i < N; ++i){ if (max < v[i]){ max = v[i]; idx = i; } } return idx; } // ---------------- Matrix -> Vector ---------------- template utils::Vector argmax_rowwise_serial(const utils::Matrix& A) { const uint64_t rows = A.rows(); const uint64_t cols = A.cols(); utils::Vector max(rows, T{0}); utils::Vector idx(rows, 0); for (uint64_t i = 0; i < rows; ++i){ max[i] = A(i,0); for (uint64_t j = 1; j < cols; ++j){ if (max[i] < A(i,j)){ max[i] = A(i,j); idx[i] = j; } } } return idx; } template utils::Vector argmax_colwise_serial(const utils::Matrix& A) { const uint64_t rows = A.rows(); const uint64_t cols = A.cols(); utils::Vector max(cols, T{0}); utils::Vector idx(cols, 0); for (uint64_t j = 0; j < cols; ++j){ max[j] = A(0, j); for (uint64_t i = 1; i < rows; ++i){ if (max[j] < A(i,j)){ max[j] = A(i,j); idx[j] = i; } } } return idx; } } // namespace numerics