Integration with Machine Learning Frameworks ============================================ Prelude ------- In the past chapters, we have learned about abstractions for machine learning compilation and transformations among tensor functions. This chapter will discuss how to bring machine learning models from the existing ML framework into an MLC flow. Preparations ------------ To begin with, we will import necessary dependencies. .. raw:: latex \diilbookstyleinputcell .. code:: python import numpy as np import tvm from tvm import relax from tvm.ir.module import IRModule from tvm.script import relax as R from tvm.script import tir as T .. raw:: latex \diilbookstyleinputcell .. code:: python import torch import torch.nn as nn from torch import fx from torch.nn import functional as F Build an IRModule Through a Builder ----------------------------------- In the past chapters, we have been building IRModule by directly writing TVMScript. As the model gets larger, we need a programmatical way to build up an IRModule. In this section, let us review some of the tools to support that process. Tensor Expression for TensorIR Creation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ First, we review the tensor expression domain-specific language to build TensorIR functions. .. raw:: latex \diilbookstyleinputcell .. code:: python from tvm import te We begin by creating a placeholder object, which represents an input to a TensorIR function. .. raw:: latex \diilbookstyleinputcell .. code:: python A = te.placeholder((128, 128), name="A", dtype="float32") B = te.placeholder((128, 128), name="B", dtype="float32") Each input and intermediate result here are represented as a ``te.Tensor`` object. .. raw:: latex \diilbookstyleinputcell .. code:: python type(A) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.te.tensor.Tensor Each ``te.Tensor`` has a shape field and dtype field that tracks the shape and data type of the computation. .. raw:: latex \diilbookstyleinputcell .. code:: python A.shape .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [128, 128] We can describe computations through a sequence of tensor expression computation, Here ``te.compute`` takes the signature ``te.compute(output_shape, fcompute)``. And the fcompute function describes how we want to compute the value of each element ``[i, j]`` for a given index. The ``te_matmul`` function takes in an object with type ``te.Tensor``, and returns the matrix multiplication result. Note how we build up computations depending on A and B’s input shape. The ``te_matmul`` works for A and B with different input shapes. .. raw:: latex \diilbookstyleinputcell .. code:: python def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor: assert A.shape[1] == B.shape[0] n = A.shape[0] m = B.shape[1] k = te.reduce_axis((0, A.shape[1]), name="k") return te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul") We can create the result of matmul calling ``te_matmul`` with A and B. .. raw:: latex \diilbookstyleinputcell .. code:: python C = te_matmul(A, B) To create a TensorIR function, we can call ``te.create_prim_func`` and pass in the input and output values. .. raw:: latex \diilbookstyleinputcell .. code:: python te.create_prim_func([A, B, C]).show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import tir as T
    
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), matmul: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j, k in T.grid(128, 128, 128):
            with T.block("matmul"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(matmul[v_i, v_j])
                with T.init():
                    matmul[v_i, v_j] = T.float32(0)
                matmul[v_i, v_j] = matmul[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
    
We can create a tensor expression for relu computation in a similar fashion. Here we write it in a way so that ``te_relu`` function can work for ``te.Tensor`` with any dimension and shape. .. raw:: latex \diilbookstyleinputcell .. code:: python def te_relu(A: te.Tensor) -> te.Tensor: return te.compute(A.shape, lambda *i: te.max(A(*i), 0), name="relu") Let us try out ``te_relu`` on two different input shapes and dimensions. First ``X1`` with shape ``(10,)``. .. raw:: latex \diilbookstyleinputcell .. code:: python X1 = te.placeholder((10,), name="X1", dtype="float32") Y1 = te_relu(X1) te.create_prim_func([X1, Y1]).show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import tir as T
    
    @T.prim_func
    def main(X1: T.Buffer((10,), "float32"), relu: T.Buffer((10,), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0 in range(10):
            with T.block("relu"):
                v_i0 = T.axis.spatial(10, i0)
                T.reads(X1[v_i0])
                T.writes(relu[v_i0])
                relu[v_i0] = T.max(X1[v_i0], T.float32(0))
    
Then ``X2`` with shape ``(10, 20)``. .. raw:: latex \diilbookstyleinputcell .. code:: python X2 = te.placeholder((10, 20), name="X1", dtype="float32") Y2 = te_relu(X2) te.create_prim_func([X2, Y2]).show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import tir as T
    
    @T.prim_func
    def main(X1: T.Buffer((10, 20), "float32"), relu: T.Buffer((10, 20), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1 in T.grid(10, 20):
            with T.block("relu"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(X1[v_i0, v_i1])
                T.writes(relu[v_i0, v_i1])
                relu[v_i0, v_i1] = T.max(X1[v_i0, v_i1], T.float32(0))
    
One final thing that ``te`` API allows us to do is to compose operations and create “fused” operators. For example, we can take the result of matmul and apply relu again. .. raw:: latex \diilbookstyleinputcell .. code:: python C = te_matmul(A, B) D = te_relu(C) We can create a TensorIR function by only passing the input and output values of interest, skipping intermediate values. This will cause the result of matmul being allocated as a temp space in the TensorIR function. .. raw:: latex \diilbookstyleinputcell .. code:: python te.create_prim_func([A, B, D]).show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import tir as T
    
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), relu: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul = T.alloc_buffer((128, 128))
        for i, j, k in T.grid(128, 128, 128):
            with T.block("matmul"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(matmul[v_i, v_j])
                with T.init():
                    matmul[v_i, v_j] = T.float32(0)
                matmul[v_i, v_j] = matmul[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
        for i0, i1 in T.grid(128, 128):
            with T.block("relu"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(matmul[v_i0, v_i1])
                T.writes(relu[v_i0, v_i1])
                relu[v_i0, v_i1] = T.max(matmul[v_i0, v_i1], T.float32(0))
    
We can also pass the intermediate result C into the argument list. In this case, the TensorIR function expects us to also pass in the buffer of C from the caller side. Normally we recommend only passing in the input/output so we can have more advanced fusion inside. .. raw:: latex \diilbookstyleinputcell .. code:: python te.create_prim_func([A, B, C, D]).show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import tir as T
    
    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), matmul: T.Buffer((128, 128), "float32"), relu: T.Buffer((128, 128), "float32")):
        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j, k in T.grid(128, 128, 128):
            with T.block("matmul"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B[v_k, v_j])
                T.writes(matmul[v_i, v_j])
                with T.init():
                    matmul[v_i, v_j] = T.float32(0)
                matmul[v_i, v_j] = matmul[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
        for i0, i1 in T.grid(128, 128):
            with T.block("relu"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(matmul[v_i0, v_i1])
                T.writes(relu[v_i0, v_i1])
                relu[v_i0, v_i1] = T.max(matmul[v_i0, v_i1], T.float32(0))
    
Use BlockBuilder to Create an IRModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ So far, we have created a single TensorIR function. In order to build end-to-end model execution, we also need to be able to connect multiple TensorIR functions through a computational graph. Let us first create a block builder, which helps us incrementally build a ``relax.Function``. .. raw:: latex \diilbookstyleinputcell .. code:: python A = relax.Var("A", relax.TensorStructInfo((128, 128), "float32")) B = relax.Var("B", relax.TensorStructInfo((128, 128), "float32")) We construct the relax function by creating a block builder and then a sequence of primitive tensor operations. .. raw:: latex \diilbookstyleinputcell .. code:: python bb = relax.BlockBuilder() with bb.function("main"): with bb.dataflow(): C = bb.emit_te(te_matmul, A, B) D = bb.emit_te(te_relu, C) R = bb.emit_output(D) bb.emit_func_output(R, params=[A, B]) MyModule = bb.get() MyModule.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    # from tvm.script import relax as R
    
    @I.ir_module
    class Module:
        @T.prim_func
        def te_matmul(rxplaceholder: T.Buffer((T.int64(128), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul: T.Buffer((T.int64(128), T.int64(128)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i, j, k in T.grid(T.int64(128), T.int64(128), T.int64(128)):
                with T.block("matmul"):
                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                    T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
                    T.writes(matmul[v_i, v_j])
                    with T.init():
                        matmul[v_i, v_j] = T.float32(0)
                    matmul[v_i, v_j] = matmul[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
    
        @T.prim_func
        def te_relu(rxplaceholder: T.Buffer((T.int64(128), T.int64(128)), "float32"), relu: T.Buffer((T.int64(128), T.int64(128)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i0, i1 in T.grid(T.int64(128), T.int64(128)):
                with T.block("relu"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(rxplaceholder[v_i0, v_i1])
                    T.writes(relu[v_i0, v_i1])
                    relu[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
    
        @R.function
        def main(A: R.Tensor((128, 128), dtype="float32"), B: R.Tensor((128, 128), dtype="float32")) -> R.Tensor((128, 128), dtype="float32"):
            cls = Module
            with R.dataflow():
                lv = R.call_tir(cls.te_matmul, (A, B), out_sinfo=R.Tensor((128, 128), dtype="float32"))
                lv1 = R.call_tir(cls.te_relu, (lv,), out_sinfo=R.Tensor((128, 128), dtype="float32"))
                gv: R.Tensor((128, 128), dtype="float32") = lv1
                R.output(gv)
            return gv
    
Deep Dive into Block Builder APIs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Now let us do a deep dive into each block builder API. It is helpful to put the block builder code and the resulting module side by side. .. figure:: ../img/integration_block_builder.png The block builder comes with scopes that correspond to the scopes in the relax function. For example, ``bb.dataflow()`` creates a dataflow block where all the block builder calls inside the scope belonging to the dataflow scope. .. raw:: latex \diilbookstyleinputcell .. code:: python with bb.function("main"): with bb.dataflow(): # every emit call generates a variable inside a dataflow block. Each intermediate result is a ``relax.Var`` corresponding to a variable that stores the result of the computation. ``DataflowVar`` indicates that the var is an intermediate step inside a dataflow block (computational graph). .. raw:: latex \diilbookstyleinputcell .. code:: python type(C) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.relax.expr.DataflowVar .. raw:: latex \diilbookstyleinputcell .. code:: python isinstance(C, relax.Var) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output True Each line in the relax function is generated by an ``emit_te`` call. For example, .. raw:: latex \diilbookstyleinputcell .. code:: python lv = R.call_dps_packed(te_matmul, (A, B), (128, 128), dtype="float32") is generated by .. raw:: latex \diilbookstyleinputcell .. code:: python C = bb.emit_te(te_matmul, A, B). Under the hood, the bb.emit_te does the following things: - Create an input ``te.placeholder`` for A and B - Run them through ``te_matmul`` function. - Call into ``te.create_prim_func`` to create a TensorIR function. - Generate a call into the function via ``call_dps_packed``. We can find that the result is a computational graph with two intermediate values, with one node corresponding to the te_matmul operation and another one corresponding to ``te_relu``. We can create output variable of each dataflow block through ``bb.emit_output``. .. raw:: latex \diilbookstyleinputcell .. code:: python with bb.dataflow(): ... R = bb.emit_output(D) The above code marks that D is a variable that can be referred to outside of the dataflow block. Finally, the function output is marked by ``bb.emit_func_output``. We can only call ``emit_func_output`` once in each function scope. Notably, we can specify the list of parameters of the function in the output emission stage. Doing so helps us in cases where we collect the list of parameters on the fly. .. raw:: latex \diilbookstyleinputcell .. code:: python with bb.function("main"): ... # specify parameters in the end bb.emit_func_output(R, params=[A, B]) Alternatively, we can specify the list of parameters at the beginning of the function scope. .. raw:: latex \diilbookstyleinputcell .. code:: python # specify parameters in the beginning. with bb.function("main", params=[A, B]): ... bb.emit_func_output(R) Import Model From PyTorch ------------------------- Now that we have learned the tools to construct an IRModule programmatically. Let us use them to bring a model from PyTorch into the IRModule format. Most machine learning framework comes with computational graph abstractions, where each node corresponds to an operation, and the edges correspond to the dependency among them. We will take a PyTorch model, obtain a computational graph in PyTorch’s native format, and translate that into IRModule. Let us begin by defining a model in PyTorch. To keep the example consistent, we will use matmul relu example. .. raw:: latex \diilbookstyleinputcell .. code:: python class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.weight = nn.Parameter(torch.randn(128, 128)) def forward(self, x): x = torch.matmul(x, self.weight) x = torch.relu(x) return x Create TorchFX GraphModule ~~~~~~~~~~~~~~~~~~~~~~~~~~ We use TorchFX to trace a graph from the PyTorch module. .. raw:: latex \diilbookstyleinputcell .. code:: python model = MyModel() fx_module = fx.symbolic_trace(model) type(fx_module) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output torch.fx.graph_module.GraphModule.__new__..GraphModuleImpl The ``fx_module`` contains a simple computation graph view that can be printed as tabular data. Our goal is to translate this graph into an IRModule. .. raw:: latex \diilbookstyleinputcell .. code:: python fx_module.graph.print_tabular() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output opcode name target args kwargs ------------- ------ --------------------------------------------------------- ----------- -------- placeholder x x () {} get_attr weight weight () {} call_function matmul (x, weight) {} call_function relu (matmul,) {} output output output (relu,) {} Create Map Function ~~~~~~~~~~~~~~~~~~~ Let us define the overall high-level translation logic. The main flow is as follows: - Create a ``node_map`` that maps ``fx.Node`` to the corresponding ``relax.Var`` that represents the translated node in IRModule. - Iterate over the nodes in the fx graph in topological order. - Compute the mapped output of the node given the mapped inputs. .. raw:: latex \diilbookstyleinputcell .. code:: python def map_param(param: nn.Parameter): return relax.const( param.data.cpu().numpy(), relax.TensorStructInfo(param.data.shape, "float32") ) def fetch_attr(fx_mod, target: str): """Helper function to fetch an attr""" target_atoms = target.split('.') attr_itr = fx_mod for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") attr_itr = getattr(attr_itr, atom) return attr_itr def from_fx(fx_mod, input_shapes, call_function_map, call_module_map): input_index = 0 node_map = {} named_modules = dict(fx_mod.named_modules()) bb = relax.BlockBuilder() fn_inputs = [] fn_output = None with bb.function("main"): with bb.dataflow(): for node in fx_mod.graph.nodes: if node.op == "placeholder": # create input placeholder shape = input_shapes[input_index] input_index += 1 input_var = relax.Var( node.target, relax.TensorStructInfo(shape, "float32") ) fn_inputs.append(input_var) node_map[node] = input_var elif node.op == "get_attr": node_map[node] = map_param(fetch_attr(fx_mod, node.target)) elif node.op == "call_function": node_map[node] = call_function_map[node.target](bb, node_map, node) elif node.op == "call_module": named_module = named_modules[node.target] node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module) elif node.op == "output": output = node_map[node.args[0]] assert fn_output is None fn_output = bb.emit_output(output) # output and finalize the function bb.emit_func_output(output, fn_inputs) return bb.get() We did not define the function map in the ``from_fx`` function. We will supply the translation rule of each torch function via a map. Specifically, the following code block shows how we can do that through the ``emit_te`` API. .. raw:: latex \diilbookstyleinputcell .. code:: python def map_matmul(bb, node_map, node: fx.Node): A = node_map[node.args[0]] B = node_map[node.args[1]] return bb.emit_te(te_matmul, A, B) def map_relu(bb, node_map, node: fx.Node): A = node_map[node.args[0]] return bb.emit_te(te_relu, A) MyModule = from_fx( fx_module, input_shapes = [(1, 128)], call_function_map = { torch.matmul: map_matmul, torch.relu: map_relu, }, call_module_map={}, ) MyModule.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    # from tvm.script import relax as R
    
    @I.ir_module
    class Module:
        @T.prim_func
        def te_matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(128)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(128)):
                with T.block("matmul"):
                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                    T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
                    T.writes(matmul[v_i, v_j])
                    with T.init():
                        matmul[v_i, v_j] = T.float32(0)
                    matmul[v_i, v_j] = matmul[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
    
        @T.prim_func
        def te_relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), relu: T.Buffer((T.int64(1), T.int64(128)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i0, i1 in T.grid(T.int64(1), T.int64(128)):
                with T.block("relu"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(rxplaceholder[v_i0, v_i1])
                    T.writes(relu[v_i0, v_i1])
                    relu[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
    
        @R.function
        def main(x: R.Tensor((1, 128), dtype="float32")) -> R.Tensor((1, 128), dtype="float32"):
            cls = Module
            with R.dataflow():
                lv = R.call_tir(cls.te_matmul, (x, metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                lv1 = R.call_tir(cls.te_relu, (lv,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                gv: R.Tensor((1, 128), dtype="float32") = lv1
                R.output(gv)
            return lv1
    
    # Metadata omitted. Use show_meta=True in script() method to show it.
    
Coming back to FashionMNIST Example ----------------------------------- .. raw:: latex \diilbookstyleinputcell .. code:: python import torch import torchvision test_data = torchvision.datasets.FashionMNIST( root="data", train=False, download=True, transform=torchvision.transforms.ToTensor() ) test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True) class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] img, label = next(iter(test_loader)) img = img.reshape(1, 28, 28).numpy() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 100.0% Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 100.0% Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz 100.0% Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz 100.0%Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw .. raw:: latex \diilbookstyleinputcell .. code:: python import matplotlib.pyplot as plt plt.figure() plt.imshow(img[0]) plt.colorbar() plt.grid(False) plt.show() print("Class:", class_names[label[0]]) .. figure:: output_index_4f28a7_60_0.png .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Class: Sneaker .. raw:: latex \diilbookstyleinputcell .. code:: python # Hide outputs !wget -nc https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl .. figure:: ../img/e2e_fashionmnist_mlp_model.png The above is our model of interest, we can build the PyTorch model as follows. .. raw:: latex \diilbookstyleinputcell .. code:: python class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.linear0 = nn.Linear(784, 128, bias=True) self.relu = nn.ReLU() self.linear1 = nn.Linear(128, 10, bias=True) def forward(self, x): x = self.linear0(x) x = self.relu(x) x = self.linear1(x) return x .. raw:: latex \diilbookstyleinputcell .. code:: python import pickle as pkl mlp_model = MLP() mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb")) mlp_model.linear0.weight.data = torch.from_numpy(mlp_params["w0"]) mlp_model.linear0.bias.data = torch.from_numpy(mlp_params["b0"]) mlp_model.linear1.weight.data = torch.from_numpy(mlp_params["w1"]) mlp_model.linear1.bias.data = torch.from_numpy(mlp_params["b1"]) .. raw:: latex \diilbookstyleinputcell .. code:: python torch_res = mlp_model(torch.from_numpy(img.reshape(1, 784))) pred_kind = np.argmax(torch_res.detach().numpy(), axis=1) print("Torch Prediction:", class_names[pred_kind[0]]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Torch Prediction: Sneaker Let us try to translate from fx by defining mapping functions for the corresponding ``nn.Module``. Here we are reusing pre-defined TE libraries from TVM ``topi`` instead of defining our own tensor expression. - ``topi.nn.dense(x, w)`` performs transposed matrix multiplication ``x @ w.T`` - ``topi.add`` performs broadcast add. .. raw:: latex \diilbookstyleinputcell .. code:: python from tvm import topi def map_nn_linear(bb, node_map, node, nn_mod): x = node_map[node.args[0]] w = map_param(nn_mod.weight) if nn_mod.bias is not None: b = map_param(nn_mod.bias) y = bb.emit_te(topi.nn.dense, x, w) return bb.emit_te(topi.add, y, b) def map_nn_relu(bb, node_map, node, nn_mod): return map_relu(bb, node_map, node) MLPModule = from_fx( fx.symbolic_trace(mlp_model), input_shapes = [(1, 784)], call_function_map={ }, call_module_map={ torch.nn.Linear: map_nn_linear, torch.nn.ReLU: map_nn_relu, }, ) MLPModule.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    # from tvm.script import relax as R
    
    @I.ir_module
    class Module:
        @T.prim_func
        def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
                with T.block("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
    
        @T.prim_func
        def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
                with T.block("T_add"):
                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                    T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
                    T.writes(T_add[v_ax0, v_ax1])
                    T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
    
        @T.prim_func
        def dense(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_matmul_NT: T.Buffer((T.int64(1), T.int64(128)), "float32")):
            T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
                with T.block("T_matmul_NT"):
                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                    T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_j, v_k])
                    T.writes(T_matmul_NT[v_i, v_j])
                    with T.init():
                        T_matmul_NT[v_i, v_j] = T.float32(0)
                    T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_j, v_k]
    
        @T.prim_func
        def dense1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_matmul_NT: T.Buffer((T.int64(1), T.int64(10)), "float32")):
            T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i, j, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
                with T.block("T_matmul_NT"):
                    v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                    T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_j, v_k])
                    T.writes(T_matmul_NT[v_i, v_j])
                    with T.init():
                        T_matmul_NT[v_i, v_j] = T.float32(0)
                    T_matmul_NT[v_i, v_j] = T_matmul_NT[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_j, v_k]
    
        @T.prim_func
        def te_relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), relu: T.Buffer((T.int64(1), T.int64(128)), "float32")):
            T.func_attr({"tir.noalias": T.bool(True)})
            # with T.block("root"):
            for i0, i1 in T.grid(T.int64(1), T.int64(128)):
                with T.block("relu"):
                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                    T.reads(rxplaceholder[v_i0, v_i1])
                    T.writes(relu[v_i0, v_i1])
                    relu[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
    
        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            cls = Module
            with R.dataflow():
                lv = R.call_tir(cls.dense, (x, metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                lv1 = R.call_tir(cls.add, (lv, metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                lv2 = R.call_tir(cls.te_relu, (lv1,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                lv3 = R.call_tir(cls.dense1, (lv2, metadata["relax.expr.Constant"][2]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
                lv4 = R.call_tir(cls.add1, (lv3, metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
                gv: R.Tensor((1, 10), dtype="float32") = lv4
                R.output(gv)
            return lv4
    
    # Metadata omitted. Use show_meta=True in script() method to show it.
    
.. raw:: latex \diilbookstyleinputcell .. code:: python ex = relax.build(MLPModule, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) data_nd = tvm.nd.array(img.reshape(1, 784)) nd_res = vm["main"](data_nd) pred_kind = np.argmax(nd_res.numpy(), axis=1) print("MLPModule Prediction:", class_names[pred_kind[0]]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output MLPModule Prediction: Sneaker Remark: Translating into High-level Operators --------------------------------------------- In most machine learning frameworks, it is sometimes helpful to first translate into high-level built-in primitive operators. The following code block gives an example to do that. .. raw:: latex \diilbookstyleinputcell .. code:: python def map_nn_relu_op(bb, node_map, node, nn_mod): A = node_map[node.args[0]] return bb.emit(relax.op.nn.relu(A)) def map_nn_linear_op(bb, node_map, node, nn_mod): x = node_map[node.args[0]] w = map_param(nn_mod.weight) b = map_param(nn_mod.bias) return bb.emit(relax.op.linear(x, w, b)) MLPModuleHighLevel = from_fx( fx.symbolic_trace(mlp_model), input_shapes = [(1, 784)], call_function_map={ }, call_module_map={ torch.nn.Linear: map_nn_linear_op, torch.nn.ReLU: map_nn_relu_op, }, ) MLPModuleHighLevel.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import relax as R
    
    @I.ir_module
    class Module:
        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            with R.dataflow():
                lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
                lv1: R.Tensor((1, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
                lv2: R.Tensor((1, 128), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
                lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2)
                lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
                lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
                lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][3])
                gv: R.Tensor((1, 10), dtype="float32") = lv6
                R.output(gv)
            return lv6
    
    # Metadata omitted. Use show_meta=True in script() method to show it.
    
After we get the model into IRModule with those built-in operator calls. These built-in operators are **higher-level abstraction** than the TensorIR functions. There can be different opportunities to further translate these primitive operators into either library or TensorIR functions. In most cases, it can be helpful to translate into high-level builtins when they are available. However, there are many cases where we cannot find the corresponding high-level built-in or when we want to specify the TensorIR function directly. In those cases, we can customize the translation logic or transformation to generate ``call_dps_packed`` or call into the library functions. Usually, we can get the best result by combining the high-level op, TensorIR, and library abstractions. We will discuss the tradeoffs in the follow-up lectures. Discussions ----------- In this chapter, we focus on the **develop** part of the MLC flow. We studied different ways to get models from machine learning frameworks onto the IRModule. We also briefly touched upon the high-level primitive operators. Once we get the model into the IRModule, we can introduce more kinds of transformations on primitive functions and computational graph functions. A good MLC process composes these transformations together to form an end deployment form. .. figure:: ../img/mlc_process.png Summary ------- - Tensor expression API allows us to create a primitive TensorIR function. - BlockBuilder API creates IRModule through ``emit_te`` and other functions. - Integrate with existing machine learning frameworks by transforming models into an IRModule.