#pragma once #include "./numerics/min.h" #include "./numerics/max.h" #include "./numerics/abs.h" #include "./utils/vector.h" namespace numerics{ template struct Base_interp{ int64_t n, mm; int64_t jsav, dj; bool cor; const T *xx, *yy; Base_interp(const utils::Vector& x, const T *y, uint64_t m) :n(x.size()), mm(m), jsav(0), cor(false), xx(&x[0]), yy(y){ //dj = numerics::min(static_cast(1), static_cast(std::pow(static_cast(n), 0.25))); // from NR dj = numerics::max(static_cast(1), static_cast(std::pow(static_cast(n), 0.25))); // from chatbot if (mm < 2 || n < mm) throw std::invalid_argument("Base_interp: invalid mm or n"); if (!xx || !yy) throw std::invalid_argument("Base_interp: null data pointers"); if (n < 2) throw std::invalid_argument("Base_interp: need at least 2 points"); bool asc = false; if (xx[0] < xx[1]){ asc = true; } for (int64_t i = 1; i < n; ++i){ if (!(xx[i] > xx[i-1]) && asc) { throw std::invalid_argument("x must be strictly increasing"); } else if (!(xx[i] < xx[i-1]) && !asc){ throw std::invalid_argument("x must be strictly decreasing"); } } } T interp(T x){ int64_t jlo; if (cor){ jlo = hunt(x); } else{ jlo = locate(x); } return rawinterp(jlo,x); } // Derived classes provide this as the actual interpolation method. T virtual rawinterp(int64_t jlo, T x) = 0; int64_t locate(const T x){ int64_t ju, jl; int64_t jm; if (n < 2 || mm < 2 || mm > n){ throw std::runtime_error("Interpolate: locate size error"); } bool ascnd = (xx[n-1] >= xx[0]); // True if ascending order of table, false otherwise. jl = 0; // Initialize lower ju = n-1; // and upper limits. while (ju - jl > 1) { // If we are not yet done, jm = (ju+jl) >> 1; // compute a midpoint, if ((x >= xx[jm]) == ascnd){ jl=jm; // and replace either the lower limit }else{ ju=jm; // or the upper limit, as appropriate. } } // Repeat until the test condition is satisfied. if (std::abs(jl - jsav) > dj){ // Decide whether to use hunt or locate next time. cor = false; }else{ cor = true; } jsav = jl; return numerics::max(static_cast(0), numerics::min(n-mm, jl-((mm-2)>>1))); } int64_t hunt(const T x){ int64_t jl=jsav, jm, ju, inc=1; if (n < 2 || mm < 2 || mm > n){ throw std::runtime_error("Interpolate: hunt size error"); } bool ascnd=(xx[n-1] >= xx[0]); // True if ascending order of table, false otherwise. if (jl < 0 || jl > n-1) { // Input guess not useful. Go immediately to bisection. jl=0; ju=n-1; }else{ if ((x >= xx[jl]) == ascnd){ // Hunt up: for (;;){ ju = jl + inc; if (ju >= n-1){ ju = n-1; break; // Off end of table. }else if((x < xx[ju]) == ascnd){ break; // Found bracket. }else{ // Not done, so double the increment and try again. jl = ju; inc += inc; } } }else{ // Hunt down: ju = jl; for (;;){ jl = jl - inc; if (jl <= 0){ //Off end of table. jl = 0; break; }else if((x >= xx[jl]) == ascnd){ break; // Found bracket. } else{ // Not done, so double the increment and try again. ju = jl; inc += inc; } } } } while(ju-jl > 1){ // Hunt is done, so begin the final bisection phase: jm = (ju+jl) >> 1; if ((x >= xx[jm]) == ascnd){ jl =jm; }else{ ju=jm; } } if (numerics::abs(jl-jsav) > dj){ cor = false; }else{ cor = true; } jsav = jl; return numerics::max(static_cast(0), numerics::min(n-mm, jl-((mm-2)>>1))); } }; } // namespace numerics