7 #ifndef HEFFTE_STOCK_ALGOS_H 
    8 #define HEFFTE_STOCK_ALGOS_H 
   12 #include "heffte_stock_complex.h" 
   13 #include "heffte_stock_allocator.h" 
   14 #include "heffte_common.h" 
   17 #define HEFFTE_STOCK_THRESHOLD 1 
   27 template<
typename F, 
int L>
 
   34 template<
typename F, 
int L>
 
   37         F a = 2.*M_PI*((F) power)/((F) N);
 
   43 enum fft_type {pow2, pow3, pow4, composite, discrete, rader};
 
   49 template<
typename F, 
int L>
 
   57         explicit Fourier_Transform(
size_t a, 
size_t ainv): type(fft_type::rader), root(a), root_inv(ainv) { }
 
   60                 case fft_type::pow2: pow2_FFT(x, y, s_in, s_out, sRoot, dir); 
break;
 
   61                 case fft_type::pow3: pow3_FFT(x, y, s_in, s_out, sRoot, dir); 
break;
 
   62                 case fft_type::pow4: pow4_FFT(x, y, s_in, s_out, sRoot, dir); 
break;
 
   63                 case fft_type::composite: composite_FFT(x, y, s_in, s_out, sRoot, dir); 
break;
 
   64                 case fft_type::discrete: DFT(x, y, s_in, s_out, sRoot, dir); 
break;
 
   65                 case fft_type::rader: rader_FFT(x, y, s_in, s_out, sRoot, dir, root, root_inv); 
break;
 
   66                 default: 
throw std::runtime_error(
"Invalid Fourier Transform Form!\n");
 
   77 template<
typename F, 
int L>
 
   83     complex_vector<F,L> workspace; 
 
   85     biFuncNode(
size_t sig_size, fft_type type): sz(sig_size), fptr(type), workspace(sig_size) {}; 
 
   86     biFuncNode(
size_t sig_size, 
size_t a, 
size_t ainv): fptr(a,ainv), workspace(sig_size) {};
 
   90 template<
typename F, 
int L>
 
   93         sig_out[0] = sig_in[0];
 
   98     Complex<F,L> w0 = omega<F,L>::get(1, size, dir);
 
  101     Complex<F,L> wk = w0;
 
  104     Complex<F,L> wkn = w0;
 
  107     Complex<F,L> tmp = sig_in[0];
 
  108     for(
size_t n = 1; n < size; n++) {
 
  109         tmp += sig_in[n*s_in];
 
  114     for(
size_t k = 1; k < size; k++) {
 
  119         for(
size_t n = 1; n < size; n++) {
 
  120             tmp = wkn.
fmadd(sig_in[n*s_in], tmp);
 
  123         sig_out[k*s_out] = tmp;
 
  132 template<
typename F, 
int L>
 
  133 inline void DFT(Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, biFuncNode<F,L>* sLeaf, 
direction dir) {
 
  134     DFT_helper(sLeaf->sz, x, y, s_in, s_out, dir);
 
  138 template<
typename F, 
int L>
 
  139 inline void pow2_FFT_helper(
size_t N, Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, 
direction dir) {
 
  142         y[    0] = x[0] + x[s_in];
 
  143         y[s_out] = x[0] - x[s_in];
 
  151     pow2_FFT_helper(m, x, y, s_in*2, s_out, dir);
 
  152     pow2_FFT_helper(m, x+s_in, y+s_out*m, s_in*2, s_out, dir);
 
  155     Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
 
  156     Complex<F,L> wj = w1;
 
  157     Complex<F,L> y_j = y[0];
 
  160     y[m*s_out] = y_j - y[m*s_out];
 
  163     for(
int j = 1; j < m; j++) {
 
  164         int j_stride = j*s_out;
 
  165         int jm_stride = (j+m)*s_out;
 
  167         y[j_stride]  = y_j + wj*y[jm_stride];
 
  168         y[jm_stride] = y_j - wj*y[jm_stride];
 
  174 template<
typename F, 
int L>
 
  175 inline void pow2_FFT(Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, biFuncNode<F,L>* sRoot, 
direction dir) {
 
  176     const size_t N = sRoot->sz; 
 
  177     pow2_FFT_helper(N, x, y, s_in, s_out, dir); 
 
  181 template<
typename F, 
int L>
 
  182 inline void pow4_FFT_helper(
size_t N, Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, 
direction dir) {
 
  186             y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
 
  187             y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
 
  188             y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
 
  189             y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
 
  191             y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
 
  192             y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
 
  193             y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
 
  194             y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
 
  203     pow4_FFT_helper(m, x         , y            , s_in*4, s_out, dir);
 
  204     pow4_FFT_helper(m, x +   s_in, y +   s_out*m, s_in*4, s_out, dir);
 
  205     pow4_FFT_helper(m, x + 2*s_in, y + 2*s_out*m, s_in*4, s_out, dir);
 
  206     pow4_FFT_helper(m, x + 3*s_in, y + 3*s_out*m, s_in*4, s_out, dir);
 
  209     Complex<F,L> w1 = omega<F,L>::get(1,N,dir);
 
  210     Complex<F,L> w2 = w1*w1;
 
  211     Complex<F,L> w3 = w2*w1;
 
  212     Complex<F,L> wk1 = w1; Complex<F,L> wk2 = w2; Complex<F,L> wk3 = w3;
 
  217     Complex<F,L> y_k0 = y[k0];
 
  218     Complex<F,L> y_k1 = y[k1];
 
  219     Complex<F,L> y_k2 = y[k2];
 
  220     Complex<F,L> y_k3 = y[k3];
 
  223         y[k0] = y_k0 + y_k2 + (y_k1 + y_k3);
 
  224         y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
 
  225         y[k2] = y_k0 + y_k2 - (y_k1 + y_k3);
 
  226         y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
 
  227         for(
int k = 1; k < m; k++) {
 
  230             k2 = (k + 2*m)*s_out;
 
  231             k3 = (k + 3*m)*s_out;
 
  236             y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
 
  237             y[k1] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
 
  238             y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
 
  239             y[k3] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
 
  240             wk1 *= w1; wk2 *= w2; wk3 *= w3;
 
  244         y[k0] = y_k0 + y_k2 +  y_k1 + y_k3;
 
  245         y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
 
  246         y[k2] = y_k0 + y_k2 -  y_k1 - y_k3;
 
  247         y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
 
  248         for(
int k = 1; k < m; k++) {
 
  251             k2 = (k + 2*m)*s_out;
 
  252             k3 = (k + 3*m)*s_out;
 
  257             y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
 
  258             y[k1] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
 
  259             y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
 
  260             y[k3] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
 
  261             wk1 *= w1; wk2 *= w2; wk3 *= w3;
 
  267 template<
typename F, 
int L>
 
  268 inline void pow4_FFT(Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, biFuncNode<F,L>* sRoot, 
direction dir) {
 
  269     const size_t N = sRoot->sz; 
 
  270     pow4_FFT_helper(N, x, y, s_in, s_out, dir); 
 
  274 template<
typename F, 
int L>
 
  275 inline void composite_FFT(Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, biFuncNode<F,L>* sRoot, 
direction dir) {
 
  277     size_t N = sRoot->sz;
 
  280     biFuncNode<F,L>* left = sRoot + sRoot->left;
 
  281     biFuncNode<F,L>* right = sRoot + sRoot->right;
 
  284     size_t N1 = left->sz;
 
  285     size_t N2 = right->sz;
 
  289     Complex<F,L>* z  = sRoot->workspace.data();
 
  291     Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
 
  292     Complex<F,L> wj1 = Complex<F,L>(1., 0.);
 
  293     for(
size_t j1 = 0; j1 < N1; j1++) {
 
  294         Complex<F,L> wk2 = wj1;
 
  295         right->fptr(&x[j1*s_in], &z[N2*j1], N1*s_in, 1, right, dir);
 
  297             for(
size_t k2 = 1; k2 < N2; k2++) {
 
  298                 z[j1*N2 + k2] *= wk2;
 
  309     for(
size_t k2 = 0; k2 < N2; k2++) {
 
  310         left->fptr(&z[k2], &y[k2*s_out], N2, N2*s_out, left, dir);
 
  315 inline size_t referenceFactor(
const size_t f) {
 
  317     if((f & 0x1) == 0) 
return 2;
 
  320     for(
size_t k = 3; k*k < f; k+=2) {
 
  321         if( f % k == 0 ) 
return k;
 
  329 template<
typename F, 
int L>
 
  330 inline void rader_FFT(Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, biFuncNode<F,L>* sRoot, 
direction dir, 
size_t a, 
size_t ainv) {
 
  332     size_t p = sRoot->sz;
 
  335     biFuncNode<F,L>* subFFT = sRoot + sRoot->left;
 
  338     Complex<F,L>* z = sRoot->workspace.data();
 
  342     Complex<F,L> y0 = x[0];
 
  345     for(
size_t k = 0; k < (p-1); k++) {
 
  346         y[k*s_out] = x[akinv*s_in];
 
  347         y0 = y0 + y[k*s_out];
 
  349         akinv = (akinv*ainv) % p;
 
  355     subFFT->fptr(y, &z[0], s_out, 1, subFFT, dir);
 
  358     for(
size_t m = 0; m < (p-1); m++) {
 
  359         Complex<F,L> Cm = omega<F,L>::get(1, p, dir);
 
  361         for(
size_t k = 1; k < (p-1); k++) {
 
  362             Cm = Cm + omega<F,L>::get(p*(k*m+ak) - ak, p*(p-1), dir);
 
  365         y[m*s_out] = z[m]*Cm;
 
  369     subFFT->fptr(y, &z[0], s_out, 1, subFFT, (
direction) (-1*((int) dir)));
 
  374     for(
size_t m = 0; m < (p-1); m++) {
 
  375         y[ak*s_out] = x[0] + (z[m]/((double) (p-1)));
 
  381 template<
typename F, 
int L>
 
  382 inline void pow3_FFT_helper(
size_t N, Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, 
direction dir, Complex<F,L>& plus120, Complex<F,L>& minus120) {
 
  386         y[0] = x[0] + x[s_in] + x[2*s_in];
 
  387         y[s_out] = x[0] + plus120*x[s_in] + minus120*x[2*s_in];
 
  388         y[2*s_out] = x[0] + minus120*x[s_in] + plus120*x[2*s_in];
 
  396     pow3_FFT_helper(Nprime, x, y, s_in*3, s_out, dir, plus120, minus120);
 
  397     pow3_FFT_helper(Nprime, x+s_in, y+Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
 
  398     pow3_FFT_helper(Nprime, x+2*s_in, y+2*Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
 
  401     Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
 
  402     Complex<F,L> w2 = w1*w1;
 
  403     Complex<F,L> wk1 = w1;
 
  404     Complex<F,L> wk2 = w2;
 
  407     int k2 =   Nprime * s_out;
 
  408     int k3 = 2*Nprime * s_out;
 
  410     Complex<F,L> tmpk     = y[k1];
 
  411     Complex<F,L> tmpk_p_1 = y[k2];
 
  412     Complex<F,L> tmpk_p_2 = y[k3];
 
  414     y[k1] =                tmpk_p_2 +               tmpk_p_1+ tmpk;
 
  415     y[k2] = minus120.fmadd(tmpk_p_2, plus120.fmadd( tmpk_p_1, tmpk));
 
  416     y[k3] = plus120.fmadd( tmpk_p_2, minus120.fmadd(tmpk_p_1, tmpk));
 
  418     for(
size_t k = 1; k < Nprime; k++) {
 
  421         k2 = (  Nprime + k) * s_out;
 
  422         k3 = (2*Nprime + k) * s_out;
 
  430         y[k1] = wk2.fmadd(           tmpk_p_2, wk1.fmadd(           tmpk_p_1, tmpk));
 
  431         y[k2] = wk2.fmadd(minus120 * tmpk_p_2, wk1.fmadd(plus120  * tmpk_p_1, tmpk));
 
  432         y[k3] = wk2.fmadd(plus120  * tmpk_p_2, wk1.fmadd(minus120 * tmpk_p_1, tmpk));
 
  435         wk1 *= w1; wk2 *= w2;
 
  440 template<
typename F, 
int L>
 
  441 inline void pow3_FFT(Complex<F,L>* x, Complex<F,L>* y, 
size_t s_in, 
size_t s_out, biFuncNode<F,L>* sRoot, 
direction dir) {
 
  442     const size_t N = sRoot->sz;
 
  443     Complex<F,L> plus120 (-0.5, -sqrt(3)/2.);
 
  444     Complex<F,L> minus120 (-0.5, sqrt(3)/2.);
 
  446         case direction::forward:  pow3_FFT_helper(N, x, y, s_in, s_out, dir, plus120, minus120); 
break;
 
  447         case direction::backward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, minus120, plus120); 
break;
 
Custom complex type taking advantage of vectorization A Complex Type intrinsic to HeFFTe that takes a...
Definition: heffte_stock_complex.h:29
Complex< F, L > fmadd(Complex< F, L > const &y, Complex< F, L > const &z)
Fused multiply add.
Definition: heffte_stock_complex.h:155
direction
Indicates the direction of the FFT (internal use only).
Definition: heffte_common.h:535
@ backward
Inverse DFT transform.
@ forward
Forward DFT transform.
Namespace containing all HeFFTe methods and classes.
Definition: heffte_backend_cuda.h:38
int direction_sign(direction dir)
Find the sign given a direction.
Definition: heffte_stock_algos.h:21
Class to represent the call-graph nodes.
Definition: heffte_stock_algos.h:78
Create a stock Complex representation of a twiddle factor.
Definition: heffte_stock_algos.h:35