Files
Flux-openbuild/include/numerics/interpolation1d/interpolation1d_base.h
2025-10-06 20:14:13 +00:00

152 lines
3.9 KiB
C++

#pragma once
#include "./numerics/min.h"
#include "./numerics/max.h"
#include "./numerics/abs.h"
#include "./utils/vector.h"
namespace numerics{
template <typename T>
struct Base_interp{
int64_t n, mm;
int64_t jsav, dj;
bool cor;
const T *xx, *yy;
Base_interp(const utils::Vector<T>& x, const T *y, uint64_t m)
:n(x.size()), mm(m), jsav(0), cor(false), xx(&x[0]), yy(y){
//dj = numerics::min(static_cast<int64_t>(1), static_cast<int64_t>(std::pow(static_cast<T>(n), 0.25))); // from NR
dj = numerics::max(static_cast<int64_t>(1), static_cast<int64_t>(std::pow(static_cast<T>(n), 0.25))); // from chatbot
if (mm < 2 || n < mm) throw std::invalid_argument("Base_interp: invalid mm or n");
if (!xx || !yy) throw std::invalid_argument("Base_interp: null data pointers");
if (n < 2) throw std::invalid_argument("Base_interp: need at least 2 points");
bool asc = false;
if (xx[0] < xx[1]){
asc = true;
}
for (int64_t i = 1; i < n; ++i){
if (!(xx[i] > xx[i-1]) && asc) {
throw std::invalid_argument("x must be strictly increasing");
} else if (!(xx[i] < xx[i-1]) && !asc){
throw std::invalid_argument("x must be strictly decreasing");
}
}
}
T interp(T x){
int64_t jlo;
if (cor){
jlo = hunt(x);
}
else{
jlo = locate(x);
}
return rawinterp(jlo,x);
}
// Derived classes provide this as the actual interpolation method.
T virtual rawinterp(int64_t jlo, T x) = 0;
int64_t locate(const T x){
int64_t ju, jl;
int64_t jm;
if (n < 2 || mm < 2 || mm > n){
throw std::runtime_error("Interpolate: locate size error");
}
bool ascnd = (xx[n-1] >= xx[0]); // True if ascending order of table, false otherwise.
jl = 0; // Initialize lower
ju = n-1; // and upper limits.
while (ju - jl > 1) { // If we are not yet done,
jm = (ju+jl) >> 1; // compute a midpoint,
if ((x >= xx[jm]) == ascnd){
jl=jm; // and replace either the lower limit
}else{
ju=jm; // or the upper limit, as appropriate.
}
} // Repeat until the test condition is satisfied.
if (std::abs(jl - jsav) > dj){ // Decide whether to use hunt or locate next time.
cor = false;
}else{
cor = true;
}
jsav = jl;
return numerics::max(static_cast<int64_t>(0), numerics::min(n-mm, jl-((mm-2)>>1)));
}
int64_t hunt(const T x){
int64_t jl=jsav, jm, ju, inc=1;
if (n < 2 || mm < 2 || mm > n){
throw std::runtime_error("Interpolate: hunt size error");
}
bool ascnd=(xx[n-1] >= xx[0]); // True if ascending order of table, false otherwise.
if (jl < 0 || jl > n-1) { // Input guess not useful. Go immediately to bisection.
jl=0;
ju=n-1;
}else{
if ((x >= xx[jl]) == ascnd){ // Hunt up:
for (;;){
ju = jl + inc;
if (ju >= n-1){
ju = n-1;
break; // Off end of table.
}else if((x < xx[ju]) == ascnd){
break; // Found bracket.
}else{ // Not done, so double the increment and try again.
jl = ju;
inc += inc;
}
}
}else{ // Hunt down:
ju = jl;
for (;;){
jl = jl - inc;
if (jl <= 0){ //Off end of table.
jl = 0;
break;
}else if((x >= xx[jl]) == ascnd){
break; // Found bracket.
}
else{ // Not done, so double the increment and try again.
ju = jl;
inc += inc;
}
}
}
}
while(ju-jl > 1){ // Hunt is done, so begin the final bisection phase:
jm = (ju+jl) >> 1;
if ((x >= xx[jm]) == ascnd){
jl =jm;
}else{
ju=jm;
}
}
if (numerics::abs(jl-jsav) > dj){
cor = false;
}else{
cor = true;
}
jsav = jl;
return numerics::max(static_cast<int64_t>(0), numerics::min(n-mm, jl-((mm-2)>>1)));
}
};
} // namespace numerics