I'm fairly new to Numba, and was wondering whether someone could explain why the performance of this code doesn't seem to benefit very much at all from the JIT. Any suggestions on how to speed this up would be much appreciated. Superficially, this seems to me like the kind of thing that Numba should be good at -- but maybe I'm wrong.
import numpy as np
from numba import autojit
def ub_cases(a, b, mode):
if (a > 0) and (b > 0):
return (a - b) ** 2
elif (a > 0) and (b < 0):
return a ** 2 if mode == 'l' else -b * (a ** 2)
elif (a < 0) and (b > 0):
return b ** 2 if mode == 't' else -a * (b ** 2)
else:
return 0
def awarp(s, t):
D = np.zeros((s.shape[0] + 1, t.shape[0] + 1)).astype('int')
D[:,0] = int(1e10)
D[0,:] = int(1e10)
D[0,0] = 0
for i in range(s.shape[0]):
for j in range(t.shape[0]):
if (i > 0) and (j > 0):
a_d = D[i,j] + ub_cases(s[i], t[j], mode='d')
else:
a_d = D[i,j] + (s[i] - t[j]) ** 2
a_t = D[i+1,j] + ub_cases(s[i], t[j], mode='t')
a_l = D[i,j+1] + ub_cases(s[i], t[j], mode='l')
D[i+1,j+1] = np.min([a_d, a_t, a_l])
return D[-1,-1]
awarp_ = autojit(awarp)
s = np.random.choice((-3, 1), 100)
t = np.random.choice((-3, 1), 100)
%timeit awarp(s, t)
%timeit awarp_(s, t)