How to get attribute for quantized type?

61 views
Skip to first unread message

Axel Wong

unread,
May 30, 2024, 10:13:07 PMMay 30
to OpenXLA Discuss
Hi everyone,

I would like to get attribute for quantized type by using Python bindings.
I've tried the following Python code snippet.

from mlir.dialects import quant
uniform = quant.UniformQuantizedType.get(
    quant.UniformQuantizedType.FLAG_SIGNED, i8, f32,  0.264, 17, -8, 7
)
ir.DenseElementsAttr.get(np.zeros(shape=(1,10), dtype=np.float32), type=uniform)

But encounter the following error.
python: StableHLO/llvm-project/mlir/lib/IR/Types.cpp:126: unsigned int mlir::Type::getIntOrFloatBitWidth() const: Assertion `isIntOrFloat() && "only integers and floats have a bitwidth"' failed.

How can I get the following attribute type.

tensor<1x10x!quant.uniform<i8<-8:7>:f32, 0.264:17>>

Thanks.
Axel

Sandeep Dasgupta

unread,
May 30, 2024, 11:02:10 PMMay 30
to Axel Wong, OpenXLA Discuss
Hello Axel
Can you please try the following?

    i8 = ir.IntegerType.get_signless(8)
    f32 = ir.F32Type.get()
    uniform = quant_dialect.UniformQuantizedType.get(quant_dialect.UniformQuantizedType.FLAG_SIGNED, i8, f32,  0.264, 17, -8, 7)
    tensor = ir.RankedTensorType.get(shape=[1,10], element_type=uniform)

I do not think we need to create a DenseElementsAttr for this specific need. In any case please let me know. 

Regards,
Saneeep

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/e14ed9df-19ba-4d3e-8e58-131b628b0cadn%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Axel Wong

unread,
May 31, 2024, 12:40:11 AMMay 31
to OpenXLA Discuss, Sandeep Dasgupta, OpenXLA Discuss, Axel Wong
Hi Sandeep,

This works for result type of any operation, but what I would like to achieve is to create DenseElementsAttr for input arguments having QuantizedType.
I also found out that there are missing classes in Python bindings 'quant' dialect, there should be conversion operations as in  'quant' Dialect - MLIR (llvm.org).

Thanks
Axel

Jacques Pienaar

unread,
May 31, 2024, 1:08:26 PMMay 31
to Axel Wong, OpenXLA Discuss, Sandeep Dasgupta
Hey,

Probably also good to ask on MLIR discord as more folks working on these are there. Without knowing much more, I'd recommend following the example from llvm-project/mlir/test/Dialect/Quant/canonicalize.mlir (so create a scalar constant and then scast it) as starting point. I'd look at the larger stack trace too (or grab backtrace using gdb) as to where this is going wrong. E.g., `ir.DenseElementsAttr.get(np.zeros(shape=(1,10), dtype=np.float32)` feels off as I'd have expected to use int there ... but I haven't used Quant dialect in ages nor do I see many examples of it used directly. That being said, if you take the canoncalize example and modify it so that there isn't an add and just fold the scast, it should show you what it should look like (and then one could always do some .dump() calls C++ side to see the exact C++ structs and so work it out from the test).

-- Jacques

Axel Wong

unread,
Jun 3, 2024, 4:44:28 AMJun 3
to OpenXLA Discuss, Jacques Pienaar, OpenXLA Discuss, Sandeep Dasgupta, Axel Wong
I've created the following simple example code.

from mlir import ir
import mlir.dialects.stablehlo as stablehlo
import mlir.dialects.func as func
from mlir.ir import Context, Location, InsertionPoint, Module, IntegerType, IntegerAttr, ArrayAttr
import numpy as np

ASM_FORMAT = """
module {
  func.func @main(%arg0: tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) {
    %result = "stablehlo.uniform_quantize"(%arg0) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
    func.return
  }
}
"""

def test_reference_api():
  arg = np.asarray([4, 15], np.int8)
  with ir.Context() as context:
    stablehlo.register_dialect(context)
    m = ir.Module.parse(ASM_FORMAT)
    args = [ir.DenseElementsAttr.get(arg)]

  assert m.operation.verify()
  stablehlo.eval_module(m, args)

test_reference_api()

But I encounter the following error.

loc("-":3:3): error: invalid input argument type at index 0, input type was 'tensor<2xi8>' but entry function expected 'tensor<2x!quant.uniform<i8:f32:0, {1.000000e-01:-30,5.000000e-01:-20}>>'
ValueError: interpreter failed

How do I resolve this? Any help & ideas are appreciated.

Thanks
Axel

Sandeep Dasgupta

unread,
Jun 3, 2024, 12:27:48 PMJun 3
to Axel Wong, OpenXLA Discuss, Jacques Pienaar
Hello Axel


>> loc("-":3:3): error: invalid input argument type at index 0, input type was 'tensor<2xi8>' but entry function expected 'tensor<2x!quant.uniform<i8:f32:0, {1.000000e-01:-30,5.000000e-01:-20}>>'
The error comes from https://github.com/openxla/stablehlo/blob/2f97b6ccf5c72618d8c391fd734aafb5cc3b31a7/stablehlo/reference/Api.cpp#L134 which checks the compatibility of input actual arguments and formal argments and fails.
Also, note that the stablehlo interpreter in its current state does not support quantized types and operations on those types. We have plans to support that in the near future.  

In order to make progress for now, how about we do the following:
1. Use the argument type of the main function same as input actual argument (tensor<2xi8>
2. Let the main module return some value, else the interpreter output will be `none`.
3. Use bitcast_convert to type cast the argument to a quantized type.
4. Use  decompose_stablehlo_quantized_function to decompose the quantized types/operations to primitive integer type.
5. Run the quantizer on the decomposed module.


  def test(self):

    arg = np.asarray([4, 15], np.int8)
    with ir.Context() as context:
      stablehlo_dialect.register_dialect(context)
      ASM_FORMAT = """
module {
  func.func @main(%arg0: tensor<2xi8>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>> {
    %bcast = "stablehlo.bitcast_convert"(%arg0) : (tensor<2xi8>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
    %result = "stablehlo.uniform_quantize"(%bcast) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
    func.return %result : tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>

  }
}
"""
      m = ir.Module.parse(ASM_FORMAT)
      args = [ir.DenseElementsAttr.get(arg)]

      assert m.operation.verify()
      module = execute_testdata.decompose_stablehlo_quantized_function(m)
      output = stablehlo_dialect.eval_module(module, args)
      print(output)


def decompose_stablehlo_quantized_function(module: ir.Module) -> ir.Module:
  with cloned_module.context:
    mhlo.register_mhlo_passes()
    pipeline = [
        'stablehlo-legalize-to-hlo',
        'func.func(mhlo-quant-legalize-to-int)',
        'func.func(chlo-legalize-to-hlo)',
        'func.func(shape-legalize-to-hlo{legalize-constraints=true})',
        'hlo-legalize-to-stablehlo',
    ]
    pipeline = pm.PassManager.parse(f"builtin.module({','.join(pipeline)})")
    pipeline.run(cloned_module.operation)
    return cloned_module


Please let us know if this works for your case. 

On a separate note:
The interpreter support for quantized type is under development and we would like to know more about your use-cases to help improve its experience. Please  let us know. 

Sandeep Dasgupta

unread,
Jun 3, 2024, 12:56:32 PMJun 3
to Axel Wong, OpenXLA Discuss, Jacques Pienaar
On a second thought we do not need  modifications involving bitcast (step 3) and use the following stablehlo module instead (which is exactly similar to what you had before)

ASM_FORMAT = """
module {
  func.func @main(%arg0:  tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>> {

    %result = "stablehlo.uniform_quantize"(%arg0) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
    func.return %result : tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
  }
}
"""

The decomposition (step 4) will take care of the input formal argument quantized types. 

Regards
Sandeep
Reply all
Reply to author
Forward
0 new messages