Highly Efficient FFT for Exascale: HeFFTe v2.3
heffte_stock_algos.h
1 /*
2  -- heFFTe --
3  Univ. of Tennessee, Knoxville
4  @date
5 */
6 
7 #ifndef HEFFTE_STOCK_ALGOS_H
8 #define HEFFTE_STOCK_ALGOS_H
9 
10 #include <cmath>
11 
12 #include "heffte_stock_complex.h"
13 #include "heffte_stock_allocator.h"
14 #include "heffte_common.h"
15 
17 #define HEFFTE_STOCK_THRESHOLD 1
18 
19 namespace heffte {
21 inline int direction_sign(direction dir) {
22  return (dir == direction::forward) ? -1 : 1;
23 }
24 
25 namespace stock {
26 // Need forward declaration for the using directive
27 template<typename F, int L>
28 struct biFuncNode;
29 
30 // Functions in the algos implementation file facing externally
34 template<typename F, int L>
35 struct omega {
36  static inline Complex<F, L> get(size_t power, size_t N, direction dir) {
37  F a = 2.*M_PI*((F) power)/((F) N);
38  return Complex<F,L>(cos(a), direction_sign(dir)*sin(a));
39  }
40 };
41 
43 enum fft_type {pow2, pow3, pow4, composite, discrete, rader};
44 
49 template<typename F, int L>
51  protected:
52  fft_type type;
53  size_t root = 0;
54  size_t root_inv = 0;
55  public:
56  explicit Fourier_Transform(fft_type fft): type(fft) {}
57  explicit Fourier_Transform(size_t a, size_t ainv): type(fft_type::rader), root(a), root_inv(ainv) { }
58  void operator()(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
59  switch(type) {
60  case fft_type::pow2: pow2_FFT(x, y, s_in, s_out, sRoot, dir); break;
61  case fft_type::pow3: pow3_FFT(x, y, s_in, s_out, sRoot, dir); break;
62  case fft_type::pow4: pow4_FFT(x, y, s_in, s_out, sRoot, dir); break;
63  case fft_type::composite: composite_FFT(x, y, s_in, s_out, sRoot, dir); break;
64  case fft_type::discrete: DFT(x, y, s_in, s_out, sRoot, dir); break;
65  case fft_type::rader: rader_FFT(x, y, s_in, s_out, sRoot, dir, root, root_inv); break;
66  default: throw std::runtime_error("Invalid Fourier Transform Form!\n");
67  }
68  }
69 };
70 
77 template<typename F, int L>
78 struct biFuncNode {
79  size_t sz = 0; // Size of FFT
80  Fourier_Transform<F,L> fptr; // FFT for this call
81  size_t left = 0; // Offset in array until left child
82  size_t right = 0; // Offset in array until right child
83  complex_vector<F,L> workspace; // Workspace
84  biFuncNode(): fptr(fft_type::discrete) {};
85  biFuncNode(size_t sig_size, fft_type type): sz(sig_size), fptr(type), workspace(sig_size) {}; // Create default constructor
86  biFuncNode(size_t sig_size, size_t a, size_t ainv): fptr(a,ainv), workspace(sig_size) {};
87 };
88 
89 // Internal helper function to perform a DFT
90 template<typename F, int L>
91 inline void DFT_helper(size_t size, Complex<F,L>* sig_in, Complex<F,L>* sig_out, size_t s_in, size_t s_out, direction dir) {
92  if(size == 1) {
93  sig_out[0] = sig_in[0];
94  return;
95  }
96 
97  // Twiddle with smallest numerator
98  Complex<F,L> w0 = omega<F,L>::get(1, size, dir);
99 
100  // Base twiddle for each outer iteration
101  Complex<F,L> wk = w0;
102 
103  // Twiddle for inner iterations
104  Complex<F,L> wkn = w0;
105 
106  // Calculate first element of output
107  Complex<F,L> tmp = sig_in[0];
108  for(size_t n = 1; n < size; n++) {
109  tmp += sig_in[n*s_in];
110  }
111  sig_out[0] = tmp;
112 
113  // Initialize rest of output
114  for(size_t k = 1; k < size; k++) {
115  // Initialize kth output
116  tmp = sig_in[0];
117 
118  // Calculate kth output
119  for(size_t n = 1; n < size; n++) {
120  tmp = wkn.fmadd(sig_in[n*s_in], tmp);
121  wkn *= wk;
122  }
123  sig_out[k*s_out] = tmp;
124 
125  // "Increment" wk and "reset" wkn
126  wk *= w0;
127  wkn = wk;
128  }
129 }
130 
131 // External-facing function to properly call the internal DFT function
132 template<typename F, int L>
133 inline void DFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sLeaf, direction dir) {
134  DFT_helper(sLeaf->sz, x, y, s_in, s_out, dir);
135 }
136 
137 // Recursive helper function implementing a classic C-T FFT
138 template<typename F, int L>
139 inline void pow2_FFT_helper(size_t N, Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, direction dir) {
140  // Trivial case
141  if(N == 2) {
142  y[ 0] = x[0] + x[s_in];
143  y[s_out] = x[0] - x[s_in];
144  return;
145  }
146 
147  // Size of sub-problem
148  int m = N/2;
149 
150  // Divide into two sub-problems
151  pow2_FFT_helper(m, x, y, s_in*2, s_out, dir);
152  pow2_FFT_helper(m, x+s_in, y+s_out*m, s_in*2, s_out, dir);
153 
154  // Twiddle Factor
155  Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
156  Complex<F,L> wj = w1;
157  Complex<F,L> y_j = y[0];
158 
159  y[0] += y[m*s_out];
160  y[m*s_out] = y_j - y[m*s_out];
161 
162  // Conquer larger problem accordingly
163  for(int j = 1; j < m; j++) {
164  int j_stride = j*s_out;
165  int jm_stride = (j+m)*s_out;
166  y_j = y[j_stride];
167  y[j_stride] = y_j + wj*y[jm_stride];
168  y[jm_stride] = y_j - wj*y[jm_stride];
169  wj *= w1;
170  }
171 }
172 
173 // External function to call the C-T radix-2 FFT
174 template<typename F, int L>
175 inline void pow2_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
176  const size_t N = sRoot->sz; // Size of problem
177  pow2_FFT_helper(N, x, y, s_in, s_out, dir); // Call the radix-2 FFT
178 }
179 
180 // Recursive helper function implementing a classic C-T FFT
181 template<typename F, int L>
182 inline void pow4_FFT_helper(size_t N, Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, direction dir) {
183  // Trivial case
184  if(N == 4) {
185  if(dir == direction::forward) {
186  y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
187  y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
188  y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
189  y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
190  } else {
191  y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
192  y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
193  y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
194  y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
195  }
196  return;
197  }
198 
199  // Size of sub-problem
200  int m = N/4;
201 
202  // Divide into two sub-problems
203  pow4_FFT_helper(m, x , y , s_in*4, s_out, dir);
204  pow4_FFT_helper(m, x + s_in, y + s_out*m, s_in*4, s_out, dir);
205  pow4_FFT_helper(m, x + 2*s_in, y + 2*s_out*m, s_in*4, s_out, dir);
206  pow4_FFT_helper(m, x + 3*s_in, y + 3*s_out*m, s_in*4, s_out, dir);
207 
208  // Twiddle Factors
209  Complex<F,L> w1 = omega<F,L>::get(1,N,dir);
210  Complex<F,L> w2 = w1*w1;
211  Complex<F,L> w3 = w2*w1;
212  Complex<F,L> wk1 = w1; Complex<F,L> wk2 = w2; Complex<F,L> wk3 = w3;
213  int k0 = 0;
214  int k1 = m*s_out;
215  int k2 = 2*m*s_out;
216  int k3 = 3*m*s_out;
217  Complex<F,L> y_k0 = y[k0];
218  Complex<F,L> y_k1 = y[k1];
219  Complex<F,L> y_k2 = y[k2];
220  Complex<F,L> y_k3 = y[k3];
221  // Conquer larger problem accordingly
222  if(dir == direction::forward) {
223  y[k0] = y_k0 + y_k2 + (y_k1 + y_k3);
224  y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
225  y[k2] = y_k0 + y_k2 - (y_k1 + y_k3);
226  y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
227  for(int k = 1; k < m; k++) {
228  k0 = (k )*s_out;
229  k1 = (k + m)*s_out;
230  k2 = (k + 2*m)*s_out;
231  k3 = (k + 3*m)*s_out;
232  y_k0 = y[k0];
233  y_k1 = y[k1];
234  y_k2 = y[k2];
235  y_k3 = y[k3];
236  y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
237  y[k1] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
238  y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
239  y[k3] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
240  wk1 *= w1; wk2 *= w2; wk3 *= w3;
241  }
242  }
243  else {
244  y[k0] = y_k0 + y_k2 + y_k1 + y_k3;
245  y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
246  y[k2] = y_k0 + y_k2 - y_k1 - y_k3;
247  y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
248  for(int k = 1; k < m; k++) {
249  k0 = (k )*s_out;
250  k1 = (k + m)*s_out;
251  k2 = (k + 2*m)*s_out;
252  k3 = (k + 3*m)*s_out;
253  y_k0 = y[k0];
254  y_k1 = y[k1];
255  y_k2 = y[k2];
256  y_k3 = y[k3];
257  y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
258  y[k1] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
259  y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
260  y[k3] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
261  wk1 *= w1; wk2 *= w2; wk3 *= w3;
262  }
263  }
264 }
265 
266 // External function to call the C-T radix-4 FFT
267 template<typename F, int L>
268 inline void pow4_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
269  const size_t N = sRoot->sz; // Size of problem
270  pow4_FFT_helper(N, x, y, s_in, s_out, dir); // Call the radix-2 FFT
271 }
272 
273 // External & Internal function for radix-N1 C-T FFTs
274 template<typename F, int L>
275 inline void composite_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
276  // Retrieve N
277  size_t N = sRoot->sz;
278 
279  // Find the children on the call-graph
280  biFuncNode<F,L>* left = sRoot + sRoot->left;
281  biFuncNode<F,L>* right = sRoot + sRoot->right;
282 
283  // Get the size of the sub-problems
284  size_t N1 = left->sz;
285  size_t N2 = right->sz;
286 
287  // I'm currently using a temporary storage space malloc'd in recursive calls.
288  // This isn't optimal and will change as the engine develops
289  Complex<F,L>* z = sRoot->workspace.data();
290  // Find the FFT of the "rows" of the input signal and twiddle them accordingly
291  Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
292  Complex<F,L> wj1 = Complex<F,L>(1., 0.);
293  for(size_t j1 = 0; j1 < N1; j1++) {
294  Complex<F,L> wk2 = wj1;
295  right->fptr(&x[j1*s_in], &z[N2*j1], N1*s_in, 1, right, dir);
296  if(j1 > 0) {
297  for(size_t k2 = 1; k2 < N2; k2++) {
298  z[j1*N2 + k2] *= wk2;
299  wk2 *= wj1;
300  }
301  }
302  wj1 *= w1;
303  }
304 
305  /* Take the FFT of the "columns" of the output from the above "row" FFTs.
306  * Don't need j1 for the second transform as it's just the indexer.
307  * Take strides of N2 since z is allocated on the fly in this function for N.
308  */
309  for(size_t k2 = 0; k2 < N2; k2++) {
310  left->fptr(&z[k2], &y[k2*s_out], N2, N2*s_out, left, dir);
311  }
312 }
313 
314 // A factoring function for the reference composite FFT
315 inline size_t referenceFactor(const size_t f) {
316  // Check if it's even
317  if((f & 0x1) == 0) return 2;
318 
319  // Check all odd numbers after that
320  for(size_t k = 3; k*k < f; k+=2) {
321  if( f % k == 0 ) return k;
322  }
323 
324  // return f if no factor was found
325  return f;
326 }
327 
328 // Implementation for Rader's Algorithm
329 template<typename F, int L>
330 inline void rader_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir, size_t a, size_t ainv) {
331  // Size of the problem
332  size_t p = sRoot->sz;
333 
334  // Find the children on the call-graph
335  biFuncNode<F,L>* subFFT = sRoot + sRoot->left;
336 
337  // Temporary workspace
338  Complex<F,L>* z = sRoot->workspace.data();
339  // Loop variables
340  int ak = 1;
341  int akinv = 1;
342  Complex<F,L> y0 = x[0];
343 
344  // First, "invert" the order of x
345  for(size_t k = 0; k < (p-1); k++) {
346  y[k*s_out] = x[akinv*s_in];
347  y0 = y0 + y[k*s_out];
348  ak = (ak*a) % p;
349  akinv = (akinv*ainv) % p;
350  }
351 
352  // Convolve the resulting vector with twiddle vector
353 
354  // First fft the resulting shuffled vector
355  subFFT->fptr(y, &z[0], s_out, 1, subFFT, dir);
356 
357  // Perform cyclic convolution
358  for(size_t m = 0; m < (p-1); m++) {
359  Complex<F,L> Cm = omega<F,L>::get(1, p, dir);
360  ak = a;
361  for(size_t k = 1; k < (p-1); k++) {
362  Cm = Cm + omega<F,L>::get(p*(k*m+ak) - ak, p*(p-1), dir);
363  ak = (ak*a) % p;
364  }
365  y[m*s_out] = z[m]*Cm;
366  }
367 
368  // Bring back into signal domain
369  subFFT->fptr(y, &z[0], s_out, 1, subFFT, (direction) (-1*((int) dir)));
370 
371  // Shuffle as needed
372  ak = 1;
373  y[0] = y0;
374  for(size_t m = 0; m < (p-1); m++) {
375  y[ak*s_out] = x[0] + (z[m]/((double) (p-1)));
376  ak = (ak*a) % p;
377  }
378 }
379 
380 // Internal recursive helper-function that calculates the FFT of a signal with length 3^k
381 template<typename F, int L>
382 inline void pow3_FFT_helper(size_t N, Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, direction dir, Complex<F,L>& plus120, Complex<F,L>& minus120) {
383 
384  // Calculate the DFT manually if necessary
385  if(N == 3) {
386  y[0] = x[0] + x[s_in] + x[2*s_in];
387  y[s_out] = x[0] + plus120*x[s_in] + minus120*x[2*s_in];
388  y[2*s_out] = x[0] + minus120*x[s_in] + plus120*x[2*s_in];
389  return;
390  }
391 
392  // Calculate the size of the sub-problem
393  size_t Nprime = N/3;
394 
395  // Divide into sub-problems
396  pow3_FFT_helper(Nprime, x, y, s_in*3, s_out, dir, plus120, minus120);
397  pow3_FFT_helper(Nprime, x+s_in, y+Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
398  pow3_FFT_helper(Nprime, x+2*s_in, y+2*Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
399 
400  // Combine the sub-problem solutions
401  Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
402  Complex<F,L> w2 = w1*w1;
403  Complex<F,L> wk1 = w1;
404  Complex<F,L> wk2 = w2;
405 
406  int k1 = 0;
407  int k2 = Nprime * s_out;
408  int k3 = 2*Nprime * s_out;
409 
410  Complex<F,L> tmpk = y[k1];
411  Complex<F,L> tmpk_p_1 = y[k2];
412  Complex<F,L> tmpk_p_2 = y[k3];
413 
414  y[k1] = tmpk_p_2 + tmpk_p_1+ tmpk;
415  y[k2] = minus120.fmadd(tmpk_p_2, plus120.fmadd( tmpk_p_1, tmpk));
416  y[k3] = plus120.fmadd( tmpk_p_2, minus120.fmadd(tmpk_p_1, tmpk));
417 
418  for(size_t k = 1; k < Nprime; k++) {
419  // Index calculation
420  k1 = k * s_out;
421  k2 = ( Nprime + k) * s_out;
422  k3 = (2*Nprime + k) * s_out;
423 
424  // Storing temporary variables
425  tmpk = y[k1];
426  tmpk_p_1 = y[k2];
427  tmpk_p_2 = y[k3];
428 
429  // Reassigning the output
430  y[k1] = wk2.fmadd( tmpk_p_2, wk1.fmadd( tmpk_p_1, tmpk));
431  y[k2] = wk2.fmadd(minus120 * tmpk_p_2, wk1.fmadd(plus120 * tmpk_p_1, tmpk));
432  y[k3] = wk2.fmadd(plus120 * tmpk_p_2, wk1.fmadd(minus120 * tmpk_p_1, tmpk));
433 
434  // Twiddle factors
435  wk1 *= w1; wk2 *= w2;
436  }
437 }
438 
439 // External-facing function for performing an FFT on signal with length N = 3^k
440 template<typename F, int L>
441 inline void pow3_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
442  const size_t N = sRoot->sz;
443  Complex<F,L> plus120 (-0.5, -sqrt(3)/2.);
444  Complex<F,L> minus120 (-0.5, sqrt(3)/2.);
445  switch(dir) {
446  case direction::forward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, plus120, minus120); break;
447  case direction::backward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, minus120, plus120); break;
448  }
449 }
450 
451 }
452 }
453 
454 #endif // END HEFFTE_STOCK_ALGOS_H
Custom complex type taking advantage of vectorization A Complex Type intrinsic to HeFFTe that takes a...
Definition: heffte_stock_complex.h:29
Complex< F, L > fmadd(Complex< F, L > const &y, Complex< F, L > const &z)
Fused multiply add.
Definition: heffte_stock_complex.h:155
Functor class to represent any Fourier Transform A class to use lambdas to switch between what FFT sh...
Definition: heffte_stock_algos.h:50
direction
Indicates the direction of the FFT (internal use only).
Definition: heffte_common.h:535
@ backward
Inverse DFT transform.
@ forward
Forward DFT transform.
Namespace containing all HeFFTe methods and classes.
Definition: heffte_backend_cuda.h:38
int direction_sign(direction dir)
Find the sign given a direction.
Definition: heffte_stock_algos.h:21
Class to represent the call-graph nodes.
Definition: heffte_stock_algos.h:78
Create a stock Complex representation of a twiddle factor.
Definition: heffte_stock_algos.h:35