Files
Flux-openbuild/include/numerics/matargmax.h
2025-10-08 16:08:04 +00:00

63 lines
1.2 KiB
C++

#pragma once
#include "./utils/matrix.h"
namespace numerics{
template <typename Ti, typename Td>
void inplace_matargmax_row(const utils::Matrix<Td>& A, utils::Vector<Ti>& b){
if (b.size() != A.rows()){
b.resize(A.rows(), Ti{0});
}
Td value;
for (uint64_t i = 0; i < A.rows(); ++i){
value = Td{0};
for (uint64_t j = 0; j < A.cols(); ++j){
if (value < A(i,j)){
value = A(i,j);
b[i] = j;
}
}
}
}
template <typename Ti, typename Td>
void inplace_matargmax_col(const utils::Matrix<Td>& A, utils::Vector<Ti>& b){
if (b.size() != A.cols()){
b(A.cols(), Ti{0});
}
Td value;
for (uint64_t j = 0; j < A.cols(); ++j){
value = Td{0};
for (uint64_t i = 0; i < A.cols(); ++i){
if (value < A(i,j)){
value = A(i,j);
b[j] = i;
}
}
}
}
template <typename Ti, typename Td>
utils::Vector<Ti> matargmax_row(const utils::Matrix<Td>& A){
utils::Vector<Ti> b(A.rows(), Ti{0});
inplace_matargmax_row(A, b);
return b;
}
template <typename Ti, typename Td>
utils::Vector<Ti> matargmax_col(const utils::Matrix<Td>& A){
utils::Vector<Ti> b(A.rows(), Ti{0});
inplace_matargmax_col(A, b);
return b;
}
} // namespace numerics