Using jit/njit with solvers (higher-order functions)

0 views
Skip to first unread message

Jason Sachs

unread,
Apr 7, 2017, 1:20:59 PM4/7/17
to numba...@continuum.io
I have a question on how to make the best use of numba's jit/njit with a root-finding solver like scipy.optimize.brentq (which is coded in C and very fast).

Here is a very simple test case, where I try to find the value of omega such that H(j*omega) has a specified phase, with H(s) = 1/(L*s+R). In order to use brentq, I need a real-valued scalar function, which means that I can't pass in H(s) directly, I have to use an adapter function. But it's a trivial adapter function, requiring two multiplies, one add, one call to H(s), and one call to np.angle(), so at first glance, this seems like a prime candidate for numba.
  • If I know H(s) beforehand, I can just write the adapter function myself (see f1() below) and use njit and everything is nice and fast.
  • If I don't know H(s) beforehand, or I want to generalize (Don't Repeat Yourself) and find a way to solve this problem for several different H(s) functions, I can't seem to find a fast way of doing it.
    • My attempt f2() with no additional use of numba, besides using njit with H(s) itself, is more than 6 times slower than the specific case, even though the adapter function is so simple.
    • My attempt f3() to try to jit a general adapter function makes things even slower; nopython won't work (presumably because it is a higher order function) and a regular jit just slows things down.
Any suggestions? This is just a simple self-contained example; my real code is much more complex for H(s), but it's the same type of computation.

--Jason

--------------------

import numpy as np
import numba
from scipy.optimize import brentq

p = np.rec.array([0.5, 1.0e-3], dtype=[('R','f8'),('L','f8')])

@numba.njit
def H(s,p):
    return 1.0 / (p['R'] + p['L']*s)

@numba.njit
def f1(w,p, phi):
    return np.angle(H(1j*w,p)) + phi/180.0*np.pi

@numba.jit   # can't use nopython mode
def anglefunc(w,H,p,phi):
    return np.angle(H(1j*w,p)) + phi/180.0*np.pi

def f2(H,p,phi,w1=1,w2=1e5):
    def f(w,p):
        return np.angle(H(1j*w,p)) + phi/180.0*np.pi
    return brentq(f,w1,w2,(p,))

def f3(H,p,phi,w1=1,w2=1e5):
    return brentq(anglefunc,w1,w2,(H,p,phi))

print brentq(f1,1,1e5,(p,80))
%timeit brentq(f1,1,1e5,(p,80))
#2835.64090981
#100000 loops, best of 3: 14.8 µs per loop

print f2(H,p,80)
%timeit f2(H,p,80)
# 2835.64090981
# 10000 loops, best of 3: 95.7 µs per loop

print f3(H,p,80)
%timeit f3(H,p,80)
# 2835.64090981
# 1000 loops, best of 3: 471 µs per loop


Kevin Sheppard

unread,
Apr 7, 2017, 1:37:53 PM4/7/17
to numba...@continuum.io

This seems to work:

 

def factory(H,p,phi,w1=1,w2=1e5, n=10):

    H_njit=njit(H)

 

    @numba.njit

    def f(w,p, phi):

        return np.angle(H_njit(1j*w,p)) + phi/180.0*np.pi

   

    for i in range(n):

        out = brentq(f,w1,w2,(p,phi))

    return out

 

factory(H,p,80)

 

 

 

There is a lot of overhead in a single call which due to the local use of njit.  This is what I am running the brentq multiple times for timing.

 

 

%timeit factory(H,p,80,n=1)

10 loops, best of 3: 110 ms per loop

 

%timeit factory(H,p,80,n=10)

10 loops, best of 3: 111 ms per loop

 

%timeit factory(H,p,80,n=100)

10 loops, best of 3: 112 ms per loop

 

%timeit factory(H,p,80,n=1000)

10 loops, best of 3: 122 ms per loop

 

 

Kevin

--
You received this message because you are subscribed to the Google Groups "Numba Public Discussion - Public" group.
To unsubscribe from this group and stop receiving emails from it, send an email to numba-users...@continuum.io.
To post to this group, send email to numba...@continuum.io.
To view this discussion on the web visit https://groups.google.com/a/continuum.io/d/msgid/numba-users/CAOo6sON27_GG-W5utAWpywB%2BkAjcdiLFkm6t9WE1wjosKnrgig%40mail.gmail.com.
For more options, visit https://groups.google.com/a/continuum.io/d/optout.

 

Jason Sachs

unread,
Apr 7, 2017, 1:45:21 PM4/7/17
to numba...@continuum.io
One more datapoint: If I quit using recarrays and just access by index: it is almost twice as fast. It would be really great if there were a way to use Good Software Design Principles and pass in a structured array so I can access by name, but get the same kind of speedup. It seems like numba should be able to optimize the indexing here.

@numba.njit
def Hidx(s,pidx):
    return 1.0 / (pidx[0] + pidx[1]*s)

@numba.njit
def f1idx(w,pidx,phi):
    return np.angle(Hidx(1j*w,pidx)) + phi/180.0*np.pi

pidx = np.array([0.5, 1.0e-3])
print brentq(f1idx,1,1e5,(pidx,80))
%timeit brentq(f1idx,1,1e5,(pidx,80))
# 2835.64090981
# 100000 loops, best of 3: 9.17 µs per loop

Jason Sachs

unread,
Apr 7, 2017, 1:53:46 PM4/7/17
to numba...@continuum.io
Yeah, I thought of doing that, but the njit in each call is too expensive. I suppose I could cache the njitted function, so if you pass in the same H(s) then it would reuse the njit results.

Just to reproduce your results on my system so I can compare timing:

def f4(H,p,phi,w1=1,w2=1e5, n=10): 
    @numba.njit
    def f(w,p, phi):
        return np.angle(H(1j*w,p)) + phi/180.0*np.pi
   
    for i in xrange(n):
        out = brentq(f,w1,w2,(p,phi))
    return out
 
print f4(H,p,80,n=1)
%timeit f4(H,p,80,n=1)
%timeit f4(H,p,80,n=1001)
# 2835.64090981
# 10 loops, best of 3: 54.9 ms per loop
# 10 loops, best of 3: 71.9 ms per loop

This implies about 17 us per call to brentq on my system

To unsubscribe from this group and stop receiving emails from it, send an email to numba-users+unsubscribe@continuum.io.

 

--
You received this message because you are subscribed to the Google Groups "Numba Public Discussion - Public" group.
To unsubscribe from this group and stop receiving emails from it, send an email to numba-users+unsubscribe@continuum.io.

To post to this group, send email to numba...@continuum.io.

Jason Sachs

unread,
Apr 7, 2017, 1:59:54 PM4/7/17
to numba...@continuum.io
Whee! Here we go, caching to the rescue: (Now I just need to figure out how to avoid memory leaks.)

def f5_factory():
    cache = {}
    def f5(H,p,phi,w1=1,w2=1e5):
        f = cache.get(id(H))
        if f is None:
            @numba.njit
            def f(w,p, phi):
                return np.angle(H(1j*w,p)) + phi/180.0*np.pi
            cache[id(H)] = f
        return brentq(f,w1,w2,(p,phi))
    return f5
f5 = f5_factory()
 
print f5(H,p,80)
%timeit f5(H,p,80)
print f5(H,p,80)  # just make sure the cache still works
# 2835.64090981
# 100000 loops, best of 3: 15.8 µs per loop
# 2835.64090981
Reply all
Reply to author
Forward
0 new messages