Quick Start

This tutorial is for people who are new to Apache TVM Unity. Taking an simple example to show how to use Apache TVM Unity to compile a simple neural network.

Prepare the Neural Network Model

Before we get started, let’s prepare a neural network model first. In this tutorial, to make things simple, we will defined a two-layer MLP networks directly in this script. For people who are trying to run real models, please jump to the next section.

import torch
from torch import nn


class MLPModel(nn.Module):
    def __init__(self):
        super(MLPModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


torch_model = MLPModel()

Import Model into Apache TVM Unity

We choose PyTorch FX as our frontend. PyTorch FX is a toolkit for tracing PyTorch programs into a intermediate representation (IR) with symbolic shape support.

Note

Original PyTorch FX may not be compatible with HuggingFace Model. Please use HuggingFace self-defined FX to trace the model.

from tvm import relax
from tvm.relax.frontend.torch import from_fx
from torch import fx

torch_fx_model = fx.symbolic_trace(torch_model)

As the PyTorch model does not contain input information like in ONNX, we need to provide the input information ourselves. This includes the shape and data type of the input tensors, which are represented as a list of tuples. Each tuple contains the shape and data type of one input tensor.

In this particular example, the shape of the input tensor is (1, 784) and the data type is "float32". We combine the shape and data type in a tuple like ((1, 784), "float32"). Then we gather all the input tuples into a list, which looks like [((1, 784), "float32")].

input_info = [((1, 784), "float32")]

Use the Apache TVM Unity API to convert the PyTorch FX model into Relax Model. And print it out to in the TVMScript Syntax

with torch.no_grad():
    mod = from_fx(torch_fx_model, input_info)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
            lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
            lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
            lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
            lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
            lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][3])
            gv: R.Tensor((1, 10), dtype="float32") = lv6
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

Up to this point, we have successfully transformed the PyTorch FX model into a TVM IRModule. It is important to mention that the IRModule is the central abstraction of Apache TVM Unity, and it is utilized for subsequent transformations and optimization processes. The IRModule has the ability to hold both high-level graph IR (Relax) and low-level tensor IR (TensorIR). Currently, the IRModule solely consists of Relax functions, which are marked with the @R.function decorator.

Transform The Model

Apply Optimization Transforms

We can apply a variety of optimization transforms to the IRModule. We have predefined a set of optimization transforms to simplify their usage. By using the get_pipeline function, we can apply the default optimization flow. By following the default path, the following transformations will be applied in order:

  • LegalizeOps: This transform converts the Relax operators into call_tir functions with the corresponding TensorIR Functions. After this transform, the IRModule will contain both Relax functions and TensorIR functions.

  • AnnotateTIROpPattern: This transform annotates the pattern of the TensorIR functions, preparing them for subsequent operator fusion.

  • FoldConstant: This pass performs constant folding, optimizing operations involving constants.

  • FuseOps and FuseTIR: These two passes work together to fuse operators based on the patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform both Relax functions and TensorIR functions.

mod = relax.get_pipeline()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), param_0: T.Buffer((T.int64(256), T.int64(10)), "float32"), param_1: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv3[v_i0, v_k], param_0[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * param_0[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], param_1[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + param_1[v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), param_0: T.Buffer((T.int64(784), T.int64(256)), "float32"), param_1: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
        for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], param_1[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + param_1[v_ax1]
        for i0, i1 in T.grid(T.int64(1), T.int64(256)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0))

    @R.function
    def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.fused_matmul_add_relu, (inp_0, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

If you are only interested in the changes of the Relax functions and omit the TensorIR functions, print the main function of the IRModule.

mod["main"].show()
# from tvm.script import relax as R

@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
    with R.dataflow():
        lv = R.call_tir(fused_matmul_add_relu, (inp_0, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 256), dtype="float32"))
        gv = R.call_tir(fused_matmul1_add1, (lv, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
        R.output(gv)
    return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

Tensor Function Optimization

Usually we apply Tensor Function Optimization after the Relax Function Optimization, as graph transformations will changes the TIR functions. There are different ways to apply Tensor Function Optimization, we choose DLight on cuda target in this tutorial. Note that DLight is not the only way to optimize the Tensor Function, for other optimizations, please refer to corresponding tutorials.

import tvm
from tvm import dlight as dl

target = tvm.target.Target("cuda")

with target:
    mod = dl.ApplyDefaultSchedule(
        dl.gpu.Matmul(),
        dl.gpu.GEMV(),
        dl.gpu.Reduction(),
        dl.gpu.GeneralReduction(),
        dl.gpu.Fallback(),
    )(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), param_0: T.Buffer((T.int64(256), T.int64(10)), "float32"), param_1: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(10)), scope="local")
        matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(10)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0)
                    for ax1_fused_0, u in T.grid(T.int64(16), 1):
                        with T.block("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(16), ax1_fused_0)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], lv3[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + lv3[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul"):
                        vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax1_fused])
                        T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul_intermediate_local[T.int64(0), v0])
                        with T.init():
                            matmul_intermediate_local[T.int64(0), v0] = T.float32(0)
                        matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_fused_0_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                for ax0_fused_1 in range(T.int64(1)):
                    with T.block("T_add"):
                        v0 = T.axis.spatial(T.int64(10), ax0_fused_0_1 + ax0_fused_1)
                        T.reads(matmul_intermediate_local[T.int64(0), v0], param_1[v0])
                        T.writes(T_add_intermediate[T.int64(0), v0])
                        T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + param_1[v0]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), param_0: T.Buffer((T.int64(784), T.int64(256)), "float32"), param_1: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(256)), scope="local")
        matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
        for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
            for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul_rf_init"):
                        vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                        T.reads()
                        T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0)
                    for ax1_fused_0, u in T.grid(T.int64(49), 1):
                        with T.block("matmul_rf_update"):
                            vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
                            v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
                            vax1_fused_0 = T.axis.reduce(T.int64(49), ax1_fused_0)
                            T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], inp_0[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
                            T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                            matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + inp_0[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
            for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
                    with T.block("matmul"):
                        vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
                        T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
                        T.writes(matmul_intermediate_local[T.int64(0), v0])
                        with T.init():
                            matmul_intermediate_local[T.int64(0), v0] = T.float32(0)
                        matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
            for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
                for ax0_fused_1 in range(T.int64(1)):
                    with T.block("compute"):
                        v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
                        T.reads(matmul_intermediate_local[T.int64(0), v0], param_1[v0])
                        T.writes(compute_intermediate[T.int64(0), v0])
                        compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + param_1[v0], T.float32(0))

    @R.function
    def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.fused_matmul_add_relu, (inp_0, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

Note

The DLight framework is still under development, and currently only supports GPU backends with limited operators, to be specific, common operators used in LLMs. We would improve the framework in the future to support more operators and backends.

Compile and Run

After the optimization, we can compile the model into a TVM runtime module. Apache TVM Unity use Relax Virtual Machine to run the model. The following code shows how to compile the model

exec = relax.build(mod, target=target)
dev = tvm.device(str(target.kind), 0)
vm = relax.VirtualMachine(exec, dev)

Now we can run the model on the TVM runtime module. We first prepare the input data and then invoke the TVM runtime module to get the output.

import numpy as np

data = np.random.rand(1, 784).astype("float32")
tvm_data = tvm.nd.array(data, device=dev)
tvm_out = vm["main"](tvm_data).numpy()

We can also compare the output with the PyTorch model to verify the correctness.

with torch.no_grad():
    torch_out = torch_model(torch.Tensor(data)).numpy()

np.testing.assert_allclose(tvm_out, torch_out, rtol=1e-5, atol=1e-5)

Relax VM supports timing evaluation. We can use the following code to get the timing result.

timing_res = vm.time_evaluator("main", dev)(tvm_data)
print(timing_res)
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)
   0.0085       0.0085       0.0085       0.0085       0.0000

Total running time of the script: (0 minutes 3.096 seconds)