Refactoring dot operation algorithms (and possibly more)

144 views
Skip to first unread message

Tamás Danyluk

unread,
Feb 19, 2024, 10:16:16 AMFeb 19
to openxla...@openxla.org
Hi all,

Currently, for dot operations with F32 storage type, we have 3 precisions: DEFAULT, HIGH, HIGHEST.
But the meaning of these precisions depend on the type and version of the accelerator.
This can cause a surprising precision loss when running something on different accelerators, and it is also not flexible enough to represent other algorithms, such as "bf16_6x".

So we will try to add new specific precisions, which explicitly define the algorithm.

These are the ones we consider adding now:
  • DOT_BF16_BF16_F32 (meaning: Each primitive tile is calculated using a primitive dot operation which operates on BF16 inputs and accumulates the results into F32 values.)
  • DOT_BF16_BF16_F32_X3
  • DOT_BF16_BF16_F32_X6 (meaning: Each primitive tile is calculated using 6 primitive dot operations which operate on BF16 inputs and accumulate the results into F32 values. In this algorithm, 6 BF16 dots are used to get a precision similar to a single F32 dot.) 
  • DOT_TF32_TF32_F32
  • DOT_TF32_TF32_F32_X3
  • DOT_F32_F32_F32
The naming convention is DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE and ACCUM_TYPE correspond to the types in the "primitive dot operations" (such as TensorCore operations) and NUM_OPS is the number of such operations used per "primitive tile" (such as 16x16 tile). When the NUM_OPS field is skipped, it is assumed to be 1. The types mentioned in the name are independent of the storage types.

The mapping of accelerators and existing precisions to algorithms when the input and output storage types are F32:

DEFAULT

HIGH

HIGHEST

GPU from Ampere

DOT_TF32_TF32_F32

DOT_TF32_TF32_F32

DOT_F32_F32_F32

GPU below Ampere

DOT_F32_F32_F32

DOT_F32_F32_F32

DOT_F32_F32_F32

CPU

DOT_F32_F32_F32

DOT_F32_F32_F32

DOT_F32_F32_F32


The initial step is to add these to XLA and use them at codegen, but later we plan to expose them in StableHLO, so that JAX and other libraries could set the precision explicitly.

For the StableHLO changes, we will send out an RFC of course.

Some details are not fully clear yet:
- Whether we will reuse operand_precision, or add an algorithm field.
(- Whether we will also add Convolution algorithms.
- Whether we will add algorithms for F8, F16, BF16, etc. storage types.)

The details of the proposal may crystalise during the implementation.

Regards,
Tamás Danyluk

Farid Zakaria

unread,
Feb 22, 2024, 4:12:07 PMFeb 22
to Tamás Danyluk, Kevin Gleason, openxla...@openxla.org
Thanks for the discussion.

Please reach out to me or @Kevin Gleason if you want to discuss the StableHLO specific changes and drafting an RFC.

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/CACibGPUdATBWFT%3DCT_S1OT2F%3DkYR%2BMDupxqtNKz3Mw1UU%3Dvj4Q%40mail.gmail.com.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.


--
Email seems short? go/efficient-emails

Tamás Danyluk

unread,
Mar 15, 2024, 9:56:35 AMMar 15
to Farid Zakaria, Kevin Gleason, openxla...@openxla.org, George Karpenkov
Hi all,

I've posted an RFC for adding "algorithm" to dot_general's parameters in the StableHLO specification.
Please take a look and feel free to leave constructive comments.

https://github.com/openxla/stablehlo/pull/2096

Best,
Tamas
Reply all
Reply to author
Forward
0 new messages