after some poking around in the dispatchers I came-up with the decorator below which seems to work. Essentially I am applying overload to a dummy function, jit that, and shoe horn the result into type_callable/lower_builtin.
However I am slightly worried I might be missing more subtle issues (explicitly keeping weakrefs alive etc does not sound like a good idea ...). Any ideas?
def overload_binop(op_name):
def decorator(func):
# overload dummy function with the provided func (cannot use operator.mul etc as numba will resolve this
# as '*' ending in a recursion...)
def dummy(x, y):
raise NotImplementedError
overload(dummy)(func)
# define dummy2 calling dummy and jit dummy2 (using the overload func)
def dummy2(x, y):
return dummy(x, y)
disp = jit(nopython=True)(dummy2)
disp_type = types.Dispatcher(disp) # will access the impl via this in lower_builtin below
dispatcher = disp_type.dispatcher # keep a reference around for use in typer, Dispatcher.dispatcher is a weakref
cache = {}
@type_callable(op_name)
def typer(ctx):
def type_binop(x, y):
sig = cache.get((x, y), -1)
if sig == -1:
try:
# try if dummy2's dispatcher can type the call (compiles func for the signature if possible)
template, pysig, args, kws = dispatcher.get_call_template((x, y), {})
sig = template(ctx).apply((x, y), {})
except TypingError:
cache[x, y] = None
return None
# if dummy2's dispatcher can type the call provide an impl for op_name that calls the corresponding
# dummy2 impl
@lower_builtin(op_name, *sig.args)
def lower_mul_quantity(ctx, builder, sig, args):
impl = ctx.get_function(disp_type, sig)
return impl(builder, args)
cache[x, y] = sig
return sig
return type_binop
return func
return decorator