Highly Efficient FFT for Exascale: HeFFTe v2.3
heffte_stock_vec_types.h
1 /*
2  -- heFFTe --
3  Univ. of Tennessee, Knoxville
4  @date
5 */
6 
7 #ifndef HEFFTE_STOCK_VEC_TYPES_H
8 #define HEFFTE_STOCK_VEC_TYPES_H
9 
10 #include "heffte_config.h"
11 
12 #ifdef __AVX__
13 #include <immintrin.h>
14 #endif
15 #include <complex>
16 
17 namespace heffte {
18 namespace stock {
20 template<typename T>
21 using is_float = std::is_same<float, typename std::remove_cv<T>::type>;
23 template<typename T>
24 using is_double = std::is_same<double, typename std::remove_cv<T>::type>;
26 template<typename T>
27 using is_fcomplex = std::is_same<std::complex<float>, typename std::remove_cv<T>::type>;
29 template<typename T>
30 using is_dcomplex = std::is_same<std::complex<double>, typename std::remove_cv<T>::type>;
31 
33 template<typename T> struct is_real {
34  static constexpr bool value = is_float<T>::value || is_double<T>::value;
35 };
36 
38 template<typename T> struct is_complex {
39  static constexpr bool value = is_fcomplex<T>::value || is_dcomplex<T>::value;
40 };
41 
43 template<typename T, int N> struct pack {};
45 template<> struct pack<float, 1> { using type = std::complex<float>; };
47 template<> struct pack<double, 1> { using type = std::complex<double>; };
48 
49 // Some simple operations that will be useful for vectorized types.
50 
55 template<typename F, int L>
56 inline typename pack<F,L>::type mm_zero(){return 0.0;}
61 template<typename F, int L>
62 inline typename pack<F,L>::type mm_load(F const *src) { return typename pack<F,L>::type {src[0], src[1]}; }
67 template<typename F, int L>
68 inline void mm_store(F *dest, typename pack<F,L>::type const &src) {
69  dest[0] = src.real(); dest[1] = src.imag();
70 }
75 template<typename F, int L>
76 inline typename pack<F,L>::type mm_pair_set(F x, F y) { return typename pack<F,L>::type(x, y); }
81 template<typename F, int L>
82 inline typename pack<F,L>::type mm_set1(F src) { return typename pack<F,L>::type(src,src); }
87 template<typename F, int L>
88 inline typename pack<F,L>::type mm_complex_load(std::complex<F> const *src) { return *src; }
93 template<typename F, int L>
94 inline typename pack<F,L>::type mm_complex_load(std::complex<F> const *src, int) { return *src; }
95 
96 // Real basic arithmetic for the "none" case
97 
99 inline typename pack<float, 1>::type mm_add(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a + b; }
101 inline typename pack<double, 1>::type mm_add(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a + b; }
103 inline typename pack<float, 1>::type mm_sub(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a - b; }
105 inline typename pack<double, 1>::type mm_sub(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a - b; }
107 inline typename pack<float, 1>::type mm_div(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a / b.real(); }
109 inline typename pack<double, 1>::type mm_div(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a / b.real(); }
111 inline typename pack<float, 1>::type mm_neg(typename pack<float, 1>::type const &a){ return -a; }
113 inline typename pack<double, 1>::type mm_neg(typename pack<double, 1>::type const &a){ return -a; }
115 inline typename pack<float, 1>::type mm_mul(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a * b.real(); }
117 inline typename pack<double, 1>::type mm_mul(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a * b.real(); }
119 inline typename pack<float, 1>::type mm_complex_mul(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a * b; }
121 inline typename pack<double, 1>::type mm_complex_mul(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a * b; }
123 inline typename pack<float, 1>::type mm_complex_fmadd(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b, typename pack<float, 1>::type const &c){ return a * b + c; }
125 inline typename pack<double, 1>::type mm_complex_fmadd(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b, typename pack<double, 1>::type const &c){ return a * b + c; }
127 inline typename pack<float, 1>::type mm_complex_fmsub(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b, typename pack<float, 1>::type const &c){ return a * b - c; }
129 inline typename pack<double, 1>::type mm_complex_fmsub(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b, typename pack<double, 1>::type const &c){ return a * b - c; }
131 inline typename pack<float, 1>::type mm_complex_mul_i(typename pack<float, 1>::type const &a){return a * std::complex<float>{0.f,1.f}; }
133 inline typename pack<double, 1>::type mm_complex_mul_i(typename pack<double, 1>::type const &a){ return a * std::complex<double>{0.,1.}; }
135 inline typename pack<float, 1>::type mm_complex_mul_neg_i(typename pack<float, 1>::type const &a){return a * std::complex<float>{0.f,-1.f}; }
137 inline typename pack<double, 1>::type mm_complex_mul_neg_i(typename pack<double, 1>::type const &a){ return a * std::complex<double>{0.,-1.}; }
139 inline typename pack<float, 1>::type mm_complex_sq_mod(typename pack<float,1>::type const &a){ return std::complex<float>{norm(a), norm(a)}; }
141 inline typename pack<double, 1>::type mm_complex_sq_mod(typename pack<double,1>::type const &a){ return std::complex<double>{norm(a), norm(a)}; }
143 inline typename pack<float, 1>::type mm_complex_mod(typename pack<float,1>::type const &a){ return std::complex<float>{std::abs(a), std::abs(a)}; }
145 inline typename pack<double, 1>::type mm_complex_mod(typename pack<double,1>::type const &a){ return std::complex<double>{std::abs(a), std::abs(a)}; }
147 inline typename pack<float, 1>::type mm_complex_conj(typename pack<float,1>::type const &a){ return conj(a); }
149 inline typename pack<double, 1>::type mm_complex_conj(typename pack<double,1>::type const &a){ return conj(a); }
151 inline typename pack<float, 1>::type mm_complex_div(typename pack<float, 1>::type const &a, typename pack<float, 1>::type const &b){ return a / b; }
153 inline typename pack<double, 1>::type mm_complex_div(typename pack<double, 1>::type const &a, typename pack<double, 1>::type const &b){ return a / b; }
154 
156 /* Below is functionality for vector packs */
158 
159 #ifdef Heffte_ENABLE_AVX
160 
162 template<> struct pack<double, 2> { using type = __m128d; };
164 template<> struct pack<float, 4> { using type = __m128; };
166 template<> struct pack<double, 4> { using type = __m256d; };
168 template<> struct pack<float, 8> { using type = __m256; };
169 
171 /* Below are structs for pack<float, 4> */
173 
175 template<>
176 inline typename pack<float, 4>::type mm_zero<float, 4>(){ return _mm_setzero_ps(); }
177 
179 template<>
180 inline typename pack<float, 4>::type mm_load<float, 4>(float const *src) { return _mm_loadu_ps(src); }
181 
183 template<>
184 inline void mm_store<float, 4>(float *dest, pack<float, 4>::type const &src) { _mm_storeu_ps(dest, src); }
185 
187 template<>
188 inline typename pack<float, 4>::type mm_pair_set<float, 4>(float x, float y) { return _mm_setr_ps(x, y, x, y); }
189 
191 template<>
192 inline typename pack<float, 4>::type mm_set1<float,4> (float x) { return _mm_set1_ps(x); }
193 
195 template<>
196 inline typename pack<float, 4>::type mm_complex_load<float,4>(std::complex<float> const *src, int stride) {
197  return _mm_setr_ps(src[0].real(), src[0].imag(), src[stride].real(), src[stride].imag());
198 }
200 template<>
201 inline typename pack<float, 4>::type mm_complex_load<float,4>(std::complex<float> const *src) {
202  return mm_complex_load<float,4>(src, 1);
203 }
204 
206 /* Below are structs for pack<float, 8> */
208 
210 template<>
211 inline typename pack<float, 8>::type mm_zero<float, 8>(){ return _mm256_setzero_ps(); }
212 
214 template<>
215 inline typename pack<float, 8>::type mm_load<float, 8>(float const *src) { return _mm256_loadu_ps(src); }
216 
218 template<>
219 inline void mm_store<float, 8>(float *dest, pack<float, 8>::type const &src) { _mm256_storeu_ps(dest, src); }
220 
222 template<>
223 inline typename pack<float, 8>::type mm_pair_set<float, 8>(float x, float y) { return _mm256_setr_ps(x, y, x, y, x, y, x, y); }
224 
226 template<>
227 inline typename pack<float, 8>::type mm_set1<float,8> (float x) { return _mm256_set1_ps(x); }
228 
230 template<>
231 inline typename pack<float, 8>::type mm_complex_load<float, 8>(std::complex<float> const *src, int stride) {
232  return _mm256_setr_ps(src[0*stride].real(), src[0*stride].imag(),
233  src[1*stride].real(), src[1*stride].imag(),
234  src[2*stride].real(), src[2*stride].imag(),
235  src[3*stride].real(), src[3*stride].imag());
236 }
238 template<>
239 inline typename pack<float, 8>::type mm_complex_load<float,8>(std::complex<float> const *src) {
240  return mm_complex_load<float,8>(src, 1);
241 }
242 
244 /* Below are structs for pack<double, 2> */
246 
248 template<>
249 inline typename pack<double, 2>::type mm_zero<double, 2>(){ return _mm_setzero_pd(); }
250 
252 template<>
253 inline typename pack<double, 2>::type mm_load<double, 2>(double const *src) { return _mm_loadu_pd(src); }
254 
256 template<>
257 inline void mm_store<double, 2>(double *dest, pack<double, 2>::type const &src) { _mm_storeu_pd(dest, src); }
258 
260 template<>
261 inline typename pack<double, 2>::type mm_pair_set<double, 2>(double x, double y) { return _mm_setr_pd(x, y); }
262 
264 template<>
265 inline typename pack<double, 2>::type mm_set1<double, 2>(double x) { return _mm_set1_pd(x); }
266 
268 template<>
269 inline typename pack<double,2>::type mm_complex_load<double,2>(std::complex<double> const *src, int) {
270  return _mm_setr_pd(src[0].real(), src[0].imag());
271 }
272 template<>
273 inline typename pack<double,2>::type mm_complex_load<double,2>(std::complex<double> const *src) {
274  return mm_complex_load<double,2>(src, 1);
275 }
276 
278 /* Below are structs for pack<double, 4> */
280 
282 template<>
283 inline typename pack<double, 4>::type mm_zero<double, 4>(){ return _mm256_setzero_pd(); }
284 
286 template<>
287 inline typename pack<double, 4>::type mm_load<double, 4>(double const *src) { return _mm256_loadu_pd(src); }
288 
290 template<>
291 inline void mm_store<double, 4>(double *dest, pack<double, 4>::type const &src) { _mm256_storeu_pd(dest, src); }
292 
294 template<>
295 inline typename pack<double, 4>::type mm_pair_set<double, 4>(double x, double y) { return _mm256_setr_pd(x, y, x, y); }
296 
298 template<>
299 inline typename pack<double, 4>::type mm_set1<double, 4>(double x) { return _mm256_set1_pd(x); }
300 
302 template<>
303 inline typename pack<double,4>::type mm_complex_load<double,4>(std::complex<double> const *src, int stride) {
304  return _mm256_setr_pd(src[0].real(), src[0].imag(), src[stride].real(), src[stride].imag());
305 }
307 template<>
308 inline typename pack<double,4>::type mm_complex_load<double,4>(std::complex<double> const *src) {
309  return mm_complex_load<double,4>(src, 1);
310 }
311 
313 /* Elementary operations for vector packs */
315 
316 /* Addition */
317 
319 inline pack<float, 4>::type mm_add(pack<float, 4>::type const &x,pack<float, 4>::type const &y) {
320  return _mm_add_ps(x, y);
321 }
322 
324 inline pack<float, 8>::type mm_add(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
325  return _mm256_add_ps(x, y);
326 }
327 
329 inline pack<double, 2>::type mm_add(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
330  return _mm_add_pd(x, y);
331 }
332 
334 inline pack<double, 4>::type mm_add(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
335  return _mm256_add_pd(x, y);
336 }
337 
338 /* Subtraction */
339 
341 inline pack<float, 4>::type mm_sub(pack<float, 4>::type const &x,pack<float, 4>::type const &y) {
342  return _mm_sub_ps(x, y);
343 }
344 
346 inline pack<float, 8>::type mm_sub(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
347  return _mm256_sub_ps(x, y);
348 }
349 
351 inline pack<double, 2>::type mm_sub(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
352  return _mm_sub_pd(x, y);
353 }
354 
356 inline pack<double, 4>::type mm_sub(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
357  return _mm256_sub_pd(x, y);
358 }
359 
360 /* Multiplication */
361 
363 inline pack<float, 4>::type mm_mul(pack<float, 4>::type const &x, pack<float, 4>::type const &y) {
364  return _mm_mul_ps(x, y);
365 }
366 
368 inline pack<float, 8>::type mm_mul(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
369  return _mm256_mul_ps(x, y);
370 }
371 
373 inline pack<double, 2>::type mm_mul(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
374  return _mm_mul_pd(x, y);
375 }
376 
378 inline pack<double, 4>::type mm_mul(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
379  return _mm256_mul_pd(x, y);
380 }
381 
382 /* Division */
383 
385 inline pack<float, 4>::type mm_div(pack<float, 4>::type const &x,pack<float, 4>::type const &y) {
386  return _mm_div_ps(x, y);
387 }
388 
390 inline pack<float, 8>::type mm_div(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
391  return _mm256_div_ps(x, y);
392 }
393 
395 inline pack<double, 2>::type mm_div(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
396  return _mm_div_pd(x, y);
397 }
398 
400 inline pack<double, 4>::type mm_div(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
401  return _mm256_div_pd(x, y);
402 }
403 
404 
405 /* Negation */
407 inline pack<float, 4>::type mm_neg(pack<float, 4>::type const &x) {
408  return _mm_xor_ps(x, (mm_set1<float, 4>(-0.f)));
409 }
410 
412 inline pack<float, 8>::type mm_neg(pack<float, 8>::type const &x) {
413  return _mm256_xor_ps(x, (mm_set1<float, 8>(-0.f)));
414 }
415 
417 inline pack<double, 2>::type mm_neg(pack<double, 2>::type const &x) {
418  return _mm_xor_pd(x, (mm_set1<double, 2>(-0.)));
419 }
420 
422 inline pack<double, 4>::type mm_neg(pack<double, 4>::type const &x) {
423  return _mm256_xor_pd(x, (mm_set1<double, 4>(-0.)));
424 }
425 
427 /* Complex operations using vector packs */
429 
430 // Complex Multiplication
431 
433 inline pack<float,4>::type mm_complex_mul(pack<float, 4>::type const &x, pack<float, 4>::type const &y) {
434  typename pack<float,4>::type cc = _mm_permute_ps(y, 0b10100000);
435  typename pack<float,4>::type ba = _mm_permute_ps(x, 0b10110001);
436  typename pack<float,4>::type dd = _mm_permute_ps(y, 0b11110101);
437  typename pack<float,4>::type dba = _mm_mul_ps(ba, dd);
438  typename pack<float,4>::type mult = _mm_fmaddsub_ps(x, cc, dba);
439  return mult;
440 }
441 
443 inline pack<float, 8>::type mm_complex_mul(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
444  typename pack<float,8>::type cc = _mm256_permute_ps(y, 0b10100000);
445  typename pack<float,8>::type ba = _mm256_permute_ps(x, 0b10110001);
446  typename pack<float,8>::type dd = _mm256_permute_ps(y, 0b11110101);
447  typename pack<float,8>::type dba = _mm256_mul_ps(ba, dd);
448  typename pack<float,8>::type mult = _mm256_fmaddsub_ps(x, cc, dba);
449  return mult;
450 }
451 
453 inline pack<double, 2>::type mm_complex_mul(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
454  typename pack<double,2>::type cc = _mm_permute_pd(y, 0);
455  typename pack<double,2>::type ba = _mm_permute_pd(x, 0b01);
456  typename pack<double,2>::type dd = _mm_permute_pd(y, 0b11);
457  typename pack<double,2>::type dba = _mm_mul_pd(ba, dd);
458  typename pack<double,2>::type mult = _mm_fmaddsub_pd(x, cc, dba);
459  return mult;
460 }
461 
463 inline pack<double, 4>::type mm_complex_mul(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
464  typename pack<double,4>::type cc = _mm256_permute_pd(y, 0b0000);
465  typename pack<double,4>::type ba = _mm256_permute_pd(x, 0b0101);
466  typename pack<double,4>::type dd = _mm256_permute_pd(y, 0b1111);
467  typename pack<double,4>::type dba = _mm256_mul_pd(ba, dd);
468  typename pack<double,4>::type mult = _mm256_fmaddsub_pd(x, cc, dba);
469  return mult;
470 }
471 
472 // Fused multiply-add
473 
475 inline pack<float,4>::type mm_complex_fmadd(pack<float, 4>::type const &x, pack<float, 4>::type const &y, pack<float, 4>::type const &z) {
476  typename pack<float,4>::type cc = _mm_permute_ps(y, 0b10100000);
477  typename pack<float,4>::type ba = _mm_permute_ps(x, 0b10110001);
478  typename pack<float,4>::type dd = _mm_permute_ps(y, 0b11110101);
479  typename pack<float,4>::type dba = _mm_fmaddsub_ps(ba, dd, z);
480  typename pack<float,4>::type mult = _mm_fmaddsub_ps(x, cc, dba);
481  return mult;
482 }
483 
485 inline pack<float, 8>::type mm_complex_fmadd(pack<float, 8>::type const &x, pack<float, 8>::type const &y, pack<float, 8>::type const &z) {
486  typename pack<float,8>::type cc = _mm256_permute_ps(y, 0b10100000);
487  typename pack<float,8>::type ba = _mm256_permute_ps(x, 0b10110001);
488  typename pack<float,8>::type dd = _mm256_permute_ps(y, 0b11110101);
489  typename pack<float,8>::type dba = _mm256_fmaddsub_ps(ba, dd, z);
490  typename pack<float,8>::type mult = _mm256_fmaddsub_ps(x, cc, dba);
491  return mult;
492 }
493 
495 inline pack<double, 2>::type mm_complex_fmadd(pack<double, 2>::type const &x, pack<double, 2>::type const &y, pack<double, 2>::type const &z) {
496  typename pack<double,2>::type cc = _mm_permute_pd(y, 0);
497  typename pack<double,2>::type ba = _mm_permute_pd(x, 0b01);
498  typename pack<double,2>::type dd = _mm_permute_pd(y, 0b11);
499  typename pack<double,2>::type dba = _mm_fmaddsub_pd(ba, dd, z);
500  typename pack<double,2>::type mult = _mm_fmaddsub_pd(x, cc, dba);
501  return mult;
502 }
503 
505 inline pack<double, 4>::type mm_complex_fmadd(pack<double, 4>::type const &x, pack<double, 4>::type const &y, pack<double, 4>::type const &z) {
506  typename pack<double,4>::type cc = _mm256_permute_pd(y, 0b0000);
507  typename pack<double,4>::type ba = _mm256_permute_pd(x, 0b0101);
508  typename pack<double,4>::type dd = _mm256_permute_pd(y, 0b1111);
509  typename pack<double,4>::type dba = _mm256_fmaddsub_pd(ba, dd, z);
510  typename pack<double,4>::type mult = _mm256_fmaddsub_pd(x, cc, dba);
511  return mult;
512 }
513 
515 inline pack<float,4>::type mm_complex_fmsub(pack<float, 4>::type const &x, pack<float, 4>::type const &y, pack<float, 4>::type const &z) {
516  typename pack<float,4>::type cc = _mm_permute_ps(y, 0b10100000);
517  typename pack<float,4>::type ba = _mm_permute_ps(x, 0b10110001);
518  typename pack<float,4>::type dd = _mm_permute_ps(y, 0b11110101);
519  typename pack<float,4>::type dba = _mm_fmsubadd_ps(ba, dd, z);
520  typename pack<float,4>::type mult = _mm_fmaddsub_ps(x, cc, dba);
521  return mult;
522 }
523 
525 inline pack<float, 8>::type mm_complex_fmsub(pack<float, 8>::type const &x, pack<float, 8>::type const &y, pack<float, 8>::type const &z) {
526  typename pack<float,8>::type cc = _mm256_permute_ps(y, 0b10100000);
527  typename pack<float,8>::type ba = _mm256_permute_ps(x, 0b10110001);
528  typename pack<float,8>::type dd = _mm256_permute_ps(y, 0b11110101);
529  typename pack<float,8>::type dba = _mm256_fmsubadd_ps(ba, dd, z);
530  typename pack<float,8>::type mult = _mm256_fmaddsub_ps(x, cc, dba);
531  return mult;
532 }
533 
535 inline pack<double, 2>::type mm_complex_fmsub(pack<double, 2>::type const &x, pack<double, 2>::type const &y, pack<double, 2>::type const &z) {
536  typename pack<double,2>::type cc = _mm_permute_pd(y, 0);
537  typename pack<double,2>::type ba = _mm_permute_pd(x, 0b01);
538  typename pack<double,2>::type dd = _mm_permute_pd(y, 0b11);
539  typename pack<double,2>::type dba = _mm_fmsubadd_pd(ba, dd, z);
540  typename pack<double,2>::type mult = _mm_fmaddsub_pd(x, cc, dba);
541  return mult;
542 }
543 
545 inline pack<double, 4>::type mm_complex_fmsub(pack<double, 4>::type const &x, pack<double, 4>::type const &y, pack<double, 4>::type const &z) {
546  typename pack<double,4>::type cc = _mm256_permute_pd(y, 0b0000);
547  typename pack<double,4>::type ba = _mm256_permute_pd(x, 0b0101);
548  typename pack<double,4>::type dd = _mm256_permute_pd(y, 0b1111);
549  typename pack<double,4>::type dba = _mm256_fmsubadd_pd(ba, dd, z);
550  typename pack<double,4>::type mult = _mm256_fmaddsub_pd(x, cc, dba);
551  return mult;
552 }
553 
554 // Squared modulus of the complex numbers in a pack
555 
557 inline pack<float, 4>::type mm_complex_sq_mod(pack<float, 4>::type const &x) {
558  return _mm_or_ps(_mm_dp_ps(x, x, 0b11001100), _mm_dp_ps(x, x, 0b00110011));
559 }
560 
562 inline pack<float, 8>::type mm_complex_sq_mod(pack<float, 8>::type const &x) {
563  return _mm256_or_ps(_mm256_dp_ps(x, x, 0b11001100), _mm256_dp_ps(x, x, 0b00110011));
564 }
565 
567 inline pack<double, 2>::type mm_complex_sq_mod(pack<double, 2>::type const &x) {
568  return _mm_dp_pd(x, x, 0b11111111);
569 }
570 
572 inline pack<double, 4>::type mm_complex_sq_mod(pack<double, 4>::type const &x) {
573  typename pack<double,4>::type a = _mm256_mul_pd(x, x);
574  return _mm256_hadd_pd(a, a);
575 }
576 
577 // Moduli (with square root) of complex numbers
578 
580 inline pack<float, 4>::type mm_complex_mod(pack<float, 4>::type const &x) {
581  return _mm_sqrt_ps(mm_complex_sq_mod(x));
582 }
583 
585 inline pack<float, 8>::type mm_complex_mod(pack<float, 8>::type const &x) {
586  return _mm256_sqrt_ps(mm_complex_sq_mod(x));
587 }
588 
590 inline pack<double, 2>::type mm_complex_mod(pack<double, 2>::type const &x) {
591  return _mm_sqrt_pd(mm_complex_sq_mod(x));
592 }
593 
595 inline pack<double, 4>::type mm_complex_mod(pack<double, 4>::type const &x) {
596  return _mm256_sqrt_pd(mm_complex_sq_mod(x));
597 }
598 
600 inline pack<float, 4>::type mm_complex_conj(pack<float, 4>::type const &x) {
601  return _mm_blend_ps(x, (mm_neg(x)), 0b1010);
602 }
603 
605 inline pack<float, 8>::type mm_complex_conj(pack<float, 8>::type const &x) {
606  return _mm256_blend_ps(x, (mm_neg(x)), 0b10101010);
607 }
608 
610 inline pack<double, 2>::type mm_complex_conj(pack<double, 2>::type const &x) {
611  return _mm_blend_pd(x, (mm_neg(x)), 0b10);
612 }
613 
615 inline pack<double, 4>::type mm_complex_conj(pack<double, 4>::type const &x) {
616  return _mm256_blend_pd(x, (mm_neg(x)), 0b1010);
617 }
618 
619 // Special operation when multiplying by i and -i
621 inline pack<float, 4>::type mm_complex_mul_i(pack<float, 4>::type const &x) {
622  return _mm_permute_ps( (mm_complex_conj(x)), 0b10110001);
623 }
624 
626 inline pack<float, 8>::type mm_complex_mul_i(pack<float, 8>::type const &x) {
627  return _mm256_permute_ps( (mm_complex_conj(x)), 0b10110001);
628 }
629 
631 inline pack<double, 2>::type mm_complex_mul_i(pack<double, 2>::type const &x) {
632  return _mm_permute_pd( (mm_complex_conj(x)), 0b00000001);
633 }
634 
636 inline pack<double, 4>::type mm_complex_mul_i(pack<double, 4>::type const &x) {
637  return _mm256_permute_pd( (mm_complex_conj(x)), 0b00000101);
638 }
639 
641 inline pack<float, 4>::type mm_complex_mul_neg_i(pack<float, 4>::type const &x) {
642  return mm_complex_conj(_mm_permute_ps(x, 0b10110001));
643 }
644 
646 inline pack<float, 8>::type mm_complex_mul_neg_i(pack<float, 8>::type const &x) {
647  return mm_complex_conj(_mm256_permute_ps(x, 0b10110001));
648 }
649 
651 inline pack<double, 2>::type mm_complex_mul_neg_i(pack<double, 2>::type const &x) {
652  return mm_complex_conj(_mm_permute_pd(x, 0b0000001));
653 }
654 
656 inline pack<double, 4>::type mm_complex_mul_neg_i(pack<double, 4>::type const &x) {
657  return mm_complex_conj(_mm256_permute_pd(x, 0b00000101));
658 }
659 
660 // Complex division
661 
663 inline pack<float, 4>::type mm_complex_div(pack<float, 4>::type const &x, pack<float, 4>::type const &y) {
664  return _mm_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
665 }
666 
668 inline pack<float, 8>::type mm_complex_div(pack<float, 8>::type const &x, pack<float, 8>::type const &y) {
669  return _mm256_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
670 }
671 
673 inline pack<double, 2>::type mm_complex_div(pack<double, 2>::type const &x, pack<double, 2>::type const &y) {
674  return _mm_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
675 }
676 
678 inline pack<double, 4>::type mm_complex_div(pack<double, 4>::type const &x, pack<double, 4>::type const &y) {
679  return _mm256_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
680 }
681 
682 // Now all the implementations for types in AVX512 headers
683 #ifdef Heffte_ENABLE_AVX512
684 
686 template<> struct pack<double, 8> { using type = __m512d; };
688 template<> struct pack<float, 16> { using type = __m512; };
689 
691 /* Below are structs for pack<float, 16> */
693 
695 template<>
696 inline typename pack<float, 16>::type mm_zero<float, 16>(){ return _mm512_setzero_ps(); }
697 
699 template<>
700 inline typename pack<float, 16>::type mm_load<float, 16>(float const *src) { return _mm512_loadu_ps(src); }
701 
703 template<>
704 inline void mm_store<float, 16>(float *dest, pack<float, 16>::type const &src) { _mm512_storeu_ps(dest, src); }
705 
707 template<>
708 inline typename pack<float, 16>::type mm_pair_set<float, 16>(float x, float y) { return _mm512_setr_ps(x, y, x, y, x, y, x, y, x, y, x, y, x, y, x, y); }
709 
711 template<>
712 inline typename pack<float, 16>::type mm_set1<float, 16> (float x) { return _mm512_set1_ps(x); }
713 
715 template<>
716 inline typename pack<float, 16>::type mm_complex_load<float, 16>(std::complex<float> const *src, int stride) {
717  return _mm512_setr_ps(src[0*stride].real(), src[0*stride].imag(), src[1*stride].real(), src[1*stride].imag(),
718  src[2*stride].real(), src[2*stride].imag(), src[3*stride].real(), src[3*stride].imag(),
719  src[4*stride].real(), src[4*stride].imag(), src[5*stride].real(), src[5*stride].imag(),
720  src[6*stride].real(), src[6*stride].imag(), src[7*stride].real(), src[7*stride].imag());
721 }
722 
724 template<>
725 inline typename pack<float, 16>::type mm_complex_load<float, 16>(std::complex<float> const *src) {
726  return mm_complex_load<float, 16>(src, 1);
727 }
728 
730 /* Below are structs for pack<double, 8> */
732 
734 template<>
735 inline typename pack<double, 8>::type mm_zero<double, 8>(){ return _mm512_setzero_pd(); }
736 
738 template<>
739 inline typename pack<double, 8>::type mm_load<double, 8>(double const *src) { return _mm512_loadu_pd(src); }
740 
742 template<>
743 inline void mm_store<double, 8>(double *dest, pack<double, 8>::type const &src) { _mm512_storeu_pd(dest, src); }
744 
746 template<>
747 inline typename pack<double, 8>::type mm_pair_set<double, 8>(double x, double y) { return _mm512_setr_pd(x, y, x, y, x, y, x, y); }
748 
750 template<>
751 inline typename pack<double, 8>::type mm_set1<double, 8>(double x) { return _mm512_set1_pd(x); }
752 
754 template<>
755 inline typename pack<double, 8>::type mm_complex_load<double, 8>(std::complex<double> const *src, int stride) {
756  return _mm512_setr_pd(src[0*stride].real(), src[0*stride].imag(), src[1*stride].real(), src[1*stride].imag(),
757  src[2*stride].real(), src[2*stride].imag(), src[3*stride].real(), src[3*stride].imag());
758 }
760 template<>
761 inline typename pack<double, 8>::type mm_complex_load<double, 8>(std::complex<double> const *src) {
762  return mm_complex_load<double, 8>(src, 1);
763 }
764 
766 /* Elementary binary operations for vector packs */
768 
769 /* Addition */
770 
772 inline pack<float, 16>::type mm_add(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
773  return _mm512_add_ps(x, y);
774 }
775 
777 inline pack<double, 8>::type mm_add(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
778  return _mm512_add_pd(x, y);
779 }
780 
781 /* Subtraction */
782 
784 inline pack<float, 16>::type mm_sub(pack<float, 16>::type const &x,pack<float, 16>::type const &y) {
785  return _mm512_sub_ps(x, y);
786 }
787 
789 inline pack<double, 8>::type mm_sub(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
790  return _mm512_sub_pd(x, y);
791 }
792 
793 /* Multiplication */
794 
796 inline pack<float, 16>::type mm_mul(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
797  return _mm512_mul_ps(x, y);
798 }
799 
801 inline pack<double, 8>::type mm_mul(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
802  return _mm512_mul_pd(x, y);
803 }
804 
805 /* Division */
806 
808 inline pack<float, 16>::type mm_div(pack<float, 16>::type const &x,pack<float, 16>::type const &y) {
809  return _mm512_div_ps(x, y);
810 }
811 
813 inline pack<double, 8>::type mm_div(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
814  return _mm512_div_pd(x, y);
815 }
816 
817 /* Negation */
819 inline pack<float, 16>::type mm_neg(pack<float, 16>::type const &x) {
820  return _mm512_xor_ps(x, (mm_set1<float, 16>(-0.f)));
821 }
822 
824 inline pack<double, 8>::type mm_neg(pack<double, 8>::type const &x) {
825  return _mm512_xor_pd(x, (mm_set1<double, 8>(-0.f)));
826 }
827 
829 /* Complex operations using AVX512 vector packs */
831 
832 // Complex Multiplication
833 
835 inline pack<float,16>::type mm_complex_mul(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
836  typename pack<float, 16>::type cc = _mm512_permute_ps(y, 0b10100000);
837  typename pack<float, 16>::type ba = _mm512_permute_ps(x, 0b10110001);
838  typename pack<float, 16>::type dd = _mm512_permute_ps(y, 0b11110101);
839  typename pack<float, 16>::type dba = _mm512_mul_ps(ba, dd);
840  typename pack<float, 16>::type mult = _mm512_fmaddsub_ps(x, cc, dba);
841  return mult;
842 }
843 
845 inline pack<double, 8>::type mm_complex_mul(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
846  typename pack<double, 8>::type cc = _mm512_permute_pd(y, 0b00000000);
847  typename pack<double, 8>::type ba = _mm512_permute_pd(x, 0b01010101);
848  typename pack<double, 8>::type dd = _mm512_permute_pd(y, 0b11111111);
849  typename pack<double, 8>::type dba = _mm512_mul_pd(ba, dd);
850  typename pack<double, 8>::type mult = _mm512_fmaddsub_pd(x, cc, dba);
851  return mult;
852 }
853 
854 // Complex fused-multiply add
855 
857 inline pack<float,16>::type mm_complex_fmadd(pack<float, 16>::type const &x, pack<float, 16>::type const &y, pack<float, 16>::type const &alpha) {
858  typename pack<float, 16>::type cc = _mm512_permute_ps(y, 0b10100000);
859  typename pack<float, 16>::type ba = _mm512_permute_ps(x, 0b10110001);
860  typename pack<float, 16>::type dd = _mm512_permute_ps(y, 0b11110101);
861  typename pack<float, 16>::type dba = _mm512_fmaddsub_ps(ba, dd, alpha);
862  typename pack<float, 16>::type mult = _mm512_fmaddsub_ps(x, cc, dba);
863  return mult;
864 }
865 
867 inline pack<double, 8>::type mm_complex_fmadd(pack<double, 8>::type const &x, pack<double, 8>::type const &y, pack<double, 8>::type const &alpha) {
868  typename pack<double, 8>::type cc = _mm512_permute_pd(y, 0b00000000);
869  typename pack<double, 8>::type ba = _mm512_permute_pd(x, 0b01010101);
870  typename pack<double, 8>::type dd = _mm512_permute_pd(y, 0b11111111);
871  typename pack<double, 8>::type dba = _mm512_fmaddsub_pd(ba, dd, alpha);
872  typename pack<double, 8>::type mult = _mm512_fmaddsub_pd(x, cc, dba);
873  return mult;
874 }
875 
877 inline pack<float,16>::type mm_complex_fmsub(pack<float, 16>::type const &x, pack<float, 16>::type const &y, pack<float, 16>::type const &alpha) {
878  typename pack<float, 16>::type cc = _mm512_permute_ps(y, 0b10100000);
879  typename pack<float, 16>::type ba = _mm512_permute_ps(x, 0b10110001);
880  typename pack<float, 16>::type dd = _mm512_permute_ps(y, 0b11110101);
881  typename pack<float, 16>::type dba = _mm512_fmsubadd_ps(ba, dd, alpha);
882  typename pack<float, 16>::type mult = _mm512_fmaddsub_ps(x, cc, dba);
883  return mult;
884 }
885 
887 inline pack<double, 8>::type mm_complex_fmsub(pack<double, 8>::type const &x, pack<double, 8>::type const &y, pack<double, 8>::type const &alpha) {
888  typename pack<double, 8>::type cc = _mm512_permute_pd(y, 0b00000000);
889  typename pack<double, 8>::type ba = _mm512_permute_pd(x, 0b01010101);
890  typename pack<double, 8>::type dd = _mm512_permute_pd(y, 0b11111111);
891  typename pack<double, 8>::type dba = _mm512_fmsubadd_pd(ba, dd, alpha);
892  typename pack<double, 8>::type mult = _mm512_fmaddsub_pd(x, cc, dba);
893  return mult;
894 }
895 
896 // Squared modulus of the complex numbers in a pack
897 
899 inline pack<float, 16>::type mm_complex_sq_mod(pack<float, 16>::type const &x) {
900  typename pack<float, 16>::type sq = mm_mul(x, x);
901  typename pack<float, 16>::type sq_perm = _mm512_permute_ps(sq, 0b10110001);
902  typename pack<float, 16>::type mod = mm_add(sq, sq_perm);
903  return mod;
904 }
905 
907 inline pack<double, 8>::type mm_complex_sq_mod(pack<double, 8>::type const &x) {
908  typename pack<double, 8>::type sq = mm_mul(x, x);
909  typename pack<double, 8>::type sq_perm = _mm512_permute_pd(sq, 0b01010101);
910  typename pack<double, 8>::type mod = mm_add(sq, sq_perm);
911  return mod;
912 }
913 
914 // Moduli (with square root) of complex numbers
915 
917 inline pack<float, 16>::type mm_complex_mod(pack<float, 16>::type const &x) {
918  return _mm512_sqrt_ps(mm_complex_sq_mod(x));
919 }
920 
922 inline pack<double, 8>::type mm_complex_mod(pack<double, 8>::type const &x) {
923  return _mm512_sqrt_pd(mm_complex_sq_mod(x));
924 }
925 
926 // Conjugate complex numbers
927 
929 inline pack<float, 16>::type mm_complex_conj(pack<float, 16>::type const &x) {
930  return _mm512_mask_blend_ps(0b1010101010101010, x, mm_neg(x));
931 }
932 
934 inline pack<double, 8>::type mm_complex_conj(pack<double, 8>::type const &x) {
935  return _mm512_mask_blend_pd(0b10101010, x, mm_neg(x));
936 }
937 
938 // Special operation when multiplying by i and -i
940 inline pack<float, 16>::type mm_complex_mul_i(pack<float, 16>::type const &x) {
941  return _mm512_permute_ps( (mm_complex_conj(x)), 0b10110001);
942 }
943 
945 inline pack<double, 8>::type mm_complex_mul_i(pack<double, 8>::type const &x) {
946  return _mm512_permute_pd( (mm_complex_conj(x)), 0b01010101);
947 }
948 
950 inline pack<float, 16>::type mm_complex_mul_neg_i(pack<float, 16>::type const &x) {
951  return mm_complex_conj(_mm512_permute_ps(x, 0b10110001));
952 }
953 
955 inline pack<double, 8>::type mm_complex_mul_neg_i(pack<double, 8>::type const &x) {
956  return mm_complex_conj(_mm512_permute_pd(x, 0b01010101));
957 }
958 
959 // Complex division
960 
962 inline pack<float, 16>::type mm_complex_div(pack<float, 16>::type const &x, pack<float, 16>::type const &y) {
963  return _mm512_div_ps(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
964 }
965 
967 inline pack<double, 8>::type mm_complex_div(pack<double, 8>::type const &x, pack<double, 8>::type const &y) {
968  return _mm512_div_pd(mm_complex_mul(x, mm_complex_conj(y)), mm_complex_sq_mod(y));
969 }
970 
971 #endif // Heffte_ENABLE_AVX512
972 #endif // Heffte_ENABLE_AVX
973 
974 }
975 }
976 
977 #endif // HEFFTE_STOCK_VEC_TYPES_H
Namespace containing all HeFFTe methods and classes.
Definition: heffte_backend_cuda.h:38
Struct determining whether a type is a complex number.
Definition: heffte_stock_vec_types.h:38
Struct determining whether a type is a real number.
Definition: heffte_stock_vec_types.h:33
Struct to retrieve the vector type associated with the number of elements stored "per unit".
Definition: heffte_stock_vec_types.h:43