While it works as expected (when used as a PureParallel object), as it does not call load_same or store_same, I do not know how I could use it within a Computation...
Here is the code:
def get_two_fft(arr_in):
# conjugate module
conj = reikna.cluda.functions.conj(arr_in.dtype)
# div module
div_by_double = reikna.cluda.functions.div(in_dtype1=arr_in.dtype,
in_dtype2=np.double,
out_dtype=arr_in.dtype)
# mul module
mul_by_j = reikna.cluda.functions.mul(arr_in.dtype, arr_in.dtype,
out_dtype=arr_in.dtype)
arr_type = Type(dtype=arr_in.dtype, shape=arr_in.shape)
nrows, ncols = arr_in.shape
j = np.complex64(1j)
return Transformation(
[Parameter('in_cplx', Annotation(arr_type, 'i')),
Parameter('out_cplx_1', Annotation(arr_type, 'o')),
Parameter('out_cplx_2', Annotation(arr_type, 'o')),
],
"""
## get 2D indice
VSIZE_T idx_col = threadIdx.x + blockDim.x * blockIdx.x;
VSIZE_T idx_row = threadIdx.y + blockDim.y * blockIdx.y;
## get the reversed indice
int idx_new_row;
int idx_new_col;
if (idx_row==0)
idx_new_row = 0;
else
idx_new_row = ${nrows} - idx_row;
if (idx_col==0)
idx_new_col = 0;
else
idx_new_col = ${ncols} - idx_col;
## conjugate indice-reversed FFT
${in_cplx.ctype} rev_fft = ${conj}( ${in_cplx.load_idx}(idx_new_row, idx_new_col) );
## initial FFT
${in_cplx.ctype} init_fft = ${in_cplx.load_idx}(idx_row, idx_col);
## sum and divide
${in_cplx.ctype} sum_div = ${div}(init_fft + rev_fft, 2.0);
## substract and divide
${in_cplx.ctype} sub_div = ${div}(init_fft - rev_fft, 2.0);
## FFT of the first real-valued input
${out_cplx_1.store_idx}(idx_row, idx_col, sum_div);
## FFT second real-valued input
${out_cplx_2.store_idx}(idx_row, idx_col, -${mul}(${j}, sub_div));
""",
render_kwds=dict(conj=conj, div=div_by_double, mul=mul_by_j, j=j,
nrows=nrows, ncols=ncols)
)
However, contrary to what I wrote in the first post, I do not manage to call mul() with j (complex number) as inputs.
It raises an error saying that cuda cannot work with complex integral number.
The only way to solve this was to pass a (1,1) complex array, containing j, and then calling ${arr_j.load_idx}(0, 0).
${out_cplx_2.store_idx}(idx_row, idx_col, -${mul}(${dtypes.c_constant(j)}, sub_div));
(which will evaluate to `COMPLEX_CTR(double2)(0, 1)`) or, alternatively, you can pass it to the transformation as a scalar parameter.