152 lines
3.9 KiB
C++
152 lines
3.9 KiB
C++
#pragma once
|
|
|
|
#include "./numerics/min.h"
|
|
#include "./numerics/max.h"
|
|
#include "./numerics/abs.h"
|
|
|
|
#include "./utils/vector.h"
|
|
|
|
|
|
namespace numerics{
|
|
|
|
template <typename T>
|
|
struct Base_interp{
|
|
|
|
int64_t n, mm;
|
|
int64_t jsav, dj;
|
|
bool cor;
|
|
const T *xx, *yy;
|
|
|
|
|
|
Base_interp(const utils::Vector<T>& x, const T *y, uint64_t m)
|
|
:n(x.size()), mm(m), jsav(0), cor(false), xx(&x[0]), yy(y){
|
|
//dj = numerics::min(static_cast<int64_t>(1), static_cast<int64_t>(std::pow(static_cast<T>(n), 0.25))); // from NR
|
|
dj = numerics::max(static_cast<int64_t>(1), static_cast<int64_t>(std::pow(static_cast<T>(n), 0.25))); // from chatbot
|
|
|
|
if (mm < 2 || n < mm) throw std::invalid_argument("Base_interp: invalid mm or n");
|
|
if (!xx || !yy) throw std::invalid_argument("Base_interp: null data pointers");
|
|
if (n < 2) throw std::invalid_argument("Base_interp: need at least 2 points");
|
|
|
|
bool asc = false;
|
|
if (xx[0] < xx[1]){
|
|
asc = true;
|
|
}
|
|
for (int64_t i = 1; i < n; ++i){
|
|
if (!(xx[i] > xx[i-1]) && asc) {
|
|
throw std::invalid_argument("x must be strictly increasing");
|
|
} else if (!(xx[i] < xx[i-1]) && !asc){
|
|
throw std::invalid_argument("x must be strictly decreasing");
|
|
}
|
|
}
|
|
}
|
|
|
|
T interp(T x){
|
|
int64_t jlo;
|
|
if (cor){
|
|
jlo = hunt(x);
|
|
}
|
|
else{
|
|
jlo = locate(x);
|
|
}
|
|
return rawinterp(jlo,x);
|
|
}
|
|
|
|
// Derived classes provide this as the actual interpolation method.
|
|
T virtual rawinterp(int64_t jlo, T x) = 0;
|
|
|
|
|
|
int64_t locate(const T x){
|
|
int64_t ju, jl;
|
|
int64_t jm;
|
|
|
|
if (n < 2 || mm < 2 || mm > n){
|
|
throw std::runtime_error("Interpolate: locate size error");
|
|
}
|
|
|
|
bool ascnd = (xx[n-1] >= xx[0]); // True if ascending order of table, false otherwise.
|
|
jl = 0; // Initialize lower
|
|
ju = n-1; // and upper limits.
|
|
while (ju - jl > 1) { // If we are not yet done,
|
|
jm = (ju+jl) >> 1; // compute a midpoint,
|
|
if ((x >= xx[jm]) == ascnd){
|
|
jl=jm; // and replace either the lower limit
|
|
}else{
|
|
ju=jm; // or the upper limit, as appropriate.
|
|
}
|
|
} // Repeat until the test condition is satisfied.
|
|
|
|
if (std::abs(jl - jsav) > dj){ // Decide whether to use hunt or locate next time.
|
|
cor = false;
|
|
}else{
|
|
cor = true;
|
|
}
|
|
jsav = jl;
|
|
return numerics::max(static_cast<int64_t>(0), numerics::min(n-mm, jl-((mm-2)>>1)));
|
|
}
|
|
|
|
int64_t hunt(const T x){
|
|
int64_t jl=jsav, jm, ju, inc=1;
|
|
|
|
if (n < 2 || mm < 2 || mm > n){
|
|
throw std::runtime_error("Interpolate: hunt size error");
|
|
}
|
|
bool ascnd=(xx[n-1] >= xx[0]); // True if ascending order of table, false otherwise.
|
|
if (jl < 0 || jl > n-1) { // Input guess not useful. Go immediately to bisection.
|
|
jl=0;
|
|
ju=n-1;
|
|
}else{
|
|
if ((x >= xx[jl]) == ascnd){ // Hunt up:
|
|
for (;;){
|
|
ju = jl + inc;
|
|
if (ju >= n-1){
|
|
ju = n-1;
|
|
break; // Off end of table.
|
|
}else if((x < xx[ju]) == ascnd){
|
|
break; // Found bracket.
|
|
}else{ // Not done, so double the increment and try again.
|
|
jl = ju;
|
|
inc += inc;
|
|
}
|
|
}
|
|
}else{ // Hunt down:
|
|
ju = jl;
|
|
for (;;){
|
|
jl = jl - inc;
|
|
if (jl <= 0){ //Off end of table.
|
|
jl = 0;
|
|
break;
|
|
}else if((x >= xx[jl]) == ascnd){
|
|
break; // Found bracket.
|
|
}
|
|
else{ // Not done, so double the increment and try again.
|
|
ju = jl;
|
|
inc += inc;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
while(ju-jl > 1){ // Hunt is done, so begin the final bisection phase:
|
|
jm = (ju+jl) >> 1;
|
|
if ((x >= xx[jm]) == ascnd){
|
|
jl =jm;
|
|
}else{
|
|
ju=jm;
|
|
}
|
|
}
|
|
if (numerics::abs(jl-jsav) > dj){
|
|
cor = false;
|
|
}else{
|
|
cor = true;
|
|
}
|
|
jsav = jl;
|
|
return numerics::max(static_cast<int64_t>(0), numerics::min(n-mm, jl-((mm-2)>>1)));
|
|
|
|
}
|
|
|
|
};
|
|
|
|
} // namespace numerics
|
|
|