RFC: Evolving VMVX to a portable, performant and jittable backend

371 views
Skip to first unread message

Stella Laurenzo

unread,
Jul 6, 2022, 8:45:02 PM7/6/22
to iree-d...@googlegroups.com

(don't like reading long markdown emails: read at the source)

Most people who come to IREE ask a couple of the same questions about the design:

  • Is VMVX just a reference implementation or is there a bigger plan there?
  • What is up with the VM dialect and implementation? Is that really needed? Why not just lower directly to X?

The two questions are linked and I'll try to elaborate the design behind VMVX and project where I think it should be going. I've done some initial implementation work on this in the PR Lower VMVX Linalg ops to microkernels. Most of this is following on the careful work that Ben Vanik laid out at the beginning of the project but we have not yet taken the time to elaborate.

Some background

For those unfamiliar, IREE has multiple code generation backends (referred to as HAL target backend in the code). These generate code compatible with some set of runtime drivers that are responsible for loading and scheduling the work. In most cases, there are easy correlations:

  • The vulkan runtime driver handles executables generated by the vulkan-spirv compiler HAL target.
  • The local-sync runtime driver handles synchronous CPU exection for the CPU-based llvm-aot HAL target.
    • The local-task runtime driver handles CPU multi-threaded execution.
  • The cuda runtime driver handles
  • etc

The local-* runtime drivers can execute modules that contain VMVX kernels as well. This is used today as part of test suites and the compiler itself uses it to compile constant expressions for generalized constant evaluation at compile time. Is it just a reference implementation?

Not really. It is actually a full CPU backend which does tiling and distribution to threads in a similar fashion to the other backends, but it stops lowering after bufferizing linalg ops, instead taking a path which performs simplification/cleanup and ultimately emits VM bytecode for the loopy, lowered program. Today, this always lowers completely to scalar code. As expected, this is not particularly fast (i.e. in the default mode, it is both interpreted and at the scalar level). While the code it generates is very slow, it is also:

  • Very simple: The lowering to scalar form is the most basic use of the compiler infra and has some value as a reference.
  • Post memory planning: Buffers have been assigned, layouts have been optimized, etc (actually somewhat limited today since this target is only triggering platform independent heuristics, but can be more).
  • Parallel: like the other targets, the program is fused/broken into tiles and grid scheduled. On multi-threaded drivers (i.e. local-task), this achieves generally good parallelism and uses similar heuristics as its more industrial strength siblings.

Talk to me about the "VM" in "VMVX"

IREE's VM is both a compile time target and multiple execution modalities. At compile time, it is represented by the VM Dialect. While it bears some similarity to LLVM IR, it is important to look at both the points of ommission and deviation:

  • Non-polymorphic op names: The op name encodes the entire semantics of the op (i.e. no attributes).
  • High level runtime types:
    • Exposes built-in types/ops for buffers (!vm.buffer)
    • Lists (!vm.list)
    • Pluggable ref-counting type bridging (the VM provides a rich facility for integrating custom types and passing them across ABI boundaries)
  • Function import/export: VM modules exist in a namespace and can be integrated at runtime via function import/export
  • No pointers: Careful use of high level types and ops lets us express real programs without pointers.
  • Other misc features:
    • Co-routine primitives
    • Debugging
    • etc

While it is not uncommon to implement the above features in various lowerings to LLVM IR, the resulting program is very low level and, basically, is only suitable for use by an LLVM backend. By stopping at this intermediate level, we have an IR that is still:

  • Reasonably easy to verify/augment for safety in heightened security contexts.
  • Reasonably efficient to interpret: we're actually somewhat shocked that IREE's VM interpreter has survived as long as it has as the default way to run IREE host programs. It hasn't yet been a real bottleneck, and when it is, a jitter can take it the rest of the way.
  • Structured with respect to concurrency control: co-operative execution semantics can still be extracted and handled at this level.
  • Able to be cleanly lowered to C code for targets that prefer that integration modality (or perhaps are missing a compatible LLVM backend).
  • Relatively portable between host/device: modern approaches to targeting devices often rely on having a level of host program representation that is relatively easy to re-target to run parts of on device. The VM is constrained enough that this can be done with a handful of patterns as needed (versus needing to reason about LLVM->* style things).
  • Amenable to JITing: IREE started its life as an embedded/mobile solution. In this model, the host program may need to be run with or without JITing, based on runtime constraints. In all such scenarios where JITing is appropriate, we want on-device JITers to be exceptionally light "splat jitters" that are basically doing some light-weight register allocation, machine code emission and intrinsic function expansion. Already being in a limited, optimized form enables this use case with cheap JIT solutions that are far more simplistic than a full LLVM backend.

Talk to me about the "VX" in "VMVX"

The "VX" stands for "Vector Extensions" -- which is probably somewhat troubling to folks, considering that the current implementation only lowers to scalars :)

The dialect README talks about how to add ops, but in fact, the dialect is empty and little guidance is given on where this is going. Upstream terminology and layers of abstraction have shifted a bit and the above README needs some polishing. The primary guidance is: "The operations added here are modeled as close to a machine ISA as reasonable, meaning that there are no shapes, element types are encoded as part of the operations, and memory access is tightly restricted."

The work described from here enunciates the next steps and where this can go.

Better defining the VMVX dialect

As a first set of extensions, the above PR adds four ops to VMVX and implements loopy reference kernels for them:

def VMVX_AddOp : VMVX_Op<"add", [SameVariadicOperandSize]> {
  let summary = "Performs a strided elementwise add of two same-rank buffers";
  let description = [{
    Performs addition in-place as if:
      OUT = LHS + RHS

    All operands have the same rank.
  }];
  let arguments = (ins
    // LHS.
    VMVX_Buffer:$lhs_buffer,
    VMVX_Index:$lhs_offset,
    Variadic<VMVX_Index>:$lhs_strides,
    // RHS.
    VMVX_Buffer:$rhs_buffer,
    VMVX_Index:$rhs_offset,
    Variadic<VMVX_Index>:$rhs_strides,
    // OUT.
    VMVX_Buffer:$out_buffer,
    VMVX_Index:$out_offset,
    Variadic<VMVX_Index>:$out_strides,

    // Dimensions.
    Variadic<VMVX_Index>:$size
  );

  let assemblyFormat = [{
    `lhs` `` `(` $lhs_buffer `offset` $lhs_offset `strides` `[` $lhs_strides `]` `:` type($lhs_buffer) `)`
    `rhs` `` `(` $rhs_buffer `offset` $rhs_offset `strides` `[` $rhs_strides `]` `:` type($rhs_buffer) `)`
    `out` `` `(` $out_buffer `offset` $out_offset `strides` `[` $out_strides `]` `:` type($out_buffer) `)`
    `size` `` `(` $size `)`
    attr-dict
  }];
}

def VMVX_CopyOp : VMVX_Op<"copy", [SameVariadicOperandSize]> {
  let summary = "Copy from one buffer to another";
  let arguments = (ins
    // LHS.
    VMVX_Buffer:$inp_buffer,
    VMVX_Index:$inp_offset,
    Variadic<VMVX_Index>:$inp_strides,
    // OUT.
    VMVX_Buffer:$out_buffer,
    VMVX_Index:$out_offset,
    Variadic<VMVX_Index>:$out_strides,

    // Dimensions.
    Variadic<VMVX_Index>:$size
  );
  let assemblyFormat = [{
    `inp` `` `(` $inp_buffer `offset` $inp_offset `strides` `[` $inp_strides `]` `:` type($inp_buffer) `)`
    `out` `` `(` $out_buffer `offset` $out_offset `strides` `[` $out_strides `]` `:` type($out_buffer) `)`
    `size` `` `(` $size `)`
    attr-dict
  }];
}

def VMVX_Fill2DOp : VMVX_Op<"fill2d"> {
  let summary = "Fill a tile with a scalar";
  let description = [{
    Fills a tile with dimensions [m, n] with a scalar.
  }];
  let arguments = (ins
    VMVX_ElementType:$scalar,
    VMVX_Buffer:$out_buffer,
    VMVX_Index:$out_offset,
    VMVX_Index:$out_row_stride,

    // Dimensions.
    VMVX_Index:$m,
    VMVX_Index:$n
  );

  let assemblyFormat = [{
    `scalar` `` `(` $scalar `:` type($scalar) `)`
    `out` `` `(` $out_buffer `offset` $out_offset `row_stride` $out_row_stride `:` type($out_buffer) `)`
    `size` `` `(` $m `,` $n `)`
    attr-dict
  }];
}

def VMVX_MatmulOp : VMVX_Op<"matmul"> {
  let summary = "Matmul";
  let description = [{
    General matrix-multiply of the form:

      OUT = alpha * (LHS * RHS) + beta * OUT
  }];
  let arguments = (ins
    // Lhs buffer.
    VMVX_Buffer:$lhs_buffer,
    VMVX_Index:$lhs_offset,
    VMVX_Index:$lhs_row_stride,
    // Rhs buffer.
    VMVX_Buffer:$rhs_buffer,
    VMVX_Index:$rhs_offset,
    VMVX_Index:$rhs_row_stride,
    // Out buffer.
    VMVX_Buffer:$out_buffer,
    VMVX_Index:$out_offset,
    VMVX_Index:$out_row_stride,

    // Dimensions.
    VMVX_Index:$m,
    VMVX_Index:$n,
    VMVX_Index:$k,

    // Scale factors.
    VMVX_ElementType:$alpha,
    VMVX_ElementType:$beta,

    // Execution flags.
    I32Attr:$flags
  );

  let assemblyFormat = [{
    `lhs` `` `(` $lhs_buffer `offset` $lhs_offset `row_stride` $lhs_row_stride `:` type($lhs_buffer) `)`
    `rhs` `` `(` $rhs_buffer `offset` $rhs_offset `row_stride` $rhs_row_stride `:` type($rhs_buffer)`)`
    `out` `` `(` $out_buffer `offset` $out_offset `row_stride` $out_row_stride `:` type($out_buffer) `)`
    `size` `` `(` $m `,` $n `,` $k `)`
    `scale` `` `(` $alpha `:` type($alpha) `,` $beta `:` type($beta) `)`
    `flags` `` `(` $flags `)`
    attr-dict
  }];
}

These "microkernel" ops are not final and will be elaborated in followon work. While they are inspired from various sources and work on the related topics, they have been ultimately hand crafted to satisfy the following design goals:

  • Unique signatures map to a monomorphic implementation: While various implementations may switch internally for selecting certain fast path cases, the code structure of the reference is fixed (i.e. same loops/index calculations/etc) based on name/attributes/types.
  • A systematic mapping to a concrete, named implementation microkernel can be extracted from the signature alone, and this will result in a vm.call to an appropriate micro-kernel.
  • All values needed to perform the operation are expressed as operands of simple buffers, index types and scalars.
  • Not geared for transformation but possible to do very local, peephole optimizations (i.e. elide a fill when zero and consumed by a matmul which only assigns, for example).
  • Strides are used liberally to implement common nd-array transformations as part of the operation.
  • When there is a benefit to highly specializing variants (i.e. different conv microkernel sizes and explicit specialization for 2D variants common in tiled programs are the usual suspects), we expect multiple named variants of the microkernel to be defined and selected as possible. Microkernels should be seen as relatively cheap from this perspective, and since they represent fundamental units of linear algebra programs, we expect the number of them to be asymptotically reasonable over time.

Remember that because of where these operations are generated and intended to be of maximum use, their operands have likely already been heavily tiled, and they are being run as part of fused, grid dispatches. As such, while they can be implemented in terms of very high level op libraries, such libraries are often also doing those things and are overkill: the intent is for these implementations to shine when the "inner parts" of such libraries are used directly. This gives the microkernel implementor freedom to specialize the innermost loops of already optimized linear algebra operations while leaving the high level optimizations done to the entire program to the higher level compiler.

Further, since they are an optimization, the program is still legal without identifying microkernels: we just expect that over time, we end up with a semi-comprehensive library of these primitives so that performance sensitive parts of ML models are completely covered by fast paths.

It is the job of the overall compiler heuristics to attempt to fuse, tile and memory plan vmvx-containing dispatch regions such that they operate on reasonably good trade-off points that are common in practice (i.e. tile sizes that are within some multiple of typical L1 cache sizes for a class of targets, as an example). We expect this to be an evolving area of work, but it is not dissimilar to the decisions that every backend must make when applying its high level optimizations.

Future work on the VMVX runtime

The current VMVX runtime is fully interpreted. While this is fine for simple things (indexing math, loops, a few calls, etc), it isn't great. However, we do expect that for many programs, it will be good enough and on relative par with existing, interpreted op-by-op tensor executors: it may be doing a bit more work in aggregate, but it is much more comprehensively threaded and has very low code size -- which yields non-obvious cache benefits. We intend it to be useful as both a reference and a fallback for portable cases that have no other option.

VM Interpreter Overheads:

To give a hint of overheads in the current implementation (on a several year old AMD ThreadRipper):

Simple loopy-sum: 5.05ns/iteration interpreted vs 2.3ns C

  vm.func @loop_sum(%count : i32) -> i32 {
    %c1 = vm.const.i32 1
    %i0 = vm.const.i32.zero
    vm.br ^loop(%i0 : i32)
  ^loop(%i : i32):
    %in = vm.add.i32 %i, %c1 : i32
    %cmp = vm.cmp.lt.i32.s %in, %count : i32
    vm.cond_br %cmp, ^loop(%in : i32), ^loop_exit(%in : i32)
  ^loop_exit(%ie : i32):
    vm.return %ie : i32
  }

Buffer reduce: 12.1ns/element interpreted vs 2.19ns/element C

  vm.func @buffer_reduce(%count : i32) -> i32 {
    %c0 = vm.const.i64.zero
    %c0_i32 = vm.const.i32.zero
    %c1 = vm.const.i32 1
    %c4 = vm.const.i32 4
    %max = vm.mul.i32 %count, %c4 : i32
    %max_i64 = vm.ext.i32.i64.u %max : i32 -> i64
    %buf = vm.buffer.alloc %max_i64 : !vm.buffer
    vm.buffer.fill.i32 %buf, %c0, %max_i64, %c1 : i32 -> !vm.buffer
    vm.br ^loop(%c0_i32, %c0_i32 : i32, i32)
  ^loop(%i : i32, %sum : i32):
    %i_i64 = vm.ext.i32.i64.u %i : i32 -> i64
    %element = vm.buffer.load.i32 %buf[%i_i64] : !vm.buffer -> i32
    %new_sum = vm.add.i32 %sum, %element : i32
    %ip4 = vm.add.i32 %i, %c4 : i32
    %cmp = vm.cmp.lt.i32.s %ip4, %max : i32
    vm.cond_br %cmp, ^loop(%ip4, %new_sum : i32, i32), ^loop_exit(%new_sum : i32)
  ^loop_exit(%result : i32):
    vm.return %result : i32
  }

Call to imported func: 51.6ns/call interpreted, 1.61ns/call C

NOTE: This case has not been optimized very much.

  vm.func @call_imported_func(%arg0 : i32) -> i32 {
    %0 = vm.call @native_import_module.add_1(%arg0) : (i32) -> i32
    %1 = vm.call @native_import_module.add_1(%0) : (i32) -> i32
    %2 = vm.call @native_import_module.add_1(%1) : (i32) -> i32
    %3 = vm.call @native_import_module.add_1(%2) : (i32) -> i32
    %4 = vm.call @native_import_module.add_1(%3) : (i32) -> i32
    %5 = vm.call @native_import_module.add_1(%4) : (i32) -> i32
    %6 = vm.call @native_import_module.add_1(%5) : (i32) -> i32
    %7 = vm.call @native_import_module.add_1(%6) : (i32) -> i32
    %8 = vm.call @native_import_module.add_1(%7) : (i32) -> i32
    %9 = vm.call @native_import_module.add_1(%8) : (i32) -> i32
    %10 = vm.call @native_import_module.add_1(%9) : (i32) -> i32
    %11 = vm.call @native_import_module.add_1(%10) : (i32) -> i32
    %12 = vm.call @native_import_module.add_1(%11) : (i32) -> i32
    %13 = vm.call @native_import_module.add_1(%12) : (i32) -> i32
    %14 = vm.call @native_import_module.add_1(%13) : (i32) -> i32
    %15 = vm.call @native_import_module.add_1(%14) : (i32) -> i32
    %16 = vm.call @native_import_module.add_1(%15) : (i32) -> i32
    %17 = vm.call @native_import_module.add_1(%16) : (i32) -> i32
    %18 = vm.call @native_import_module.add_1(%17) : (i32) -> i32
    %19 = vm.call @native_import_module.add_1(%18) : (i32) -> i32
    %20 = vm.call @native_import_module.add_1(%19) : (i32) -> i32
    vm.return %20 : i32
  }

Future work

There are two additional avenues that likely have merit for extending the VMVX runtime:

  • Integration with the C emission framework: With some work to extend the C code emission to VMVX, combined C code for the VM/host and VMVX/device-kernels could be generated, requiring nothing more than a C compiler and a header with inline implementations of used microkernels. This will likely be great for DSPs and other embedded targets who would benefit from C and a bring-your-own-microkernel strategy.
  • On the fly jitting of VMVX modules to native code (more below)
  • Full VMVX AOT with builtins (more below)

On the Fly Jitting

As mentioned previously, we expect that the constraints of the VM code emitter are such that a really light weight JIT could be created to on the fly generate binaries from VMVX kernels. At runtime, these kernels are already serialized to a stable flatbuffer format with embedded constants and bytecode based program structure. A JIT at this level could just focus on straight-line code emission with a few peephole optimizations:

  • Inlining of vm.call for known "intrinsic functions" for microkernels, optimizing across the call boundary.
  • Promotion elimination of some dynamic range checks.

We have seen such JITs in the wild be table based and consist of small thousands of lines of unsurprising code. The result should be pretty efficient, and with the inlining support mentioned above, would enable some fairly powerful but simple microkernel specializations to be emitted. For example:

These were previously used by the VM auther when they wrote xenia with similar constraints and goals, and the result was successful at achieving smooth JITing of PPC XBox360 AAA games on x86. Much of the thinking that proved successful there is baked into the design decisions here.

See an example of a VMVX kernel today, which dispatches to microkernels to perform a fused 64x64 matmul and broadcasted add. Note that this is a prototype intended to show the relative complexity ceiling (also note that this kernel is not typical in whole programs, which typically fuse the add and fill according to different heuristics):

module attributes {vm.toplevel} {
  vm.module public @module {
    vm.rodata private @__constant_384x64xf32 dense<7.000000e+00> : tensor<24576xf32>
    vm.import @vmvx.add.2d.f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_strides : tuple<i64, i64>, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_strides : tuple<i64, i64>, %out_buffer : !vm.buffer, %out_offset : i64, %out_strides : tuple<i64, i64>, %size : tuple<i64, i64>) attributes {sym_visibility = "private"}
    vm.import @vmvx.fill.2d.x32(%fill_value : i32, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %size_m : i64, %size_n : i64) attributes {sym_visibility = "private"}
    vm.import @vmvx.matmul.f32f32f32(%lhs_buffer : !vm.buffer, %lhs_offset : i64, %lhs_row_stride : i64, %rhs_buffer : !vm.buffer, %rhs_offset : i64, %rhs_row_stride : i64, %out_buffer : !vm.buffer, %out_offset : i64, %out_row_stride : i64, %m : i64, %n : i64, %k : i64, %alpha : f32, %beta : f32, %flags : i32) attributes {sym_visibility = "private"}
    vm.func private @tensor_float_dispatch_0(%arg0: !vm.buffer, %arg1: !vm.buffer, %arg2: !vm.list<!vm.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32) {
      %c1 = vm.const.i64 1
      %zero = vm.const.i64.zero
      %c384 = vm.const.i64 384
      %c64 = vm.const.i64 64
      %c128 = vm.const.i64 128
      %c2 = vm.const.i32 2
      %c384_0 = vm.const.i32 384
      %0 = vm.const.f32 1.000000e+00
      %c1_1 = vm.const.i32 1
      %c64_2 = vm.const.i32 64
      %c512 = vm.const.i32 512
      %c128_3 = vm.const.i32 128
      %zero_4 = vm.const.i32.zero
      %__constant_384x64xf32 = vm.const.ref.rodata @__constant_384x64xf32 : !vm.buffer
      %buffer = vm.list.get.ref %arg2, %zero_4 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
      %buffer_5 = vm.list.get.ref %arg2, %c1_1 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
      %buffer_6 = vm.list.get.ref %arg2, %c2 : (!vm.list<!vm.buffer>, i32) -> !vm.buffer
      %1 = vm.mul.i32 %arg4, %c64_2 : i32
      %2 = vm.mul.i32 %arg10, %c64_2 : i32
      %3 = vm.mul.i32 %arg3, %c64_2 : i32
      %4 = vm.mul.i32 %arg9, %c64_2 : i32
      vm.br ^bb1(%1 : i32)
    ^bb1(%5: i32):  // 2 preds: ^bb0, ^bb5
      %slt = vm.cmp.lt.i32.s %5, %c512 : i32
      vm.cond_br %slt, ^bb2, ^bb6
    ^bb2:  // pred: ^bb1
      %6 = vm.mul.i32 %5, %c128_3 : i32
      %7 = vm.mul.i32 %5, %c384_0 : i32
      vm.br ^bb3(%3 : i32)
    ^bb3(%8: i32):  // 2 preds: ^bb2, ^bb4
      %slt_7 = vm.cmp.lt.i32.s %8, %c128_3 : i32
      vm.cond_br %slt_7, ^bb4, ^bb5
    ^bb4:  // pred: ^bb3
      %9 = vm.add.i32 %6, %8 : i32
      %10 = vm.ext.i32.i64.s %9 : i32 -> i64
      vm.call @vmvx.fill.2d.x32(%zero_4, %buffer_6, %10, %c128, %c64, %c64) : (i32, !vm.buffer, i64, i64, i64, i64) -> ()
      %11 = vm.ext.i32.i64.s %7 : i32 -> i64
      vm.call @vmvx.matmul.f32f32f32(%buffer, %11, %c384, %__constant_384x64xf32, %zero, %c64, %buffer_6, %10, %c128, %c64, %c64, %c384, %0, %0, %zero_4) : (!vm.buffer, i64, i64, !vm.buffer, i64, i64, !vm.buffer, i64, i64, i64, i64, i64, f32, f32, i32) -> ()
      %12 = vm.ext.i32.i64.s %8 : i32 -> i64
      vm.call @vmvx.add.2d.f32(%buffer_5, %12, %zero, %c1, %buffer_6, %10, %c128, %c1, %buffer_6, %10, %c128, %c1, %c64, %c64) : (!vm.buffer, i64, i64, i64, !vm.buffer, i64, i64, i64, !vm.buffer, i64, i64, i64, i64, i64) -> ()
      %13 = vm.add.i32 %8, %4 : i32
      vm.br ^bb3(%13 : i32)
    ^bb5:  // pred: ^bb3
      %14 = vm.add.i32 %5, %2 : i32
      vm.br ^bb1(%14 : i32)
    ^bb6:  // pred: ^bb1
      vm.return
    }
    vm.export @tensor_float_dispatch_0
  }
}

Full VMVX AOT with Builtins

It is likely profitable on some systems to treat the VMVX as just another code generation methodology (alongside existing expert compilers) when targeting CPU. In this case, we would not be emitting VMVX bytecode but would be AOT compiling the microkernels via the existing builtins libdevice library import support (currently used for math calls, datatype converters, etc). This largely impacts how the microkernel implementations are managed: if they are included in the builtins library in addition to the runtime/interpreter/jit layer, then this should just all happen seamlessly and support microkernel based code emission. This would allow profitability decisions to be made between VMVX based and normal MLIR/vector based code generation.

Work Plan

While the initial prototype demonstrates a ~200X improvement over stock VMVX, just using loopy reference kernels, this is shooting pretty low :) There is a fair amount of work to make this great -- much of it incremental or separable. We believe that we aren't far from having performant/portable VMVX microkernels for a number of important workloads. We will proceed in a few steps:

  1. Landing the prototype VMVX microkernel support for f32 matmul.
  2. Iterating on some upstream memref improvements (mostly moving things that are bundled into the LLVM dialect conversion into MLIR proper, so that non-LLVM bound lowerings can use them).
  3. Identifying the handful of microkernels needed to get fast paths on common matmul and conv heavy models (i.e. BERT and a mobilenet TBD when getting here) and plumbing through.
  4. Implementing the VM(VX)Jit (HELP WANTED).
  5. Making VMVX a supported AOT compilation flow (HELP WANTED).
  6. Full elaboration/sweep/burn down of microkernel needs (HELP WANTED).
  7. Platform-specific, optimized implementations of key microkernels (HELP WANTED).

I plan to keep working on #1-3 and then will be looking for help (or for someone to take outright responsibility) on the rest.

Scott Todd

unread,
Jul 6, 2022, 10:16:05 PM7/6/22
to Stella Laurenzo, iree-discuss
The overall direction here sounds good to me, and thanks for the thorough background information.

It looks like the VMVX microkernel implementations will be built into the runtime (using the same compiler toolchain as the rest of the runtime, i.e. clang/gcc/etc.), rather than into individual program executables. How much do we expect they will contribute to the runtime's binary size? Will the "extensions" part of VMVX be optional? I see some complexity lurking around "building the runtime with just the ops that you need" (like https://www.tensorflow.org/lite/guide/reduce_binary_size). Since we'll be doing more work in the compiler than in more traditional "kernel library" runtimes, we can keep the microkernel surface area smaller, but do we have some threshold in mind for how small/large is acceptable?

--
You received this message because you are subscribed to the Google Groups "iree-discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to iree-discuss...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/iree-discuss/CAEkedjgu0NzMk5-5sS0LqyhW_mmKtNfyN3OmPwV41Ei%2BLYz%2B4A%40mail.gmail.com.

Stella Laurenzo

unread,
Jul 6, 2022, 10:24:47 PM7/6/22
to Scott Todd, Stella Laurenzo, iree-discuss
On Wed, Jul 6, 2022 at 7:16 PM 'Scott Todd' via iree-discuss <iree-d...@googlegroups.com> wrote:
The overall direction here sounds good to me, and thanks for the thorough background information.

It looks like the VMVX microkernel implementations will be built into the runtime (using the same compiler toolchain as the rest of the runtime, i.e. clang/gcc/etc.), rather than into individual program executables. How much do we expect they will contribute to the runtime's binary size? Will the "extensions" part of VMVX be optional? I see some complexity lurking around "building the runtime with just the ops that you need" (like https://www.tensorflow.org/lite/guide/reduce_binary_size). Since we'll be doing more work in the compiler than in more traditional "kernel library" runtimes, we can keep the microkernel surface area smaller, but do we have some threshold in mind for how small/large is acceptable?

We're going to have to see. As an intuition, I think this is going to be of an entirely different scale compared to a high level executor's typical op explosion problems. I think we're talking about a useful set of ~dozens before this becomes interesting, but then I don't know what the actual ceiling is. Ben wrote me up a strawman for how to do this so that the VMVX code we generate can emit both the microkernel *and* scalar callback code and switch at runtime based on import availability. Sounded totally feasible but also not something I want to frontload.

This is all behind a flag in the compiler that won't be enabled any time soon. I propose that we drive it to some min-viable set over the coming weeks and then assess based on what we've learned. It is pretty easy to just comment out most of the module.c file for checking size impacts (once we have something to check against).

So yes, complexity... but I think we can push a lot of it to the compiler and give the runtime some outs.
 

Ben Vanik

unread,
Jul 7, 2022, 9:34:18 AM7/7/22
to Stella Laurenzo, Scott Todd, Stella Laurenzo, iree-discuss
Exciting!

We can use optional imports for versioning (run new artifacts on old runtimes) and also configurable kernel sets, though because vmvx is just a VM module we could also allow large additions to be done in custom modules registered when the loader is created. Could imagine things like vmvx_image or whatnot that the user can choose to compile in or not. Hopefully we can keep things small enough that the decision of whether or not to include vmvx support at all (IREE_HAL_EXECUTABLE_LOADER_VMVX_MODULE) is usually all people have to decide. 

Geoffrey Martin-Noble

unread,
Jul 7, 2022, 1:20:36 PM7/7/22
to Ben Vanik, Stella Laurenzo, Scott Todd, Stella Laurenzo, iree-discuss
This overall sounds good to me. I do think we want to preserve the property that VMVX is usable as a reference backend, which means being "obviously correct", simple and easily debuggable. Therefore, I think we want this to be guarded behind a flag (even if we flip that flag's default), such that we can always choose to pay a performance penalty and get the "obviously correct" scalar code. Right now, VMVX's usefulness as a reference backend is actually limited by its speed though, so especially this level of "a few kernels make it 200x faster" seems excellent. But I think we will necessarily come to a point where the goals of it being a reference backend and the goals of it being a usable production backend will come into conflict, so I think we really want to maintain the optionality there so we can have both. It sounds like that is very much the plan, but I've also seen us drop support for things in kind of a big bang way before, and I'm be slightly worried about us setting ourselves up for feeling like we need to do that here for the fully-scalar path as well.

Useful functionality we might be able to have in the future would be the ability to only include (or only exclude) some subset of microkernels, such that we could debug issues with them without having to take the full performance hit of dropping everything back to scalar code.

Stella Laurenzo

unread,
Jul 7, 2022, 1:24:25 PM7/7/22
to Geoffrey Martin-Noble, Ben Vanik, Scott Todd, Stella Laurenzo, iree-discuss
On Thu, Jul 7, 2022 at 10:20 AM Geoffrey Martin-Noble <gc...@google.com> wrote:
This overall sounds good to me. I do think we want to preserve the property that VMVX is usable as a reference backend, which means being "obviously correct", simple and easily debuggable. Therefore, I think we want this to be guarded behind a flag (even if we flip that flag's default), such that we can always choose to pay a performance penalty and get the "obviously correct" scalar code. Right now, VMVX's usefulness as a reference backend is actually limited by its speed though, so especially this level of "a few kernels make it 200x faster" seems excellent. But I think we will necessarily come to a point where the goals of it being a reference backend and the goals of it being a usable production backend will come into conflict, so I think we really want to maintain the optionality there so we can have both. It sounds like that is very much the plan, but I've also seen us drop support for things in kind of a big bang way before, and I'm be slightly worried about us setting ourselves up for feeling like we need to do that here for the fully-scalar path as well.

Useful functionality we might be able to have in the future would be the ability to only include (or only exclude) some subset of microkernels, such that we could debug issues with them without having to take the full performance hit of dropping everything back to scalar code.

Agreed. It is a flag now (and will continue to be). If this is successful, there will probably be even more knobs as this tradeoff space is likely to be useful for folks and we will need to provide some level of control + reasonable defaults.

Stella Laurenzo

unread,
Jul 9, 2022, 11:35:58 PM7/9/22
to Stella Laurenzo, Nicolas Vasilache, Geoffrey Martin-Noble, Ben Vanik, Scott Todd, iree-discuss, benoi...@google.com
Ok, the first patch is in.The compiler side of it is still a little rough and needs some upstream work to be really good (that +Nicolas Vasilache) is helping with. It's flag protected and sufficient for e2e testing a couple of simple things.

Since Benoit was asking for some starting points to work on, here is some more information.

Here is the test program I've mostly been iterating on to get it going:

func.func @tensor_float() -> tensor<512x128xf32> {
  %0 = util.unfoldable_constant dense<3.0> : tensor<512x384xf32>
  %1 = arith.constant dense<7.0> : tensor<128x384xf32>
  %2 = util.unfoldable_constant dense<1.0> : tensor<128xf32>
  %result = "tosa.fully_connected"(%0, %1, %2) : (tensor<512x384xf32>, tensor<128x384xf32>, tensor<128xf32>) -> tensor<512x128xf32>
  return %result : tensor<512x128xf32>
}

 And a couple of comment lines to get going:

  • ./tools/iree-compile --iree-input-type=tosa --iree-hal-target-backends=vmvx ~/scratch/vmvx/fully_connected.mlir -o ~/scratch/vmvx/fully_connected.vmfb --mlir-disable-threading --mlir-print-ir-before=iree-vmvx-lower-linalg-microkernels --mlir-print-ir-after=iree-vmvx-lower-linalg-microkernels --iree-vmvx-enable-microkernels
  • ./tools/iree-run-module --module_file=$HOME/scratch/vmvx/fully_connected.vmfb --entry_function=tensor_float --device=local-sync://
Some file locations:



Reply all
Reply to author
Forward
0 new messages