Note
This tutorial can be used interactively with Google Colab! You can also click here to run the Jupyter notebook locally.
Quick Start¶
This tutorial is for people who are new to Apache TVM Unity. Taking an simple example to show how to use Apache TVM Unity to compile a simple neural network.
Prepare the Neural Network Model¶
Before we get started, let’s prepare a neural network model first. In this tutorial, to make things simple, we will defined a two-layer MLP networks directly in this script. For people who are trying to run real models, please jump to the next section.
import torch
from torch import nn
class MLPModel(nn.Module):
def __init__(self):
super(MLPModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
torch_model = MLPModel()
Import Model into Apache TVM Unity¶
We choose PyTorch FX as our frontend. PyTorch FX is a toolkit for tracing PyTorch programs into a intermediate representation (IR) with symbolic shape support.
Note
Original PyTorch FX may not be compatible with HuggingFace Model. Please use HuggingFace self-defined FX to trace the model.
from tvm import relax
from tvm.relax.frontend.torch import from_fx
from torch import fx
torch_fx_model = fx.symbolic_trace(torch_model)
As the PyTorch model does not contain input information like in ONNX, we need to provide the input information ourselves. This includes the shape and data type of the input tensors, which are represented as a list of tuples. Each tuple contains the shape and data type of one input tensor.
In this particular example, the shape of the input tensor is (1, 784)
and
the data type is "float32"
. We combine the shape and data type in a tuple
like ((1, 784), "float32")
. Then we gather all the input tuples into a list,
which looks like [((1, 784), "float32")]
.
input_info = [((1, 784), "float32")]
Use the Apache TVM Unity API to convert the PyTorch FX model into Relax Model. And print it out to in the TVMScript Syntax
with torch.no_grad():
mod = from_fx(torch_fx_model, input_info)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((784, 256), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
lv1: R.Tensor((1, 256), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32")
lv2: R.Tensor((1, 256), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
lv3: R.Tensor((1, 256), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((256, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="float32")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][3])
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
Up to this point, we have successfully transformed the PyTorch FX model into a TVM IRModule. It is important to mention that the IRModule is the central abstraction of Apache TVM Unity, and it is utilized for subsequent transformations and optimization processes. The IRModule has the ability to hold both high-level graph IR (Relax) and low-level tensor IR (TensorIR). Currently, the IRModule solely consists of Relax functions, which are marked with the @R.function decorator.
Transform The Model¶
Apply Optimization Transforms¶
We can apply a variety of optimization transforms to the IRModule. We have predefined a set of optimization transforms to simplify their usage. By using the get_pipeline function, we can apply the default optimization flow. By following the default path, the following transformations will be applied in order:
LegalizeOps: This transform converts the Relax operators into call_tir functions with the corresponding TensorIR Functions. After this transform, the IRModule will contain both Relax functions and TensorIR functions.
AnnotateTIROpPattern: This transform annotates the pattern of the TensorIR functions, preparing them for subsequent operator fusion.
FoldConstant: This pass performs constant folding, optimizing operations involving constants.
FuseOps and FuseTIR: These two passes work together to fuse operators based on the patterns annotated in the previous step (AnnotateTIROpPattern). These passes transform both Relax functions and TensorIR functions.
mod = relax.get_pipeline()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), param_0: T.Buffer((T.int64(256), T.int64(10)), "float32"), param_1: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(lv3[v_i0, v_k], param_0[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + lv3[v_i0, v_k] * param_0[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], param_1[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + param_1[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), param_0: T.Buffer((T.int64(784), T.int64(256)), "float32"), param_1: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], param_1[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + param_1[v_ax1]
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0))
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.fused_matmul_add_relu, (inp_0, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 256), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
If you are only interested in the changes of the Relax functions and omit the
TensorIR functions, print the main
function of the IRModule.
mod["main"].show()
# from tvm.script import relax as R
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
with R.dataflow():
lv = R.call_tir(fused_matmul_add_relu, (inp_0, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 256), dtype="float32"))
gv = R.call_tir(fused_matmul1_add1, (lv, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
Tensor Function Optimization¶
Usually we apply Tensor Function Optimization after the Relax Function Optimization,
as graph transformations will changes the TIR functions.
There are different ways to apply Tensor Function Optimization, we choose DLight
on
cuda
target in this tutorial. Note that DLight
is not the only way to optimize
the Tensor Function, for other optimizations, please refer to corresponding tutorials.
import tvm
from tvm import dlight as dl
target = tvm.target.Target("cuda")
with target:
mod = dl.ApplyDefaultSchedule(
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
)(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(lv3: T.Buffer((T.int64(1), T.int64(256)), "float32"), param_0: T.Buffer((T.int64(256), T.int64(10)), "float32"), param_1: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(10)), scope="local")
matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(10)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.block("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
T.reads()
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0)
for ax1_fused_0, u in T.grid(T.int64(16), 1):
with T.block("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(10), ax0_fused_0 * T.int64(10) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(16), ax1_fused_0)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], lv3[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + lv3[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.block("matmul"):
vax1_fused_1, v0 = T.axis.remap("RS", [ax0, ax1_fused])
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul_intermediate_local[T.int64(0), v0])
with T.init():
matmul_intermediate_local[T.int64(0), v0] = T.float32(0)
matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
for ax0_fused_0_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.block("T_add"):
v0 = T.axis.spatial(T.int64(10), ax0_fused_0_1 + ax0_fused_1)
T.reads(matmul_intermediate_local[T.int64(0), v0], param_1[v0])
T.writes(T_add_intermediate[T.int64(0), v0])
T_add_intermediate[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + param_1[v0]
@T.prim_func(private=True)
def fused_matmul_add_relu(inp_0: T.Buffer((T.int64(1), T.int64(784)), "float32"), param_0: T.Buffer((T.int64(784), T.int64(256)), "float32"), param_1: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(256)), scope="local")
matmul_intermediate_rf_local = T.alloc_buffer((T.int64(16), T.int64(1), T.int64(256)), scope="local")
for ax0_fused_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax1_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.block("matmul_rf_init"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
T.reads()
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = T.float32(0)
for ax1_fused_0, u in T.grid(T.int64(49), 1):
with T.block("matmul_rf_update"):
vax1_fused_1 = T.axis.spatial(T.int64(16), ax1_fused_1)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_1)
vax1_fused_0 = T.axis.reduce(T.int64(49), ax1_fused_0)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0], inp_0[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1], param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0])
T.writes(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] = matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0] + inp_0[T.int64(0), vax1_fused_0 * T.int64(16) + vax1_fused_1] * param_0[vax1_fused_0 * T.int64(16) + vax1_fused_1, v0]
for ax1_fused in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0 in T.thread_binding(T.int64(16), thread="threadIdx.y"):
with T.block("matmul"):
vax1_fused_1 = T.axis.reduce(T.int64(16), ax0)
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax1_fused)
T.reads(matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0])
T.writes(matmul_intermediate_local[T.int64(0), v0])
with T.init():
matmul_intermediate_local[T.int64(0), v0] = T.float32(0)
matmul_intermediate_local[T.int64(0), v0] = matmul_intermediate_local[T.int64(0), v0] + matmul_intermediate_rf_local[vax1_fused_1, T.int64(0), v0]
for ax0_fused_0_1 in T.thread_binding(T.int64(16), thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.block("compute"):
v0 = T.axis.spatial(T.int64(256), ax0_fused_0 * T.int64(16) + ax0_fused_0_1 + ax0_fused_1)
T.reads(matmul_intermediate_local[T.int64(0), v0], param_1[v0])
T.writes(compute_intermediate[T.int64(0), v0])
compute_intermediate[T.int64(0), v0] = T.max(matmul_intermediate_local[T.int64(0), v0] + param_1[v0], T.float32(0))
@R.function
def main(inp_0: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.fused_matmul_add_relu, (inp_0, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 256), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
Note
The DLight
framework is still under development, and currently only supports
GPU backends with limited operators, to be specific, common operators used in LLMs.
We would improve the framework in the future to support more operators and backends.
Compile and Run¶
After the optimization, we can compile the model into a TVM runtime module. Apache TVM Unity use Relax Virtual Machine to run the model. The following code shows how to compile the model
exec = relax.build(mod, target=target)
dev = tvm.device(str(target.kind), 0)
vm = relax.VirtualMachine(exec, dev)
Now we can run the model on the TVM runtime module. We first prepare the input data and then invoke the TVM runtime module to get the output.
import numpy as np
data = np.random.rand(1, 784).astype("float32")
tvm_data = tvm.nd.array(data, device=dev)
tvm_out = vm["main"](tvm_data).numpy()
We can also compare the output with the PyTorch model to verify the correctness.
with torch.no_grad():
torch_out = torch_model(torch.Tensor(data)).numpy()
np.testing.assert_allclose(tvm_out, torch_out, rtol=1e-5, atol=1e-5)
Relax VM supports timing evaluation. We can use the following code to get the timing result.
timing_res = vm.time_evaluator("main", dev)(tvm_data)
print(timing_res)
Execution time summary:
mean (ms) median (ms) max (ms) min (ms) std (ms)
0.0085 0.0085 0.0085 0.0085 0.0000
Total running time of the script: (0 minutes 3.096 seconds)