Highly Efficient FFT for Exascale: HeFFTe v2.4
Loading...
Searching...
No Matches
heffte_stock_algos.h
1/*
2 -- heFFTe --
3 Univ. of Tennessee, Knoxville
4 @date
5*/
6
7#ifndef HEFFTE_STOCK_ALGOS_H
8#define HEFFTE_STOCK_ALGOS_H
9
10#include <cmath>
11
12#include "heffte_stock_complex.h"
13#include "heffte_stock_allocator.h"
14#include "heffte_common.h"
15
17#define HEFFTE_STOCK_THRESHOLD 1
18
19namespace heffte {
21inline int direction_sign(direction dir) {
22 return (dir == direction::forward) ? -1 : 1;
23}
24
25namespace stock {
26// Need forward declaration for the using directive
27template<typename F, int L>
28struct biFuncNode;
29
30// Functions in the algos implementation file facing externally
34template<typename F, int L>
35struct omega {
36 static inline Complex<F, L> get(size_t power, size_t N, direction dir) {
37 F a = 2.*M_PI*((F) power)/((F) N);
38 return Complex<F,L>(cos(a), direction_sign(dir)*sin(a));
39 }
40};
41
43enum fft_type {pow2, pow3, pow4, composite, discrete, rader};
44
49template<typename F, int L>
50class Fourier_Transform {
51 protected:
52 fft_type type;
53 size_t root = 0;
54 size_t root_inv = 0;
55 public:
56 explicit Fourier_Transform(fft_type fft): type(fft) {}
57 explicit Fourier_Transform(size_t a, size_t ainv): type(fft_type::rader), root(a), root_inv(ainv) { }
58 void operator()(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
59 switch(type) {
60 case fft_type::pow2: pow2_FFT(x, y, s_in, s_out, sRoot, dir); break;
61 case fft_type::pow3: pow3_FFT(x, y, s_in, s_out, sRoot, dir); break;
62 case fft_type::pow4: pow4_FFT(x, y, s_in, s_out, sRoot, dir); break;
63 case fft_type::composite: composite_FFT(x, y, s_in, s_out, sRoot, dir); break;
64 case fft_type::discrete: DFT(x, y, s_in, s_out, sRoot, dir); break;
65 case fft_type::rader: rader_FFT(x, y, s_in, s_out, sRoot, dir, root, root_inv); break;
66 default: throw std::runtime_error("Invalid Fourier Transform Form!\n");
67 }
68 }
69};
70
77template<typename F, int L>
78struct biFuncNode {
79 size_t sz = 0; // Size of FFT
80 Fourier_Transform<F,L> fptr; // FFT for this call
81 size_t left = 0; // Offset in array until left child
82 size_t right = 0; // Offset in array until right child
83 complex_vector<F,L> workspace; // Workspace
84 biFuncNode(): fptr(fft_type::discrete) {};
85 biFuncNode(size_t sig_size, fft_type type): sz(sig_size), fptr(type), workspace(sig_size) {}; // Create default constructor
86 biFuncNode(size_t sig_size, size_t a, size_t ainv): fptr(a,ainv), workspace(sig_size) {};
87};
88
89// Internal helper function to perform a DFT
90template<typename F, int L>
91inline void DFT_helper(size_t size, Complex<F,L>* sig_in, Complex<F,L>* sig_out, size_t s_in, size_t s_out, direction dir) {
92 if(size == 1) {
93 sig_out[0] = sig_in[0];
94 return;
95 }
96
97 // Twiddle with smallest numerator
98 Complex<F,L> w0 = omega<F,L>::get(1, size, dir);
99
100 // Base twiddle for each outer iteration
101 Complex<F,L> wk = w0;
102
103 // Twiddle for inner iterations
104 Complex<F,L> wkn = w0;
105
106 // Calculate first element of output
107 Complex<F,L> tmp = sig_in[0];
108 for(size_t n = 1; n < size; n++) {
109 tmp += sig_in[n*s_in];
110 }
111 sig_out[0] = tmp;
112
113 // Initialize rest of output
114 for(size_t k = 1; k < size; k++) {
115 // Initialize kth output
116 tmp = sig_in[0];
117
118 // Calculate kth output
119 for(size_t n = 1; n < size; n++) {
120 tmp = wkn.fmadd(sig_in[n*s_in], tmp);
121 wkn *= wk;
122 }
123 sig_out[k*s_out] = tmp;
124
125 // "Increment" wk and "reset" wkn
126 wk *= w0;
127 wkn = wk;
128 }
129}
130
131// External-facing function to properly call the internal DFT function
132template<typename F, int L>
133inline void DFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sLeaf, direction dir) {
134 DFT_helper(sLeaf->sz, x, y, s_in, s_out, dir);
135}
136
137// Recursive helper function implementing a classic C-T FFT
138template<typename F, int L>
139inline void pow2_FFT_helper(size_t N, Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, direction dir) {
140 // Trivial case
141 if(N == 2) {
142 y[ 0] = x[0] + x[s_in];
143 y[s_out] = x[0] - x[s_in];
144 return;
145 }
146
147 // Size of sub-problem
148 int m = N/2;
149
150 // Divide into two sub-problems
151 pow2_FFT_helper(m, x, y, s_in*2, s_out, dir);
152 pow2_FFT_helper(m, x+s_in, y+s_out*m, s_in*2, s_out, dir);
153
154 // Twiddle Factor
155 Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
156 Complex<F,L> wj = w1;
157 Complex<F,L> y_j = y[0];
158
159 y[0] += y[m*s_out];
160 y[m*s_out] = y_j - y[m*s_out];
161
162 // Conquer larger problem accordingly
163 for(int j = 1; j < m; j++) {
164 int j_stride = j*s_out;
165 int jm_stride = (j+m)*s_out;
166 y_j = y[j_stride];
167 y[j_stride] = y_j + wj*y[jm_stride];
168 y[jm_stride] = y_j - wj*y[jm_stride];
169 wj *= w1;
170 }
171}
172
173// External function to call the C-T radix-2 FFT
174template<typename F, int L>
175inline void pow2_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
176 const size_t N = sRoot->sz; // Size of problem
177 pow2_FFT_helper(N, x, y, s_in, s_out, dir); // Call the radix-2 FFT
178}
179
180// Recursive helper function implementing a classic C-T FFT
181template<typename F, int L>
182inline void pow4_FFT_helper(size_t N, Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, direction dir) {
183 // Trivial case
184 if(N == 4) {
185 if(dir == direction::forward) {
186 y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
187 y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
188 y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
189 y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
190 } else {
191 y[0*s_out] = x[0] + x[2*s_in] + (x[s_in] + x[3*s_in]);
192 y[1*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_i();
193 y[2*s_out] = x[0] + x[2*s_in] - (x[s_in] + x[3*s_in]);
194 y[3*s_out] = x[0] - x[2*s_in] + (x[s_in] - x[3*s_in]).__mul_neg_i();
195 }
196 return;
197 }
198
199 // Size of sub-problem
200 int m = N/4;
201
202 // Divide into two sub-problems
203 pow4_FFT_helper(m, x , y , s_in*4, s_out, dir);
204 pow4_FFT_helper(m, x + s_in, y + s_out*m, s_in*4, s_out, dir);
205 pow4_FFT_helper(m, x + 2*s_in, y + 2*s_out*m, s_in*4, s_out, dir);
206 pow4_FFT_helper(m, x + 3*s_in, y + 3*s_out*m, s_in*4, s_out, dir);
207
208 // Twiddle Factors
209 Complex<F,L> w1 = omega<F,L>::get(1,N,dir);
210 Complex<F,L> w2 = w1*w1;
211 Complex<F,L> w3 = w2*w1;
212 Complex<F,L> wk1 = w1; Complex<F,L> wk2 = w2; Complex<F,L> wk3 = w3;
213 int k0 = 0;
214 int k1 = m*s_out;
215 int k2 = 2*m*s_out;
216 int k3 = 3*m*s_out;
217 Complex<F,L> y_k0 = y[k0];
218 Complex<F,L> y_k1 = y[k1];
219 Complex<F,L> y_k2 = y[k2];
220 Complex<F,L> y_k3 = y[k3];
221 // Conquer larger problem accordingly
222 if(dir == direction::forward) {
223 y[k0] = y_k0 + y_k2 + (y_k1 + y_k3);
224 y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
225 y[k2] = y_k0 + y_k2 - (y_k1 + y_k3);
226 y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
227 for(int k = 1; k < m; k++) {
228 k0 = (k )*s_out;
229 k1 = (k + m)*s_out;
230 k2 = (k + 2*m)*s_out;
231 k3 = (k + 3*m)*s_out;
232 y_k0 = y[k0];
233 y_k1 = y[k1];
234 y_k2 = y[k2];
235 y_k3 = y[k3];
236 y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
237 y[k1] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
238 y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
239 y[k3] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
240 wk1 *= w1; wk2 *= w2; wk3 *= w3;
241 }
242 }
243 else {
244 y[k0] = y_k0 + y_k2 + y_k1 + y_k3;
245 y[k1] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_i();
246 y[k2] = y_k0 + y_k2 - y_k1 - y_k3;
247 y[k3] = y_k0 - y_k2 + (y_k1 - y_k3).__mul_neg_i();
248 for(int k = 1; k < m; k++) {
249 k0 = (k )*s_out;
250 k1 = (k + m)*s_out;
251 k2 = (k + 2*m)*s_out;
252 k3 = (k + 3*m)*s_out;
253 y_k0 = y[k0];
254 y_k1 = y[k1];
255 y_k2 = y[k2];
256 y_k3 = y[k3];
257 y[k0] = wk2.fmadd( y_k2, y_k0) + wk1.fmadd(y_k1, wk3*y_k3);
258 y[k1] = wk2.fmadd(-y_k2, y_k0) + wk1.fmsub(y_k1, wk3*y_k3).__mul_i();
259 y[k2] = wk2.fmadd( y_k2, y_k0) - wk1.fmadd(y_k1, wk3*y_k3);
260 y[k3] = wk2.fmadd(-y_k2, y_k0) + wk3.fmsub(y_k3, wk1*y_k1).__mul_i();
261 wk1 *= w1; wk2 *= w2; wk3 *= w3;
262 }
263 }
264}
265
266// External function to call the C-T radix-4 FFT
267template<typename F, int L>
268inline void pow4_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
269 const size_t N = sRoot->sz; // Size of problem
270 pow4_FFT_helper(N, x, y, s_in, s_out, dir); // Call the radix-2 FFT
271}
272
273// External & Internal function for radix-N1 C-T FFTs
274template<typename F, int L>
275inline void composite_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
276 // Retrieve N
277 size_t N = sRoot->sz;
278
279 // Find the children on the call-graph
280 biFuncNode<F,L>* left = sRoot + sRoot->left;
281 biFuncNode<F,L>* right = sRoot + sRoot->right;
282
283 // Get the size of the sub-problems
284 size_t N1 = left->sz;
285 size_t N2 = right->sz;
286
287 // I'm currently using a temporary storage space malloc'd in recursive calls.
288 // This isn't optimal and will change as the engine develops
289 Complex<F,L>* z = sRoot->workspace.data();
290 // Find the FFT of the "rows" of the input signal and twiddle them accordingly
291 Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
292 Complex<F,L> wj1 = Complex<F,L>(1., 0.);
293 for(size_t j1 = 0; j1 < N1; j1++) {
294 Complex<F,L> wk2 = wj1;
295 right->fptr(&x[j1*s_in], &z[N2*j1], N1*s_in, 1, right, dir);
296 if(j1 > 0) {
297 for(size_t k2 = 1; k2 < N2; k2++) {
298 z[j1*N2 + k2] *= wk2;
299 wk2 *= wj1;
300 }
301 }
302 wj1 *= w1;
303 }
304
305 /* Take the FFT of the "columns" of the output from the above "row" FFTs.
306 * Don't need j1 for the second transform as it's just the indexer.
307 * Take strides of N2 since z is allocated on the fly in this function for N.
308 */
309 for(size_t k2 = 0; k2 < N2; k2++) {
310 left->fptr(&z[k2], &y[k2*s_out], N2, N2*s_out, left, dir);
311 }
312}
313
314// A factoring function for the reference composite FFT
315inline size_t referenceFactor(const size_t f) {
316 // Check if it's even
317 if((f & 0x1) == 0) return 2;
318
319 // Check all odd numbers after that
320 for(size_t k = 3; k*k < f; k+=2) {
321 if( f % k == 0 ) return k;
322 }
323
324 // return f if no factor was found
325 return f;
326}
327
328// Implementation for Rader's Algorithm
329template<typename F, int L>
330inline void rader_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir, size_t a, size_t ainv) {
331 // Size of the problem
332 size_t p = sRoot->sz;
333
334 // Find the children on the call-graph
335 biFuncNode<F,L>* subFFT = sRoot + sRoot->left;
336
337 // Temporary workspace
338 Complex<F,L>* z = sRoot->workspace.data();
339 // Loop variables
340 int ak = 1;
341 int akinv = 1;
342 Complex<F,L> y0 = x[0];
343
344 // First, "invert" the order of x
345 for(size_t k = 0; k < (p-1); k++) {
346 y[k*s_out] = x[akinv*s_in];
347 y0 = y0 + y[k*s_out];
348 ak = (ak*a) % p;
349 akinv = (akinv*ainv) % p;
350 }
351
352 // Convolve the resulting vector with twiddle vector
353
354 // First fft the resulting shuffled vector
355 subFFT->fptr(y, &z[0], s_out, 1, subFFT, dir);
356
357 // Perform cyclic convolution
358 for(size_t m = 0; m < (p-1); m++) {
359 Complex<F,L> Cm = omega<F,L>::get(1, p, dir);
360 ak = a;
361 for(size_t k = 1; k < (p-1); k++) {
362 Cm = Cm + omega<F,L>::get(p*(k*m+ak) - ak, p*(p-1), dir);
363 ak = (ak*a) % p;
364 }
365 y[m*s_out] = z[m]*Cm;
366 }
367
368 // Bring back into signal domain
369 subFFT->fptr(y, &z[0], s_out, 1, subFFT, (direction) (-1*((int) dir)));
370
371 // Shuffle as needed
372 ak = 1;
373 y[0] = y0;
374 for(size_t m = 0; m < (p-1); m++) {
375 y[ak*s_out] = x[0] + (z[m]/((double) (p-1)));
376 ak = (ak*a) % p;
377 }
378}
379
380// Internal recursive helper-function that calculates the FFT of a signal with length 3^k
381template<typename F, int L>
382inline void pow3_FFT_helper(size_t N, Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, direction dir, Complex<F,L>& plus120, Complex<F,L>& minus120) {
383
384 // Calculate the DFT manually if necessary
385 if(N == 3) {
386 y[0] = x[0] + x[s_in] + x[2*s_in];
387 y[s_out] = x[0] + plus120*x[s_in] + minus120*x[2*s_in];
388 y[2*s_out] = x[0] + minus120*x[s_in] + plus120*x[2*s_in];
389 return;
390 }
391
392 // Calculate the size of the sub-problem
393 size_t Nprime = N/3;
394
395 // Divide into sub-problems
396 pow3_FFT_helper(Nprime, x, y, s_in*3, s_out, dir, plus120, minus120);
397 pow3_FFT_helper(Nprime, x+s_in, y+Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
398 pow3_FFT_helper(Nprime, x+2*s_in, y+2*Nprime*s_out, s_in*3, s_out, dir, plus120, minus120);
399
400 // Combine the sub-problem solutions
401 Complex<F,L> w1 = omega<F,L>::get(1, N, dir);
402 Complex<F,L> w2 = w1*w1;
403 Complex<F,L> wk1 = w1;
404 Complex<F,L> wk2 = w2;
405
406 int k1 = 0;
407 int k2 = Nprime * s_out;
408 int k3 = 2*Nprime * s_out;
409
410 Complex<F,L> tmpk = y[k1];
411 Complex<F,L> tmpk_p_1 = y[k2];
412 Complex<F,L> tmpk_p_2 = y[k3];
413
414 y[k1] = tmpk_p_2 + tmpk_p_1+ tmpk;
415 y[k2] = minus120.fmadd(tmpk_p_2, plus120.fmadd( tmpk_p_1, tmpk));
416 y[k3] = plus120.fmadd( tmpk_p_2, minus120.fmadd(tmpk_p_1, tmpk));
417
418 for(size_t k = 1; k < Nprime; k++) {
419 // Index calculation
420 k1 = k * s_out;
421 k2 = ( Nprime + k) * s_out;
422 k3 = (2*Nprime + k) * s_out;
423
424 // Storing temporary variables
425 tmpk = y[k1];
426 tmpk_p_1 = y[k2];
427 tmpk_p_2 = y[k3];
428
429 // Reassigning the output
430 y[k1] = wk2.fmadd( tmpk_p_2, wk1.fmadd( tmpk_p_1, tmpk));
431 y[k2] = wk2.fmadd(minus120 * tmpk_p_2, wk1.fmadd(plus120 * tmpk_p_1, tmpk));
432 y[k3] = wk2.fmadd(plus120 * tmpk_p_2, wk1.fmadd(minus120 * tmpk_p_1, tmpk));
433
434 // Twiddle factors
435 wk1 *= w1; wk2 *= w2;
436 }
437}
438
439// External-facing function for performing an FFT on signal with length N = 3^k
440template<typename F, int L>
441inline void pow3_FFT(Complex<F,L>* x, Complex<F,L>* y, size_t s_in, size_t s_out, biFuncNode<F,L>* sRoot, direction dir) {
442 const size_t N = sRoot->sz;
443 Complex<F,L> plus120 (-0.5, -sqrt(3)/2.);
444 Complex<F,L> minus120 (-0.5, sqrt(3)/2.);
445 switch(dir) {
446 case direction::forward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, plus120, minus120); break;
447 case direction::backward: pow3_FFT_helper(N, x, y, s_in, s_out, dir, minus120, plus120); break;
448 }
449}
450
451}
452}
453
454#endif // END HEFFTE_STOCK_ALGOS_H
Custom complex type taking advantage of vectorization A Complex Type intrinsic to HeFFTe that takes a...
Definition heffte_stock_complex.h:29
Complex< F, L > fmadd(Complex< F, L > const &y, Complex< F, L > const &z)
Fused multiply add.
Definition heffte_stock_complex.h:155
Functor class to represent any Fourier Transform A class to use lambdas to switch between what FFT sh...
Definition heffte_stock_algos.h:50
direction
Indicates the direction of the FFT (internal use only).
Definition heffte_common.h:652
@ backward
Inverse DFT transform.
Definition heffte_common.h:656
@ forward
Forward DFT transform.
Definition heffte_common.h:654
Namespace containing all HeFFTe methods and classes.
Definition heffte_backend_cuda.h:38
int direction_sign(direction dir)
Find the sign given a direction.
Definition heffte_stock_algos.h:21
Class to represent the call-graph nodes.
Definition heffte_stock_algos.h:78
Create a stock Complex representation of a twiddle factor.
Definition heffte_stock_algos.h:35