7 #ifndef HEFFTE_STOCK_ALGOS_H
8 #define HEFFTE_STOCK_ALGOS_H
12 #include "heffte_stock_complex.h"
13 #include "heffte_stock_allocator.h"
14 #include "heffte_common.h"
17 #define HEFFTE_STOCK_THRESHOLD 1
27 template<
typename F,
int L>
34 template<
typename F,
int L>
37 F a = 2.*M_PI*((F) power)/((F) N);
43 enum fft_type {pow2, pow3, pow4, composite, discrete, rader};
49 template<
typename F,
int L>
57 explicit Fourier_Transform(
size_t a,
size_t ainv): type(fft_type::rader), root(a), root_inv(ainv) { }
60 case fft_type::pow2: pow2_FFT(x, y, s_in, s_out, sRoot, dir);
break;
61 case fft_type::pow3: pow3_FFT(x, y, s_in, s_out, sRoot, dir);
break;
62 case fft_type::pow4: pow4_FFT(x, y, s_in, s_out, sRoot, dir);
break;
63 case fft_type::composite: composite_FFT(x, y, s_in, s_out, sRoot, dir);
break;
64 case fft_type::discrete: DFT(x, y, s_in, s_out, sRoot, dir);
break;
65 case fft_type::rader: rader_FFT(x, y, s_in, s_out, sRoot, dir, root, root_inv);
break;
66 default:
throw std::runtime_error(
"Invalid Fourier Transform Form!\n");
77 template<
typename F,
int L>
83 complex_vector<F,L> workspace;
85 biFuncNode(
size_t sig_size, fft_type type): sz(sig_size), fptr(type), workspace(sig_size) {};
86 biFuncNode(
size_t sig_size,
size_t a,
size_t ainv): fptr(a,ainv), workspace(sig_size) {};
90 template<
typename F,
int L>
93 sig_out[0] = sig_in[0];
98 Complex<F,L> w0 = omega<F,L>::get(1, size, dir);
101 Complex<F,L> wk = w0;
104 Complex<F,L> wkn = w0;
107 Complex<F,L> tmp = sig_in[0];
108 for(
size_t n = 1; n < size; n++) {
109 tmp += sig_in[n*s_in];
114 for(
size_t k = 1; k < size; k++) {
119 for(
size_t n = 1; n < size; n++) {
120 tmp = wkn.
fmadd(sig_in[n*s_in], tmp);
123 sig_out[k*s_out] = tmp;
132 template<
typename F,
int L>
133 inline void DFT(Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out, biFuncNode<F,L>* sLeaf,
direction dir) {
134 DFT_helper(sLeaf->sz, x, y, s_in, s_out, dir);
138 template<
typename F,
int L>
139 inline void pow2_FFT_helper(
size_t N, Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out,
direction dir) {
142 y[ 0] = x[0] + x[s_in];
143 y[s_out] = x[0] - x[s_in];
151 pow2_FFT_helper(m, x, y, s_in*2, s_out, dir);
152 pow2_FFT_helper(m, x+s_in, y+s_out*m, s_in*2, s_out, dir);
155 Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
156 Complex<F,L> wj = w1;
157 Complex<F,L> y_j = y[0];
160 y[m*s_out] = y_j - y[m*s_out];
163 for(
int j = 1; j < m; j++) {
164 int j_stride = j*s_out;
165 int jm_stride = (j+m)*s_out;
167 y[j_stride] = y_j + wj*y[jm_stride];
168 y[jm_stride] = y_j - wj*y[jm_stride];
174 template<
typename F,
int L>
175 inline void pow2_FFT(Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out, biFuncNode<F,L>* sRoot,
direction dir) {
176 const size_t N = sRoot->sz;
177 pow2_FFT_helper(N, x, y, s_in, s_out, dir);
181 template<
typename F,
int L>
182 inline void pow4_FFT_helper(
size_t N, Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out,
direction dir) {
186 y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
187 y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
188 y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
189 y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
191 y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
192 y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
193 y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
194 y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
203 pow4_FFT_helper(m, x , y , s_in*4, s_out, dir);
204 pow4_FFT_helper(m, x + s_in, y + s_out*m, s_in*4, s_out, dir);
205 pow4_FFT_helper(m, x + 2*s_in, y + 2*s_out*m, s_in*4, s_out, dir);
206 pow4_FFT_helper(m, x + 3*s_in, y + 3*s_out*m, s_in*4, s_out, dir);
209 Complex<F,L> w1 = omega<F,L>::get(1,N,dir);
210 Complex<F,L> w2 = w1*w1;
211 Complex<F,L> w3 = w2*w1;
212 Complex<F,L> wk1 = w1; Complex<F,L> wk2 = w2; Complex<F,L> wk3 = w3;
217 Complex<F,L> y_k0 = y[k0];
218 Complex<F,L> y_k1 = y[k1];
219 Complex<F,L> y_k2 = y[k2];
220 Complex<F,L> y_k3 = y[k3];
223 y[k0] = y_k0 + y_k2 + (y_k1 + y_k3);
224 y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
225 y[k2] = y_k0 + y_k2 - (y_k1 + y_k3);
226 y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
227 for(
int k = 1; k < m; k++) {
230 k2 = (k + 2*m)*s_out;
231 k3 = (k + 3*m)*s_out;
236 y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
237 y[k1] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
238 y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
239 y[k3] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
240 wk1 *= w1; wk2 *= w2; wk3 *= w3;
244 y[k0] = y_k0 + y_k2 + y_k1 + y_k3;
245 y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
246 y[k2] = y_k0 + y_k2 - y_k1 - y_k3;
247 y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
248 for(
int k = 1; k < m; k++) {
251 k2 = (k + 2*m)*s_out;
252 k3 = (k + 3*m)*s_out;
257 y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
258 y[k1] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
259 y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
260 y[k3] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
261 wk1 *= w1; wk2 *= w2; wk3 *= w3;
267 template<
typename F,
int L>
268 inline void pow4_FFT(Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out, biFuncNode<F,L>* sRoot,
direction dir) {
269 const size_t N = sRoot->sz;
270 pow4_FFT_helper(N, x, y, s_in, s_out, dir);
274 template<
typename F,
int L>
275 inline void composite_FFT(Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out, biFuncNode<F,L>* sRoot,
direction dir) {
277 size_t N = sRoot->sz;
280 biFuncNode<F,L>* left = sRoot + sRoot->left;
281 biFuncNode<F,L>* right = sRoot + sRoot->right;
284 size_t N1 = left->sz;
285 size_t N2 = right->sz;
289 Complex<F,L>* z = sRoot->workspace.data();
291 Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
292 Complex<F,L> wj1 = Complex<F,L>(1., 0.);
293 for(
size_t j1 = 0; j1 < N1; j1++) {
294 Complex<F,L> wk2 = wj1;
295 right->fptr(&x[j1*s_in], &z[N2*j1], N1*s_in, 1, right, dir);
297 for(
size_t k2 = 1; k2 < N2; k2++) {
298 z[j1*N2 + k2] *= wk2;
309 for(
size_t k2 = 0; k2 < N2; k2++) {
310 left->fptr(&z[k2], &y[k2*s_out], N2, N2*s_out, left, dir);
315 inline size_t referenceFactor(
const size_t f) {
317 if((f & 0x1) == 0)
return 2;
320 for(
size_t k = 3; k*k < f; k+=2) {
321 if( f % k == 0 )
return k;
329 template<
typename F,
int L>
330 inline void rader_FFT(Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out, biFuncNode<F,L>* sRoot,
direction dir,
size_t a,
size_t ainv) {
332 size_t p = sRoot->sz;
335 biFuncNode<F,L>* subFFT = sRoot + sRoot->left;
338 Complex<F,L>* z = sRoot->workspace.data();
342 Complex<F,L> y0 = x[0];
345 for(
size_t k = 0; k < (p-1); k++) {
346 y[k*s_out] = x[akinv*s_in];
347 y0 = y0 + y[k*s_out];
349 akinv = (akinv*ainv) % p;
355 subFFT->fptr(y, &z[0], s_out, 1, subFFT, dir);
358 for(
size_t m = 0; m < (p-1); m++) {
359 Complex<F,L> Cm = omega<F,L>::get(1, p, dir);
361 for(
size_t k = 1; k < (p-1); k++) {
362 Cm = Cm + omega<F,L>::get(p*(k*m+ak) - ak, p*(p-1), dir);
365 y[m*s_out] = z[m]*Cm;
369 subFFT->fptr(y, &z[0], s_out, 1, subFFT, (
direction) (-1*((int) dir)));
374 for(
size_t m = 0; m < (p-1); m++) {
375 y[ak*s_out] = x[0] + (z[m]/((double) (p-1)));
381 template<
typename F,
int L>
382 inline void pow3_FFT_helper(
size_t N, Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out,
direction dir, Complex<F,L>& plus120, Complex<F,L>& minus120) {
386 y[0] = x[0] + x[s_in] + x[2*s_in];
387 y[s_out] = x[0] + plus120*x[s_in] + minus120*x[2*s_in];
388 y[2*s_out] = x[0] + minus120*x[s_in] + plus120*x[2*s_in];
396 pow3_FFT_helper(Nprime, x, y, s_in*3, s_out, dir, plus120, minus120);
397 pow3_FFT_helper(Nprime, x+s_in, y+Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
398 pow3_FFT_helper(Nprime, x+2*s_in, y+2*Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
401 Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
402 Complex<F,L> w2 = w1*w1;
403 Complex<F,L> wk1 = w1;
404 Complex<F,L> wk2 = w2;
407 int k2 = Nprime * s_out;
408 int k3 = 2*Nprime * s_out;
410 Complex<F,L> tmpk = y[k1];
411 Complex<F,L> tmpk_p_1 = y[k2];
412 Complex<F,L> tmpk_p_2 = y[k3];
414 y[k1] = tmpk_p_2 + tmpk_p_1+ tmpk;
415 y[k2] = minus120.fmadd(tmpk_p_2, plus120.fmadd( tmpk_p_1, tmpk));
416 y[k3] = plus120.fmadd( tmpk_p_2, minus120.fmadd(tmpk_p_1, tmpk));
418 for(
size_t k = 1; k < Nprime; k++) {
421 k2 = ( Nprime + k) * s_out;
422 k3 = (2*Nprime + k) * s_out;
430 y[k1] = wk2.fmadd( tmpk_p_2, wk1.fmadd( tmpk_p_1, tmpk));
431 y[k2] = wk2.fmadd(minus120 * tmpk_p_2, wk1.fmadd(plus120 * tmpk_p_1, tmpk));
432 y[k3] = wk2.fmadd(plus120 * tmpk_p_2, wk1.fmadd(minus120 * tmpk_p_1, tmpk));
435 wk1 *= w1; wk2 *= w2;
440 template<
typename F,
int L>
441 inline void pow3_FFT(Complex<F,L>* x, Complex<F,L>* y,
size_t s_in,
size_t s_out, biFuncNode<F,L>* sRoot,
direction dir) {
442 const size_t N = sRoot->sz;
443 Complex<F,L> plus120 (-0.5, -sqrt(3)/2.);
444 Complex<F,L> minus120 (-0.5, sqrt(3)/2.);
446 case direction::forward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, plus120, minus120);
break;
447 case direction::backward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, minus120, plus120);
break;
Custom complex type taking advantage of vectorization A Complex Type intrinsic to HeFFTe that takes a...
Definition: heffte_stock_complex.h:29
Complex< F, L > fmadd(Complex< F, L > const &y, Complex< F, L > const &z)
Fused multiply add.
Definition: heffte_stock_complex.h:155
direction
Indicates the direction of the FFT (internal use only).
Definition: heffte_common.h:535
@ backward
Inverse DFT transform.
@ forward
Forward DFT transform.
Namespace containing all HeFFTe methods and classes.
Definition: heffte_backend_cuda.h:38
int direction_sign(direction dir)
Find the sign given a direction.
Definition: heffte_stock_algos.h:21
Class to represent the call-graph nodes.
Definition: heffte_stock_algos.h:78
Create a stock Complex representation of a twiddle factor.
Definition: heffte_stock_algos.h:35