Nevertheless I found two possibilities to speed up your code:
The first one: If you do not construct an array every time you call the function that helps. In the code below replaced the ones(n,n,k) with an array I defined before
The second, more interesting one: In the Numba 0.34 development version you can use the keyword "parallel" in the (n)jit decorator. If you have a multicore machine that can give nice speedups Please check the documentation how to install a development version.
Both measures together give a nice 7x speedup on my machine (28 cores). The code follows
Leo
import numpy as np
import time
from numba import *
@njit(parallel=True)
def njit_updatePsi(n, theta, distance, pl):