Currently, for dot operations with F32 storage type, we have 3 precisions: DEFAULT, HIGH, HIGHEST.
But the meaning of these precisions depend on the type and version of the accelerator.
This can cause a surprising precision loss when running something on different accelerators, and it is also not flexible enough to represent other algorithms, such as "
bf16_6x".
So we will try to add new specific precisions, which explicitly define the algorithm.
These are the ones we consider adding now:
- DOT_BF16_BF16_F32 (meaning: Each primitive tile is calculated using a primitive dot operation which operates on BF16 inputs and accumulates the results into F32 values.)
- DOT_BF16_BF16_F32_X3
- DOT_BF16_BF16_F32_X6 (meaning: Each primitive tile is calculated using 6 primitive dot operations which operate on BF16 inputs and accumulate the results into F32 values. In this algorithm, 6 BF16 dots are used to get a precision similar to a single F32 dot.)
- DOT_TF32_TF32_F32
- DOT_TF32_TF32_F32_X3
- DOT_F32_F32_F32
The naming convention is DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE and ACCUM_TYPE correspond to the types in the "primitive dot operations" (such as TensorCore operations) and NUM_OPS is the number of such operations used per "primitive tile" (such as 16x16 tile). When the NUM_OPS field is skipped, it is assumed to be 1. The types mentioned in the name are independent of the storage types.
The mapping of accelerators and existing precisions to algorithms when the input and output storage types are F32:
| DEFAULT | HIGH | HIGHEST |
GPU from Ampere | DOT_TF32_TF32_F32 | DOT_TF32_TF32_F32 | DOT_F32_F32_F32 |
GPU below Ampere | DOT_F32_F32_F32 | DOT_F32_F32_F32 | DOT_F32_F32_F32 |
CPU | DOT_F32_F32_F32 | DOT_F32_F32_F32 | DOT_F32_F32_F32 |
The initial step is to add these to XLA and use them at codegen, but later we plan to expose them in StableHLO, so that JAX and other libraries could set the precision explicitly.
For the StableHLO changes, we will send out an RFC of course.
Some details are not fully clear yet:
- Whether we will reuse operand_precision, or add an algorithm field.
(- Whether we will also add Convolution algorithms.
- Whether we will add algorithms for F8, F16, BF16, etc. storage types.)
The details of the proposal may crystalise during the implementation.
Regards,
Tamás Danyluk