api = cluda.get_api('cuda')
dev = api.get_platforms()[0].get_devices()[0]
thr = api.Thread(dev)
rand = np.random.random((10,10,10))
fft = FFT(rand, axes=(0,))
fftc = fft.compile(thr, fast_math=True)
two = 2 * rand
three = 3 * rand
two_dev = thr.to_device(two)
three_dev = thr.to_device(three)
fftc(two_dev, two_dev)
fftc(three_dev, three_dev)
two = two_dev.get()
three = three_dev.get()
rand = np.random.random((10,10,10)).astype(np.complex128)
import numpy as np
from reikna import cluda
from reikna.fft import FFT
from reikna.cluda.dtypes import complex_for
from reikna.core import Type
from reikna.transformations import combine_complex, broadcast_const
api = cluda.get_api('cuda')
dev = api.get_platforms()[0].get_devices()[0]
thr = api.Thread(dev)
rand = np.random.random((10, 10, 10))
fft = FFT(Type(complex_for(rand.dtype), rand.shape), axes=(0,))
# combines two real-valued inputs into a complex-valued input of the same shape
cc = combine_complex(fft.parameter.input)
# supplies a constant output
bc = broadcast_const(cc.imag, 0)
fft.parameter.input.connect(cc, cc.output, real_input=cc.real, imag_input=cc.imag)
fft.parameter.imag_input.connect(bc, bc.output)
fftc = fft.compile(thr, fast_math=True)
two = 2 * rand
three = 3 * rand
two_dev = thr.to_device(two)
three_dev = thr.to_device(three)
two_res_dev = thr.empty_like(fft.parameter.output)
three_res_dev = thr.empty_like(fft.parameter.output)
fftc(two_res_dev, two_dev)
fftc(three_res_dev, three_dev)
two_res = two_res_dev.get()
three_res = three_res_dev.get()
assert np.allclose(two_res, np.fft.fftn(two, axes=(0,)))
assert np.allclose(three_res, np.fft.fftn(three, axes=(0,)))