(don't like reading long markdown emails: read at the source)
Most people who come to IREE ask a couple of the same questions about the design:
The two questions are linked and I'll try to elaborate the design behind VMVX and project where I think it should be going. I've done some initial implementation work on this in the PR Lower VMVX Linalg ops to microkernels. Most of this is following on the careful work that Ben Vanik laid out at the beginning of the project but we have not yet taken the time to elaborate.
For those unfamiliar, IREE has multiple code generation backends (referred to as HAL target backend in the code). These generate code compatible with some set of runtime drivers that are responsible for loading and scheduling the work. In most cases, there are easy correlations:
vulkan
runtime driver handles executables generated by the vulkan-spirv
compiler HAL target.local-sync
runtime driver handles synchronous CPU exection for the CPU-based llvm-aot
HAL target.local-task
runtime driver handles CPU multi-threaded execution.cuda
runtime driver handlesThe local-*
runtime drivers can execute modules that contain VMVX kernels as well. This is used today as part of test suites and the compiler itself uses it to compile constant expressions for generalized constant evaluation at compile time. Is it just a reference implementation?
Not really. It is actually a full CPU backend which does tiling and distribution to threads in a similar fashion to the other backends, but it stops lowering after bufferizing linalg
ops, instead taking a path which performs simplification/cleanup and ultimately emits VM bytecode for the loopy, lowered program. Today, this always lowers completely to scalar code. As expected, this is not particularly fast (i.e. in the default mode, it is both interpreted and at the scalar level). While the code it generates is very slow, it is also:
local-task
), this achieves generally good parallelism and uses similar heuristics as its more industrial strength siblings.IREE's VM is both a compile time target and multiple execution modalities. At compile time, it is represented by the VM Dialect. While it bears some similarity to LLVM IR, it is important to look at both the points of ommission and deviation:
!vm.buffer
)!vm.list
)While it is not uncommon to implement the above features in various lowerings to LLVM IR, the resulting program is very low level and, basically, is only suitable for use by an LLVM backend. By stopping at this intermediate level, we have an IR that is still:
The "VX" stands for "Vector Extensions" -- which is probably somewhat troubling to folks, considering that the current implementation only lowers to scalars :)
The dialect README talks about how to add ops, but in fact, the dialect is empty and little guidance is given on where this is going. Upstream terminology and layers of abstraction have shifted a bit and the above README needs some polishing. The primary guidance is: "The operations added here are modeled as close to a machine ISA as reasonable, meaning that there are no shapes, element types are encoded as part of the operations, and memory access is tightly restricted."
The work described from here enunciates the next steps and where this can go.
As a first set of extensions, the above PR adds four ops to VMVX and implements loopy reference kernels for them:
def VMVX_AddOp : VMVX_Op<"add", [SameVariadicOperandSize]> {
let summary = "Performs a strided elementwise add of two same-rank buffers";
let description = [{
Performs addition in-place as if:
OUT = LHS + RHS
All operands have the same rank.
}];
let arguments = (ins
// LHS.
VMVX_Buffer:$lhs_buffer,
VMVX_Index:$lhs_offset,
Variadic<VMVX_Index>:$lhs_strides,
// RHS.
VMVX_Buffer:$rhs_buffer,
VMVX_Index:$rhs_offset,
Variadic<VMVX_Index>:$rhs_strides,
// OUT.
VMVX_Buffer:$out_buffer,
VMVX_Index:$out_offset,
Variadic<VMVX_Index>:$out_strides,
// Dimensions.
Variadic<VMVX_Index>:$size
);
let assemblyFormat = [{
`lhs` `` `(` $lhs_buffer `offset` $lhs_offset `strides` `[` $lhs_strides `]` `:` type($lhs_buffer) `)`
`rhs` `` `(` $rhs_buffer `offset` $rhs_offset `strides` `[` $rhs_strides `]` `:` type($rhs_buffer) `)`
`out` `` `(` $out_buffer `offset` $out_offset `strides` `[` $out_strides `]` `:` type($out_buffer) `)`
`size` `` `(` $size `)`
attr-dict
}];
}
def VMVX_CopyOp : VMVX_Op<"copy", [SameVariadicOperandSize]> {
let summary = "Copy from one buffer to another";
let arguments = (ins
// LHS.
VMVX_Buffer:$inp_buffer,
VMVX_Index:$inp_offset,
Variadic<VMVX_Index>:$inp_strides,
// OUT.
VMVX_Buffer:$out_buffer,
VMVX_Index:$out_offset,
Variadic<VMVX_Index>:$out_strides,
// Dimensions.
Variadic<VMVX_Index>:$size
);
let assemblyFormat = [{
`inp` `` `(` $inp_buffer `offset` $inp_offset `strides` `[` $inp_strides `]` `:` type($inp_buffer) `)`
`out` `` `(` $out_buffer `offset` $out_offset `strides` `[` $out_strides `]` `:` type($out_buffer) `)`
`size` `` `(` $size `)`
attr-dict
}];
}
def VMVX_Fill2DOp : VMVX_Op<"fill2d"> {
let summary = "Fill a tile with a scalar";
let description = [{
Fills a tile with dimensions [m, n] with a scalar.
}];
let arguments = (ins
VMVX_ElementType:$scalar,
VMVX_Buffer:$out_buffer,
VMVX_Index:$out_offset,
VMVX_Index:$out_row_stride,
// Dimensions.
VMVX_Index:$m,
VMVX_Index:$n
);
let assemblyFormat = [{
`scalar` `` `(` $scalar `:` type($scalar) `)`
`out` `` `(` $out_buffer `offset` $out_offset `row_stride` $out_row_stride `:` type($out_buffer) `)`
`size` `` `(` $m `,` $n `)`
attr-dict
}];
}
def VMVX_MatmulOp : VMVX_Op<"matmul"> {
let summary = "Matmul";
let description = [{
General matrix-multiply of the form:
OUT = alpha * (LHS * RHS) + beta * OUT
}];
let arguments = (ins
// Lhs buffer.
VMVX_Buffer:$lhs_buffer,
VMVX_Index:$lhs_offset,
VMVX_Index:$lhs_row_stride,
// Rhs buffer.
VMVX_Buffer:$rhs_buffer,
VMVX_Index:$rhs_offset,
VMVX_Index:$rhs_row_stride,
// Out buffer.
VMVX_Buffer:$out_buffer,
VMVX_Index:$out_offset,
VMVX_Index:$out_row_stride,
// Dimensions.
VMVX_Index:$m,
VMVX_Index:$n,
VMVX_Index:$k,
// Scale factors.
VMVX_ElementType:$alpha,
VMVX_ElementType:$beta,
// Execution flags.
I32Attr:$flags
);
let assemblyFormat = [{
`lhs` `` `(` $lhs_buffer `offset` $lhs_offset `row_stride` $lhs_row_stride `:` type($lhs_buffer) `)`
`rhs` `` `(` $rhs_buffer `offset` $rhs_offset `row_stride` $rhs_row_stride `:` type($rhs_buffer)`)`
`out` `` `(` $out_buffer `offset` $out_offset `row_stride` $out_row_stride `:` type($out_buffer) `)`
`size` `` `(` $m `,` $n `,` $k `)`
`scale` `` `(` $alpha `:` type($alpha) `,` $beta `:` type($beta) `)`
`flags` `` `(` $flags `)`
attr-dict
}];
}
These "microkernel" ops are not final and will be elaborated in followon work. While they are inspired from various sources and work on the related topics, they have been ultimately hand crafted to satisfy the following design goals:
vm.call
to an appropriate micro-kernel.Remember that because of where these operations are generated and intended to be of maximum use, their operands have likely already been heavily tiled, and they are being run as part of fused, grid dispatches. As such, while they can be implemented in terms of very high level op libraries, such libraries are often also doing those things and are overkill: the intent is for these implementations to shine when the "inner parts" of such libraries are used directly. This gives the microkernel implementor freedom to specialize the innermost loops of already optimized linear algebra operations while leaving the high level optimizations done to the entire program to the higher level compiler.
Further, since they are an optimization, the program is still legal without identifying microkernels: we just expect that over time, we end up with a semi-comprehensive library of these primitives so that performance sensitive parts of ML models are completely covered by fast paths.
It is the job of the overall compiler heuristics to attempt to fuse, tile and memory plan vmvx-containing dispatch regions such that they operate on reasonably good trade-off points that are common in practice (i.e. tile sizes that are within some multiple of typical L1 cache sizes for a class of targets, as an example). We expect this to be an evolving area of work, but it is not dissimilar to the decisions that every backend must make when applying its high level optimizations.
The current VMVX runtime is fully interpreted. While this is fine for simple things (indexing math, loops, a few calls, etc), it isn't great. However, we do expect that for many programs, it will be good enough and on relative par with existing, interpreted op-by-op tensor executors: it may be doing a bit more work in aggregate, but it is much more comprehensively threaded and has very low code size -- which yields non-obvious cache benefits. We intend it to be useful as both a reference and a fallback for portable cases that have no other option.
To give a hint of overheads in the current implementation (on a several year old AMD ThreadRipper):
vm.func @loop_sum(%count : i32) -> i32 {
%c1 = vm.const.i32 1
%i0 = vm.const.i32.zero
vm.br ^loop(%i0 : i32)
^loop(%i : i32):
%in = vm.add.i32 %i, %c1 : i32
%cmp = vm.cmp.lt.i32.s %in, %count : i32
vm.cond_br %cmp, ^loop(%in : i32), ^loop_exit(%in : i32)
^loop_exit(%ie : i32):
vm.return %ie : i32
}
vm.func @buffer_reduce(%count : i32) -> i32 {
%c0 = vm.const.i64.zero
%c0_i32 = vm.const.i32.zero
%c1 = vm.const.i32 1
%c4 = vm.const.i32 4
%max = vm.mul.i32 %count, %c4 : i32
%max_i64 = vm.ext.i32.i64.u %max : i32 -> i64
%buf = vm.buffer.alloc %max_i64 : !vm.buffer
vm.buffer.fill.i32 %buf, %c0, %max_i64, %c1 : i32 -> !vm.buffer
vm.br ^loop(%c0_i32, %c0_i32 : i32, i32)
^loop(%i : i32, %sum : i32):
%i_i64 = vm.ext.i32.i64.u %i : i32 -> i64
%element = vm.buffer.load.i32 %buf[%i_i64] : !vm.buffer -> i32
%new_sum = vm.add.i32 %sum, %element : i32
%ip4 = vm.add.i32 %i, %c4 : i32
%cmp = vm.cmp.lt.i32.s %ip4, %max : i32
vm.cond_br %cmp, ^loop(%ip4, %new_sum : i32, i32), ^loop_exit(%new_sum : i32)
^loop_exit(%result : i32):
vm.return %result : i32
}
NOTE: This case has not been optimized very much.
vm.func @call_imported_func(%arg0 : i32) -> i32 {
%0 = vm.call @native_import_module.add_1(%arg0) : (i32) -> i32
%1 = vm.call @native_import_module.add_1(%0) : (i32) -> i32
%2 = vm.call @native_import_module.add_1(%1) : (i32) -> i32
%3 = vm.call @native_import_module.add_1(%2) : (i32) -> i32
%4 = vm.call @native_import_module.add_1(%3) : (i32) -> i32
%5 = vm.call @native_import_module.add_1(%4) : (i32) -> i32
%6 = vm.call @native_import_module.add_1(%5) : (i32) -> i32
%7 = vm.call @native_import_module.add_1(%6) : (i32) -> i32
%8 = vm.call @native_import_module.add_1(%7) : (i32) -> i32
%9 = vm.call @native_import_module.add_1(%8) : (i32) -> i32
%10 = vm.call @native_import_module.add_1(%9) : (i32) -> i32
%11 = vm.call @native_import_module.add_1(%10) : (i32) -> i32
%12 = vm.call @native_import_module.add_1(%11) : (i32) -> i32
%13 = vm.call @native_import_module.add_1(%12) : (i32) -> i32
%14 = vm.call @native_import_module.add_1(%13) : (i32) -> i32
%15 = vm.call @native_import_module.add_1(%14) : (i32) -> i32
%16 = vm.call @native_import_module.add_1(%15) : (i32) -> i32
%17 = vm.call @native_import_module.add_1(%16) : (i32) -> i32
%18 = vm.call @native_import_module.add_1(%17) : (i32) -> i32
%19 = vm.call @native_import_module.add_1(%18) : (i32) -> i32
%20 = vm.call @native_import_module.add_1(%19) : (i32) -> i32
vm.return %20 : i32
}
There are two additional avenues that likely have merit for extending the VMVX runtime:
As mentioned previously, we expect that the constraints of the VM code emitter are such that a really light weight JIT could be created to on the fly generate binaries from VMVX kernels. At runtime, these kernels are already serialized to a stable flatbuffer format with embedded constants and bytecode based program structure. A JIT at this level could just focus on straight-line code emission with a few peephole optimizations:
vm.call
for known "intrinsic functions" for microkernels, optimizing across the call boundary.We have seen such JITs in the wild be table based and consist of small thousands of lines of unsurprising code. The result should be pretty efficient, and with the inlining support mentioned above, would enable some fairly powerful but simple microkernel specializations to be emitted. For example:
These were previously used by the VM auther when they wrote xenia with similar constraints and goals, and the result was successful at achieving smooth JITing of PPC XBox360 AAA games on x86. Much of the thinking that proved successful there is baked into the design decisions here.
See an example of a VMVX kernel today, which dispatches to microkernels to perform a fused 64x64 matmul and broadcasted add. Note that this is a prototype intended to show the relative complexity ceiling (also note that this kernel is not typical in whole programs, which typically fuse the add and fill according to different heuristics):
module attributes {vm.toplevel} {
vm.module public @module {
vm.rodata private @__constant_384x64xf32 dense<7.000000e+00> : tensor<24576xf32>
vm.import @vmvx.add.2d.f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %size : tuple<i64, i64>) attributes {sym_visibility = "private"}
vm.import @vmvx.fill.2d.x32(%fill_value : i32, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %size_m : i64, %size_n : i64) attributes {sym_visibility = "private"}
vm.import @vmvx.matmul.f32f32f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %alpha : f32, %beta : f32, %flags : i32) attributes {sym_visibility = "private"}
vm.func private @tensor_float_dispatch_0(%arg0: !vm.buffer, %arg1: !vm.buffer, %arg2: !vm.list<!vm.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
%c1 = vm.const.i64 1
%zero = vm.const.i64.zero
%c384 = vm.const.i64 384
%c64 = vm.const.i64 64
%c128 = vm.const.i64 128
%c2 = vm.const.i32 2
%c384_0 = vm.const.i32 384
%0 = vm.const.f32 1.000000e+00
%c1_1 = vm.const.i32 1
%c64_2 = vm.const.i32 64
%c512 = vm.const.i32 512
%c128_3 = vm.const.i32 128
%zero_4 = vm.const.i32.zero
%__constant_384x64xf32 = vm.const.ref.rodata @__constant_384x64xf32 : !vm.buffer
%buffer = vm.list.get.ref %arg2, %zero_4 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
%buffer_5 = vm.list.get.ref %arg2, %c1_1 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
%buffer_6 = vm.list.get.ref %arg2, %c2 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
%1 = vm.mul.i32 %arg4, %c64_2 : i32
%2 = vm.mul.i32 %arg10, %c64_2 : i32
%3 = vm.mul.i32 %arg3, %c64_2 : i32
%4 = vm.mul.i32 %arg9, %c64_2 : i32
vm.br ^bb1(%1 : i32)
^bb1(%5: i32): // 2 preds: ^bb0, ^bb5
%slt = vm.cmp.lt.i32.s %5, %c512 : i32
vm.cond_br %slt, ^bb2, ^bb6
^bb2: // pred: ^bb1
%6 = vm.mul.i32 %5, %c128_3 : i32
%7 = vm.mul.i32 %5, %c384_0 : i32
vm.br ^bb3(%3 : i32)
^bb3(%8: i32): // 2 preds: ^bb2, ^bb4
%slt_7 = vm.cmp.lt.i32.s %8, %c128_3 : i32
vm.cond_br %slt_7, ^bb4, ^bb5
^bb4: // pred: ^bb3
%9 = vm.add.i32 %6, %8 : i32
%10 = vm.ext.i32.i64.s %9 : i32 -> i64
vm.call @vmvx.fill.2d.x32(%zero_4, %buffer_6, %10, %c128, %c64, %c64) : (i32, !vm.buffer, i64, i64, i64, i64) -> ()
%11 = vm.ext.i32.i64.s %7 : i32 -> i64
vm.call @vmvx.matmul.f32f32f32(%buffer, %11, %c384, %__constant_384x64xf32, %zero, %c64, %buffer_6, %10, %c128, %c64, %c64, %c384, %0, %0, %zero_4) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, f32, f32, i32) -> ()
%12 = vm.ext.i32.i64.s %8 : i32 -> i64
vm.call @vmvx.add.2d.f32(%buffer_5, %12, %zero, %c1, %buffer_6, %10, %c128, %c1, %buffer_6, %10, %c128, %c1, %c64, %c64) : (!vm.buffer, i64, i64, i64, !vm.buffer, i64, i64, i64, !vm.buffer, i64, i64, i64, i64, i64) -> ()
%13 = vm.add.i32 %8, %4 : i32
vm.br ^bb3(%13 : i32)
^bb5: // pred: ^bb3
%14 = vm.add.i32 %5, %2 : i32
vm.br ^bb1(%14 : i32)
^bb6: // pred: ^bb1
vm.return
}
vm.export @tensor_float_dispatch_0
}
}
It is likely profitable on some systems to treat the VMVX as just another code generation methodology (alongside existing expert compilers) when targeting CPU. In this case, we would not be emitting VMVX bytecode but would be AOT compiling the microkernels via the existing builtins libdevice library import support (currently used for math calls, datatype converters, etc). This largely impacts how the microkernel implementations are managed: if they are included in the builtins library in addition to the runtime/interpreter/jit layer, then this should just all happen seamlessly and support microkernel based code emission. This would allow profitability decisions to be made between VMVX based and normal MLIR/vector based code generation.
While the initial prototype demonstrates a ~200X improvement over stock VMVX, just using loopy reference kernels, this is shooting pretty low :) There is a fair amount of work to make this great -- much of it incremental or separable. We believe that we aren't far from having performant/portable VMVX microkernels for a number of important workloads. We will proceed in a few steps:
memref
improvements (mostly moving things that are bundled into the LLVM dialect conversion into MLIR proper, so that non-LLVM bound lowerings can use them).I plan to keep working on #1-3 and then will be looking for help (or for someone to take outright responsibility) on the rest.
--
You received this message because you are subscribed to the Google Groups "iree-discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to iree-discuss...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/iree-discuss/CAEkedjgu0NzMk5-5sS0LqyhW_mmKtNfyN3OmPwV41Ei%2BLYz%2B4A%40mail.gmail.com.
The overall direction here sounds good to me, and thanks for the thorough background information.It looks like the VMVX microkernel implementations will be built into the runtime (using the same compiler toolchain as the rest of the runtime, i.e. clang/gcc/etc.), rather than into individual program executables. How much do we expect they will contribute to the runtime's binary size? Will the "extensions" part of VMVX be optional? I see some complexity lurking around "building the runtime with just the ops that you need" (like https://www.tensorflow.org/lite/guide/reduce_binary_size). Since we'll be doing more work in the compiler than in more traditional "kernel library" runtimes, we can keep the microkernel surface area smaller, but do we have some threshold in mind for how small/large is acceptable?
To view this discussion on the web visit https://groups.google.com/d/msgid/iree-discuss/CAF-j_jfVMHQi5zrFYTU3ZFw9Zg4qB9%3D7kkzwhZ5WZLiDDVr6Cw%40mail.gmail.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/iree-discuss/CAGS7Hzbw0q7hVLHFaqZV1m%2BRC%3DayVmABjJB3JiHtVHw9mbKmmg%40mail.gmail.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/iree-discuss/CAOB_hFTiVUAzxgZy-pTA5PC1vLxQijEJ5%3DgORa4ni4H0DD%3DJhg%40mail.gmail.com.
This overall sounds good to me. I do think we want to preserve the property that VMVX is usable as a reference backend, which means being "obviously correct", simple and easily debuggable. Therefore, I think we want this to be guarded behind a flag (even if we flip that flag's default), such that we can always choose to pay a performance penalty and get the "obviously correct" scalar code. Right now, VMVX's usefulness as a reference backend is actually limited by its speed though, so especially this level of "a few kernels make it 200x faster" seems excellent. But I think we will necessarily come to a point where the goals of it being a reference backend and the goals of it being a usable production backend will come into conflict, so I think we really want to maintain the optionality there so we can have both. It sounds like that is very much the plan, but I've also seen us drop support for things in kind of a big bang way before, and I'm be slightly worried about us setting ourselves up for feeling like we need to do that here for the fully-scalar path as well.Useful functionality we might be able to have in the future would be the ability to only include (or only exclude) some subset of microkernels, such that we could debug issues with them without having to take the full performance hit of dropping everything back to scalar code.
func.func @tensor_float() -> tensor<512x128xf32> {%0 = util.unfoldable_constant dense<3.0> : tensor<512x384xf32>%1 = arith.constant dense<7.0> : tensor<128x384xf32>%2 = util.unfoldable_constant dense<1.0> : tensor<128xf32>%result = "tosa.fully_connected"(%0, %1, %2) : (tensor<512x384xf32>, tensor<128x384xf32>, tensor<128xf32>) -> tensor<512x128xf32>return %result : tensor<512x128xf32>}