Integration with Machine Learning Frameworks ============================================ Prelude ------- In the past chapters, we have learned about abstractions for machine learning compilation and transformations among tensor functions. This chapter will discuss how to bring machine learning models from the existing ML framework into an MLC flow. Preparations ------------ To begin with, we will import necessary dependencies. .. raw:: latex \diilbookstyleinputcell .. code:: python import numpy as np import tvm from tvm import relax from tvm.ir.module import IRModule from tvm.script import relax as R from tvm.script import tir as T .. raw:: latex \diilbookstyleinputcell .. code:: python import torch import torch.nn as nn from torch import fx from torch.nn import functional as F Build an IRModule Through a Builder ----------------------------------- In the past chapters, we have been building IRModule by directly writing TVMScript. As the model gets larger, we need a programmatical way to build up an IRModule. In this section, let us review some of the tools to support that process. Tensor Expression for TensorIR Creation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ First, we review the tensor expression domain-specific language to build TensorIR functions. .. raw:: latex \diilbookstyleinputcell .. code:: python from tvm import te We begin by creating a placeholder object, which represents an input to a TensorIR function. .. raw:: latex \diilbookstyleinputcell .. code:: python A = te.placeholder((128, 128), name="A", dtype="float32") B = te.placeholder((128, 128), name="B", dtype="float32") Each input and intermediate result here are represented as a ``te.Tensor`` object. .. raw:: latex \diilbookstyleinputcell .. code:: python type(A) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.te.tensor.Tensor Each ``te.Tensor`` has a shape field and dtype field that tracks the shape and data type of the computation. .. raw:: latex \diilbookstyleinputcell .. code:: python A.shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [128, 128] We can describe computations through a sequence of tensor expression computation, Here ``te.compute`` takes the signature ``te.compute(output_shape, fcompute)``. And the fcompute function describes how we want to compute the value of each element ``[i, j]`` for a given index. The ``te_matmul`` function takes in an object with type ``te.Tensor``, and returns the matrix multiplication result. Note how we build up computations depending on A and B’s input shape. The ``te_matmul`` works for A and B with different input shapes. .. raw:: latex \diilbookstyleinputcell .. code:: python def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor: assert A.shape[1] == B.shape[0] n = A.shape[0] m = B.shape[1] k = te.reduce_axis((0, A.shape[1]), name="k") return te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul") We can create the result of matmul calling ``te_matmul`` with A and B. .. raw:: latex \diilbookstyleinputcell .. code:: python C = te_matmul(A, B) To create a TensorIR function, we can call ``te.create_prim_func`` and pass in the input and output values. .. raw:: latex \diilbookstyleinputcell .. code:: python te.create_prim_func([A, B, C]).show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), matmul: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(128, 128, 128):
with T.block("matmul"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(matmul[v_i, v_j])
with T.init():
matmul[v_i, v_j] = T.float32(0)
matmul[v_i, v_j] = matmul[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
# from tvm.script import tir as T
@T.prim_func
def main(X1: T.Buffer((10,), "float32"), relu: T.Buffer((10,), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i0 in range(10):
with T.block("relu"):
v_i0 = T.axis.spatial(10, i0)
T.reads(X1[v_i0])
T.writes(relu[v_i0])
relu[v_i0] = T.max(X1[v_i0], T.float32(0))
# from tvm.script import tir as T
@T.prim_func
def main(X1: T.Buffer((10, 20), "float32"), relu: T.Buffer((10, 20), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(10, 20):
with T.block("relu"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(X1[v_i0, v_i1])
T.writes(relu[v_i0, v_i1])
relu[v_i0, v_i1] = T.max(X1[v_i0, v_i1], T.float32(0))
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), relu: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
matmul = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("matmul"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(matmul[v_i, v_j])
with T.init():
matmul[v_i, v_j] = T.float32(0)
matmul[v_i, v_j] = matmul[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
for i0, i1 in T.grid(128, 128):
with T.block("relu"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(matmul[v_i0, v_i1])
T.writes(relu[v_i0, v_i1])
relu[v_i0, v_i1] = T.max(matmul[v_i0, v_i1], T.float32(0))
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), matmul: T.Buffer((128, 128), "float32"), relu: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(128, 128, 128):
with T.block("matmul"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(matmul[v_i, v_j])
with T.init():
matmul[v_i, v_j] = T.float32(0)
matmul[v_i, v_j] = matmul[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
for i0, i1 in T.grid(128, 128):
with T.block("relu"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(matmul[v_i0, v_i1])
T.writes(relu[v_i0, v_i1])
relu[v_i0, v_i1] = T.max(matmul[v_i0, v_i1], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def te_matmul(rxplaceholder: T.Buffer((T.int64(128), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(128), T.int64(128), T.int64(128)):
with T.block("matmul"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
T.writes(matmul[v_i, v_j])
with T.init():
matmul[v_i, v_j] = T.float32(0)
matmul[v_i, v_j] = matmul[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
@T.prim_func
def te_relu(rxplaceholder: T.Buffer((T.int64(128), T.int64(128)), "float32"), relu: T.Buffer((T.int64(128), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(128), T.int64(128)):
with T.block("relu"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[v_i0, v_i1])
T.writes(relu[v_i0, v_i1])
relu[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
@R.function
def main(A: R.Tensor((128, 128), dtype="float32"), B: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.te_matmul, (A, B), out_sinfo=R.Tensor((128, 128), dtype="float32"))
lv1 = R.call_tir(cls.te_relu, (lv,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
gv: R.Tensor((128, 128), dtype="float32") = lv1
R.output(gv)
return gv
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def te_matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(128)):
with T.block("matmul"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
T.writes(matmul[v_i, v_j])
with T.init():
matmul[v_i, v_j] = T.float32(0)
matmul[v_i, v_j] = matmul[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
@T.prim_func
def te_relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), relu: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(128)):
with T.block("relu"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[v_i0, v_i1])
T.writes(relu[v_i0, v_i1])
relu[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
@R.function
def main(x: R.Tensor((1, 128), dtype="float32")) -> R.Tensor((1, 128), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.te_matmul, (x, metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv1 = R.call_tir(cls.te_relu, (lv,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
gv: R.Tensor((1, 128), dtype="float32") = lv1
R.output(gv)
return lv1
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
@T.prim_func
def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
@T.prim_func
def dense(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_matmul_NT: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float32(0)
T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_j, v_k]
@T.prim_func
def dense1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_matmul_NT: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float32(0)
T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_j, v_k]
@T.prim_func
def te_relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), relu: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(128)):
with T.block("relu"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[v_i0, v_i1])
T.writes(relu[v_i0, v_i1])
relu[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
@R.function
def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.dense, (x, metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv1 = R.call_tir(cls.add, (lv, metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv2 = R.call_tir(cls.te_relu, (lv1,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv3 = R.call_tir(cls.dense1, (lv2, metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
lv4 = R.call_tir(cls.add1, (lv3, metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv: R.Tensor((1, 10), dtype="float32") = lv4
R.output(gv)
return lv4
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
lv1: R.Tensor((1, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
lv2: R.Tensor((1, 128), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][3])
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return lv6
# Metadata omitted. Use show_meta=True in script() method to show it.