I've created the following simple example code.
from mlir import ir
import mlir.dialects.stablehlo as stablehlo
import mlir.dialects.func as func
from
mlir.ir import Context, Location, InsertionPoint, Module, IntegerType, IntegerAttr, ArrayAttr
import numpy as np
ASM_FORMAT = """
module {
func.func @main(%arg0: tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) {
%result = "stablehlo.uniform_quantize"(%arg0) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
func.return
}
}
"""
def test_reference_api():
arg = np.asarray([4, 15], np.int8)
with ir.Context() as context:
stablehlo.register_dialect(context)
m = ir.Module.parse(ASM_FORMAT)
args = [ir.DenseElementsAttr.get(arg)]
assert m.operation.verify()
stablehlo.eval_module(m, args)
test_reference_api()
But I encounter the following error.
loc("-":3:3): error: invalid input argument type at index 0, input type was 'tensor<2xi8>' but entry function expected 'tensor<2x!quant.uniform<i8:f32:0, {1.000000e-01:-30,5.000000e-01:-20}>>'
ValueError: interpreter failed
How do I resolve this? Any help & ideas are appreciated.