63 lines
1.2 KiB
C++
63 lines
1.2 KiB
C++
#pragma once
|
|
|
|
#include "./utils/matrix.h"
|
|
|
|
namespace numerics{
|
|
|
|
template <typename Ti, typename Td>
|
|
void inplace_matargmax_row(const utils::Matrix<Td>& A, utils::Vector<Ti>& b){
|
|
|
|
if (b.size() != A.rows()){
|
|
b.resize(A.rows(), Ti{0});
|
|
}
|
|
Td value;
|
|
|
|
for (uint64_t i = 0; i < A.rows(); ++i){
|
|
value = Td{0};
|
|
for (uint64_t j = 0; j < A.cols(); ++j){
|
|
if (value < A(i,j)){
|
|
value = A(i,j);
|
|
b[i] = j;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Ti, typename Td>
|
|
void inplace_matargmax_col(const utils::Matrix<Td>& A, utils::Vector<Ti>& b){
|
|
|
|
if (b.size() != A.cols()){
|
|
b(A.cols(), Ti{0});
|
|
}
|
|
Td value;
|
|
|
|
for (uint64_t j = 0; j < A.cols(); ++j){
|
|
value = Td{0};
|
|
for (uint64_t i = 0; i < A.cols(); ++i){
|
|
if (value < A(i,j)){
|
|
value = A(i,j);
|
|
b[j] = i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Ti, typename Td>
|
|
utils::Vector<Ti> matargmax_row(const utils::Matrix<Td>& A){
|
|
utils::Vector<Ti> b(A.rows(), Ti{0});
|
|
inplace_matargmax_row(A, b);
|
|
return b;
|
|
|
|
}
|
|
|
|
template <typename Ti, typename Td>
|
|
utils::Vector<Ti> matargmax_col(const utils::Matrix<Td>& A){
|
|
utils::Vector<Ti> b(A.rows(), Ti{0});
|
|
inplace_matargmax_col(A, b);
|
|
return b;
|
|
|
|
}
|
|
|
|
} // namespace numerics
|
|
|