choosing axis of 3D FFT

已查看 61 次
跳至第一个未读帖子

zlff...@gmail.com

未读,
2017年2月20日 03:23:492017/2/20
收件人 reikna
I tried to compile 3 fft functions along each axes.
But the only fft along z-axis works fine. Others give me zero.

Here is a sample of my code.

###############################################
import numpy as np
from reikna.fft import FFT

zeros = np.zeros((10,20,30), dtype=np.complex128)

api = cluda.get_api('cuda')
dev1 = api.get_platforms()[0].get_devices()[0]

thr1 = api.Thread(dev1)

storage_dev = thr1.to_device(zeros)

fftx = FFT(storage_dev, axes=(0,))
ffty = FFT(storage_dev, axes=(1,))
fftz = FFT(storage_dev, axes=(2,))

fftxc = fftx.compile(thr1, fast_math=True)
fftyc = ffty.compile(thr1, fast_math=True)
fftzc = fftz.compile(thr1, fast_math=True)

################################################


'zeros' will be replaced with some array in actual code. It is just an example.

Is this a correct usage?

zlff...@gmail.com

未读,
2017年2月20日 03:45:482017/2/20
收件人 reikna
here, fft along z-axis means

fftz = FFT(storage_dev, axes=(2,))

Bogdan Opanchuk

未读,
2017年2月20日 07:01:082017/2/20
收件人 reikna
Could you post a full script reproducing the error? The following works fine for me:

import numpy as np
from reikna import cluda
from reikna.fft import FFT

import numpy as np
from reikna.fft import
FFT

data
= np.random.normal(size=(10,20,30)).astype(np.complex128)


api
= cluda.get_api('cuda')
dev1
= api.get_platforms()[0].get_devices()[0]
thr1
= api.Thread(dev1)

for axis in (0, 1, 2):
    storage_dev
= thr1.to_device(data)

    fft
= FFT(storage_dev, axes=(axis,))
    fftc
= fft.compile(thr1, fast_math=True)

    fftc
(storage_dev, storage_dev)
    res
= storage_dev.get()
   
ref = np.fft.fftn(data, axes=(axis,))

   
assert np.linalg.norm(res - ref) / np.linalg.norm(ref)


已删除帖子
已删除帖子

zlff...@gmail.com

未读,
2017年2月20日 22:59:512017/2/20
收件人 reikna
Since the full script is over 1200 lines, I'll put the main part of the script I want to calculate.
the main concept is IFT( i * k * FT(f(x,y,z)))
where FT : Fourier Transform, IFT : Inverse Fourier transform,
'i' : imaginary number, 'k' : wave number
and f(x,y,z) : arbitrary function with 3 variables


######################################################
import numpy as np


from reikna import cluda
from reikna.fft import FFT

api = cluda.get_api('cuda')
dev1 = api.get_platforms()[0].get_devices()[0]
thr1 = api.Thread(dev1)

IEp, JEp, KEp = 64, 128, 512
totalsize = IEp * JEp * KEp

dx, dy, dz, sig = 0.05, 0.05, 0.05, 1.

dtype = np.complex128
ones = np.ones((IEp,JEp,KEp), dtype=dtype)

program = thr1.compile("""
KERNEL void MUL(
GLOBAL_MEM ${ctype} *dest,
GLOBAL_MEM ${ctype} *a,
GLOBAL_MEM ${ctype} *b)
{
SIZE_T i = get_global_id(0);

dest[i] = ${mul}(a[i],b[i]);
}
""",render_kwds=dict( ctype=cld.dtypes.ctype(dtype),
mul=cld.functions.mul(dtype,dtype,out_dtype=dtype),
add=cld.functions.add(dtype,dtype,out_dtype=dtype)))

MUL = program.MUL

fftx = FFT(ones, axes=(0,))
ffty = FFT(ones, axes=(1,))
fftz = FFT(ones, axes=(2,))

fftxc = fftx.compile(thr1, fast_math=True)
fftyc = ffty.compile(thr1, fast_math=True)
fftzc = fftz.compile(thr1, fast_math=True)


x = np.arange(IEp, dtype=dtype) * dx
y = np.arange(JEp, dtype=dtype) * dy
z = np.arange(KEp, dtype=dtype) * dz

kx = np.fft.fftfreq(IEp, dx) * 2. * np.pi
ky = np.fft.fftfreq(JEp, dx) * 2. * np.pi
kz = np.fft.fftfreq(KEp, dx) * 2. * np.pi

nax = np.newaxis
ikx = kx[:,nax,nax] * ones * 1j
iky = ky[nax,:,nax] * ones * 1j
ikz = kz[nax,nax,:] * ones * 1j

xexp1 = np.exp((-.5)*(x/sig)**2) ## Gaussian distribution
yexp1 = np.exp((-.5)*(y/sig)**2)
zexp1 = np.exp((-.5)*(z/sig)**2)

xexp = xexp1[:,nax,nax] * ones
yexp = yexp1[nax,:,nax] * ones
zexp = zexp1[nax,nax,:] * ones

xexp_dev = thr1.to_device(xexp)
yexp_dev = thr1.to_device(yexp)
zexp_dev = thr1.to_device(zexp)

ikx_dev = thr1.to_device(ikx)
iky_dev = thr1.to_device(iky)
ikz_dev = thr1.to_device(ikz)

fft_xexp_dev = thr1.empty_like(ones)
fft_yexp_dev = thr1.empty_like(ones)
fft_zexp_dev = thr1.empty_like(ones)

ikx_fft_xexp_dev = thr1.empty_like(ones)
iky_fft_yexp_dev = thr1.empty_like(ones)
ikz_fft_zexp_dev = thr1.empty_like(ones)

ifft_ikx_fft_xexp_dev = thr1.empty_like(ones)
ifft_iky_fft_yexp_dev = thr1.empty_like(ones)
ifft_ikz_fft_zexp_dev = thr1.empty_like(ones)

ft = np.fft.fftn
ift= np.fft.ifftn

fftxc(fft_xexp_dev, xexp_dev)
fftyc(fft_yexp_dev, yexp_dev)
fftzc(fft_zexp_dev, zexp_dev)

MUL(ikx_xexp_dev, ikx_dev, fft_xexp_dev, local_size=IEp, global_size = totalsize)
MUL(iky_yexp_dev, iky_dev, fft_yexp_dev, local_size=IEp, global_size = totalsize)
MUL(ikz_zexp_dev, ikz_dev, fft_zexp_dev, local_size=IEp, global_size = totalsize)

fftxc(ifft_ikx_fft_xexp_dev, ikx_fft_xexp_dev, inverse=True)
fftyc(ifft_iky_fft_yexp_dev, iky_fft_yexp_dev, inverse=True)
fftzc(ifft_ikz_fft_zexp_dev, ikz_fft_zexp_dev, inverse=True)

ifft_ikx_fft_xexp_gpu = ifft_ikx_fft_dev.get()
ifft_iky_fft_yexp_gpu = ifft_iky_fft_dev.get()
ifft_ikz_fft_zexp_gpu = ifft_ikz_fft_dev.get()

ifft_ikx_fft_xexp_cpu = ift( ikx * ft( xexp, axes=(0,)), axes=(0,))
ifft_iky_fft_yexp_cpu = ift( iky * ft( yexp, axes=(1,)), axes=(1,))
ifft_ikz_fft_zexp_cpu = ift( ikz * ft( zexp, axes=(2,)), axes=(2,))

###########################################################

the results of gpu and cpu are different. only z-axis one is equal to each other...

Bogdan Opanchuk

未读,
2017年2月20日 23:48:492017/2/20
收件人 reikna
The script you posted cannot actually be executed --- there are some undefined variables. Next time please make sure that your reproduction script is runnable and has a clearly defined part where you check for the error it is supposed to reproduce.

I took some guesses about what those undefined names should be, and what are you checking for (see asserts in the end):
https://gist.github.com/fjarri/58d5200f24c9e615b6968501121178cb
This works fine for me as well. Can you confirm that I'm checking for the error you are observing?

zlff...@gmail.com

未读,
2017年2月20日 23:52:332017/2/20
收件人 reikna
sorry I should have checked it...

I'll debug and post it again. Thank you

zlff...@gmail.com

未读,
2017年2月21日 02:41:122017/2/21
收件人 reikna
I've found the error in my code. All problem is solved. Thank you!
回复全部
回复作者
转发
0 个新帖子