#pragma once #include //uint64_t #include // std::runtime_error #include "utils/vector.h" #include "utils/matrix.h" namespace numerics::detail{ // ------------- Helper Functions ----------- template inline T clip_low_value(T x, const T low) { if (x < low) { return low; } return x; } template inline T clip_high_value(T x, const T high) { if (x > high) { return high; } return x; } template inline T clip_value(T x, const T low, const T high) { if (x < low) { return low; } if (x > high) { return high; } return x; } //--------------------------------------------- // ----------------- Clip --------------- //--------------------------------------------- // ---------------- Elemenwise ---------------- template inline void inplace_clip_scalar_serial(T& c, const T low, const T high) { if (low > high) { throw std::runtime_error("inplace_clip_scalar_serial: lower clip > higher clip"); } c = detail::clip_value(c, low, high); } template inline void inplace_clip_elementwise_serial(utils::Matrix& A, const T low, const T high) { if (low > high) { throw std::runtime_error("inplace_clip_elementwise_serial: lower clip > higher clip"); } const uint64_t rows = A.rows(); const uint64_t cols = A.cols(); for (uint64_t i = 0; i < rows; ++i){ for (uint64_t j = 0; j < cols; ++j){ A(i,j) = detail::clip_value(A(i,j), low, high); } } } template inline void inplace_clip_elementwise_serial(utils::Vector& v, const T low, const T high) { if (low > high) { throw std::runtime_error("inplace_clip_elementwise_serial: lower clip > higher clip"); } for (uint64_t i = 0; i < v.size(); ++i){ v[i] = detail::clip_value(v[i], low, high); } } //--------------------------------------------- // ----------------- Clip Low --------------- //--------------------------------------------- // ---------------- Elemenwise ---------------- template inline void inplace_clip_low_scalar_serial(T& c, const T low) { c = detail::clip_low_value(c, low); } template inline void inplace_clip_low_elementwise_serial(utils::Matrix& A, const T low) { const uint64_t rows = A.rows(); const uint64_t cols = A.cols(); for (uint64_t i = 0; i < rows; ++i){ for (uint64_t j = 0; j < cols; ++j){ A(i,j) = detail::clip_low_value(A(i,j), low); } } } template inline void inplace_clip_low_elementwise_serial(utils::Vector& v, const T low) { for (uint64_t i = 0; i < v.size(); ++i){ v[i] = detail::clip_low_value(v[i], low); } } //--------------------------------------------- // ----------------- Clip High --------------- //--------------------------------------------- // ---------------- Elemenwise ---------------- template inline void inplace_clip_high_scalar_serial(T& c, const T high) { c = detail::clip_high_value(c, high); } template inline void inplace_clip_high_elementwise_serial(utils::Matrix& A, const T high) { const uint64_t rows = A.rows(); const uint64_t cols = A.cols(); for (uint64_t i = 0; i < rows; ++i){ for (uint64_t j = 0; j < cols; ++j){ A(i,j) = detail::clip_high_value(A(i,j), high); } } } template inline void inplace_clip_high_elementwise_serial(utils::Vector& v, const T high) { for (uint64_t i = 0; i < v.size(); ++i){ v[i] = detail::clip_high_value(v[i], high); } } } // namespace numerics