TensorIR: Tensor Program Abstraction Case Study ----------------------------------------------- Install Packages ~~~~~~~~~~~~~~~~ For the purpose of this course, we will use some on-going development in tvm, which is an open source machine learning compilation framework. We provide the following command to install a packaged version for mlc course. .. raw:: latex \diilbookstyleinputcell .. code:: bash python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels Prelude ~~~~~~~ .. figure:: ../img/tensor_func_linear_relu.png To begin today’s lecture, let us recap the key principle of the MLC process. Most of the MLC process can be viewed as transformation among tensor functions. The main thing we aim to answer in our following up are: - What are the possible abstractions to represent the tensor function. - What are possible transformations among the tensor functions. Today we are going to cover part of that by focusing on primitive tensor functions. Learning one Tensor Program Abstraction – TensorIR ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We have gone over the primitive tensor function and discussed the high-level idea of tensor program abstractions. Now we are ready to learn one specific instance of tensor program abstraction called TensorIR. TensorIR is the tensor program abstraction in Apache TVM, which is one of the standard machine learning compilation frameworks. .. raw:: latex \diilbookstyleinputcell .. code:: python import numpy as np import tvm from tvm.ir.module import IRModule from tvm.script import tir as T The primary purpose of tensor program abstraction is to represent loops and corresponding hardware acceleration choices such as threading, use of specialized hardware instructions, and memory access. To help our explanations, let us use the following sequence of tensor computations as a motivating example. Specifically, for two :math:`128 \times 128` matrices A and B, let us perform the following two steps of tensor computations. - :math:`Y_{i, j} = \sum_k A_{i, k} \times B_{k, j}` - :math:`C_{i, j} = \mathbb{relu}(Y_{i, j}) = \mathbb{max}(Y_{i, j}, 0)` The above computations resemble a typical primitive tensor function commonly seen in neural networks – a linear layer with relu activation. To begin with, we can implement the two operations using array computations in NumPy as follows. .. raw:: latex \diilbookstyleinputcell .. code:: python dtype = "float32" a_np = np.random.rand(128, 128).astype(dtype) b_np = np.random.rand(128, 128).astype(dtype) # a @ b is equivalent to np.matmul(a, b) c_mm_relu = np.maximum(a_np @ b_np, 0) Under the hood, NumPy calls into libraries (such as OpenBLAS) and some of its own implementations in lower-level C languages to execute these computations. From the tensor program abstraction point of view, we would like to see through the details **under the hood** of these array computations. Specifically, we want to ask: what are the possible ways to implement the corresponding computations? For the purpose of illustrating details under the hood, we will write examples in a restricted subset of NumPy API – which we call **low-level numpy** that uses the following conventions: - We will use a loop instead of array functions when necessary to demonstrate the possible loop computations. - When possible, we always explicitly allocate arrays via numpy.empty and pass them around. Note that this is not how one would typically write NumPy programs. Still, they closely resemble what happens under the hood – most real-world deployment solutions handle allocations separately from computations. The specific libraries perform the computation using different forms of loops and arithmetic computations. Of course, primarily, they are implemented using lower-level languages such as ``C.`` .. raw:: latex \diilbookstyleinputcell .. code:: python def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): Y = np.empty((128, 128), dtype="float32") for i in range(128): for j in range(128): for k in range(128): if k == 0: Y[i, j] = 0 Y[i, j] = Y[i, j] + A[i, k] * B[k, j] for i in range(128): for j in range(128): C[i, j] = max(Y[i, j], 0) The program above is one way to implement the ``mm_relu`` operation. The program contains two stages: first we allocate an intermediate storage :math:`Y` and store the result of matrix multiplication there. Then we compute the relu in a second sequence of for loops. One thing you might notice is that this is certainly not the only way to implement the ``mm_relu``. Likely this is also not the first thing that you might come up with on top of your mind. Nevertheless, this is one way to implement ``mm_relu``, we can verify the correctness of the code by comparing our result to the original one using array computation. We will come back and revisit other possible ways in the later part of this tutorial. .. raw:: latex \diilbookstyleinputcell .. code:: python c_np = np.empty((128, 128), dtype=dtype) lnumpy_mm_relu(a_np, b_np, c_np) np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5) The above example code shows how we can bring an **under the hood** implementation of ``mm_relu``. Of course, the code itself will run much slower because of the python interpreter. Nevertheless, the example numpy code contains all the possible elements we will use in real-world implementations of those computations. - Multi-dimensional buffer (arrays). - Loops over array dimensions. - Computations statements are executed under the loops. With the low-level NumPy example in mind, now we are ready to introduce TensorIR. The code block below shows a TensorIR implementation of ``mm_relu``. The particular code is implemented in a language called TVMScript, which is a domain-specific dialect embedded in python AST. .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModule: @T.prim_func def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): with T.block("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(128, 128): with T.block("C"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) It is helpful to be able to see the numpy code and the TensorIR code side-by-side and check the corresponding elements, and we are going to walk through each of them in detail. .. figure:: ../img/tensor_func_and_numpy.png Let us first start by reviewing elements that have a direct correspondence between the numpy and TensorIR side. Then we will come back and review additional elements that are not part of the numpy program. Function Parameters and Buffers ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ First, let us see the function parameters. The function parameters correspond to the same set of parameters on the numpy function. .. raw:: latex \diilbookstyleinputcell .. code:: python # TensorIR def mm_relu(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]): ... # numpy def lnumpy_mm_relu(A: np.ndarray, B: np.ndarray, C: np.ndarray): ... Here A, B, and C takes a type named ``T.Buffer``, which with shape argument ``(128, 128)`` and data type ``float32``. This additional information helps possible MLC process to generate code that specializes in the shape and data type. Similarly, TensorIR also uses a buffer type in intermediate result allocation. .. raw:: latex \diilbookstyleinputcell .. code:: python # TensorIR Y = T.alloc_buffer((128, 128), dtype="float32") # numpy Y = np.empty((128, 128), dtype="float32") For Loop Iterations ^^^^^^^^^^^^^^^^^^^ There are also direct correspondence of loop iterations. ``T.grid`` is a syntactic sugar in TensorIR for us to write multiple nested iterators. .. raw:: latex \diilbookstyleinputcell .. code:: python # TensorIR for i, j, k in T.grid(128, 128, 128): # numpy for i in range(128): for j in range(128): for k in range(128): Computational Block ^^^^^^^^^^^^^^^^^^^ One of the main differences comes from the computational statement. TensorIR contains an additional construct called ``T.block``. .. raw:: latex \diilbookstyleinputcell .. code:: python # TensorIR with T.block("Y"): vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] # coressponding numpy code vi, vj, vk = i, j, k if vk == 0: Y[vi, vj] = 0 Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] A **block** is a basic unit of computation in TensorIR. Notably, the block contains a few additional information than the plain NumPy code. A block contains a set of block axes (``vi, vj, vk``) and computations defined around them. .. raw:: latex \diilbookstyleinputcell .. code:: python vi = T.axis.spatial(128, i) vj = T.axis.spatial(128, j) vk = T.axis.reduce(128, k) The above three lines declare the **key properties** about block axes in the following syntax. :: [block_axis] = T.axis.[axis_type]([axis_range], [mapped_value]) The three lines contain the following information: - They define where should vi, vj, vk be bound to (in this case i, j k). - They declare the original range that the vi, vj, vk are supposed to be (the ``128`` in ``T.axis.spatial(128, i)``) - They declare the properties of the iterators (``spatial``, ``reduce``) Let us walk through those property one by one. First of all, in terms of the bounding relation. ``vi = T.axis.spatial(128, i)`` effectively implies ``vi = i``. The ``[axis_range]`` value provided the expected range of the ``[block_axis]``. For example, ``128`` in ``vi = T.axis.spatial(128, i)`` provides an indication that ``vi`` should be in the ``range(0, 128)``. Block Axis Properties ^^^^^^^^^^^^^^^^^^^^^ Let us now start to take a closer look at the block axis properties. These axis properties marks the relation of the axis to the computation being performed. The figure below summarizes the block (iteration) axes and the read write relations of block Y. Note that strictly speaking the block is doing (reduction) updates to ``Y``, we mark this as write for now as we don’t need value of ``Y`` from another block. .. figure:: ../img/tensor_ir_block_axis.png In our example, block Y computes the result ``Y[vi, vj]`` by reading values from ``A[vi, vk]`` and ``B[vk, vj]`` and perform sum over all possible ``vk``. In this particular example, if we fix ``vi``, ``vj`` to be ``(0, 1)``, and run the block for ``vk in range(0, 128)``, we can effectively compute ``C[0, 1]`` independent from other possible locations (that have different values of vi, vj). Notably, for a fixed value of vi and vj, the computation block produces a point value at a spatial location of Y (``Y[vi, vj]``) that is independent from other locations in ``Y`` (with a different ``vi, vj`` values). we can call ``vi``, ``vj`` **spatial axes** as they directly corresponds to the beginning of a spatial region of buffers that the block writes to. The axes that involves in reduction (``vk``) are named as **reduce axes**. Why Extra Information in Block ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ One crucial observation is that the additional information (block axis range and their properties) makes the block to be **self-contained** when it comes to the iterations that it is supposed to carry out independent from the external loop-nest ``i``, ``j``, ``k``. The block axis information also provides additional properties that help us to validate the correctness of the external loops that are used to carry out the computation. For example, the above code block will result in an error because the loop expects an iterator of size ``128``, but we only bound it to a for loop of size ``127``. .. raw:: latex \diilbookstyleinputcell .. code:: python # wrong program due to loop and block iteration mismatch for i in range(127): with T.block("C"): vi = T.axis.spatial(128, i) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ error here due to iterator size mismatch ... This additional information also helps us in following machine learning compilation analysis. For example, while we can always parallelize over spatial axes, parallelizing over reduce axes will require specific strategies. Sugars for Block Axes Binding ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In situations where each of the block axes is directly mapped to an outer loop iterator, we can use ``T.axis.remap`` to declare the block axis in a single line. .. raw:: latex \diilbookstyleinputcell .. code:: python # SSR means the properties of each axes are "spatial", "spatial", "reduce" vi, vj, vk = T.axis.remap("SSR", [i, j, k]) is equivalent to .. raw:: latex \diilbookstyleinputcell .. code:: python vi = T.axis.spatial(range_of_i, i) vj = T.axis.spatial(range_of_j, j) vk = T.axis.reduce(range_of_k, k) So we can also write the programs as follows. .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModuleWithAxisRemapSugar: @T.prim_func def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Y = T.alloc_buffer((128, 128), dtype="float32") for i, j, k in T.grid(128, 128, 128): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = T.max(Y[vi, vj], T.float32(0)) Function Attributes and Decorators ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ So far, we have covered most of the elements in TensorIR. In this part, we will go over the remaining elements of the script. The function attribute information contains extra information about the function. .. raw:: latex \diilbookstyleinputcell .. code:: python T.func_attr({"global_symbol": "mm_relu", "tir.noalias": True}) Here ``global_symbol`` corresponds to the name of the function, and ``tir.noalias`` is an attribute indicating that all the buffer memories do not overlap. You also feel free safely skip these attributes for now as they won’t affect the overall understanding of the high-level concepts. The two decorators, ``@tvm.script.ir_module`` and ``@T.prim_func`` are used to indicate the type of the corresponding part. ``@tvm.script.ir_module`` indicate that MyModule is an ``IRModule``. IRModule is the container object to hold a collection of tensor functions in machine learning compilation. .. raw:: latex \diilbookstyleinputcell .. code:: python type(MyModule) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.ir.module.IRModule .. raw:: latex \diilbookstyleinputcell .. code:: python type(MyModule["mm_relu"]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.tir.function.PrimFunc Up until now, we have only seen IRModules containing a single tensor function. An IRModule in the MLC process can contain multiple tensor functions. The following code block shows an example of an IRModule with two functions. .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModuleWithTwoFunctions: @T.prim_func def mm(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), Y: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "mm", "tir.noalias": True}) for i, j, k in T.grid(128, 128, 128): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj] @T.prim_func def relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")): T.func_attr({"global_symbol": "relu", "tir.noalias": True}) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.max(A[vi, vj], T.float32(0)) Section Checkpoint ^^^^^^^^^^^^^^^^^^ So far, we have gone through one example instance of TensorIR program and covered most of the elements, including: - Buffer declarations in parameters and intermediate temporary memory. - For loop iterations. - **Block** and block axes properties. In this section, we have gone through one example instance of TensorIR that covers the most common elements in MLC. TensorIR contains more elements than what we went over in this section, but this section covers most of the key parts that can get us started in the MLC journey. We will cover new elements as we encounter them in the later chapters. Transformation ~~~~~~~~~~~~~~ In the last section, we learned about TensorIR and its key elements. Now, let us get to the main ingredients of all MLC flows – transformations of primitive tensor functions. In the last section, we have given an example of how to write ``mm_relu`` using low-level numpy. In practice, there can be multiple ways to implement the same functionality, and each implementation can result in different performance. We will discuss the reason behind the performance and how to leverage those variants in future lectures. In this lecture, let us focus on the ability to get different implementation variants using transformations. .. raw:: latex \diilbookstyleinputcell .. code:: python def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray): Y = np.empty((128, 128), dtype="float32") for i in range(128): for j0 in range(32): for k in range(128): for j1 in range(4): j = j0 * 4 + j1 if k == 0: Y[i, j] = 0 Y[i, j] = Y[i, j] + A[i, k] * B[k, j] for i in range(128): for j in range(128): C[i, j] = max(Y[i, j], 0) c_np = np.empty((128, 128), dtype=dtype) lnumpy_mm_relu_v2(a_np, b_np, c_np) np.testing.assert_allclose(c_mm_relu, c_np, rtol=1e-5) The above code block shows a slightly different variation of ``mm_relu``. To see the relation to the original program - We replace the ``j`` loop with two loops, ``j0`` and ``j1``. - The order of iterations changes slightly In order to get ``lnumpy_mm_relu_v2``, we have to rewrite a new function (or manual copy-pasting and editing). TensorIR introduces a utility called Schedule that allows us to do that pragmatically. To remind ourselves, let us look again at the current MyModule content. .. raw:: latex \diilbookstyleinputcell .. code:: python import IPython IPython.display.Code(MyModule.script(), language="python") .. raw:: html
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, j_1, k in T.grid(128, 32, 4, 128):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0, k, j_1 in T.grid(128, 32, 128, 4):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 32):
for k, j_1 in T.grid(128, 4):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(4):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 32):
for k, j_1 in T.grid(128, 4):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(4):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 4 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j_0 in T.grid(128, 16):
for k, j_1 in T.grid(128, 8):
with T.block("Y"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + j_1)
vk = T.axis.reduce(128, k)
T.reads(A[vi, vk], B[vk, vj])
T.writes(Y[vi, vj])
with T.init():
Y[vi, vj] = T.float32(0)
Y[vi, vj] = Y[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0 in range(8):
with T.block("C"):
vi = T.axis.spatial(128, i)
vj = T.axis.spatial(128, j_0 * 8 + ax0)
T.reads(Y[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.max(Y[vi, vj], T.float32(0))
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def mm_relu(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
T.func_attr({"global_symbol": "mm_relu", "tir.noalias": T.bool(True)})
# with T.block("root"):
Y = T.alloc_buffer((128, 128))
for i, j, k in T.grid(128, 128, 128):
with T.block("Y"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_k, v_j])
T.writes(Y[v_i, v_j])
with T.init():
Y[v_i, v_j] = T.float32(0)
Y[v_i, v_j] = Y[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
for i, j in T.grid(128, 128):
with T.block("C"):
v_i, v_j = T.axis.remap("SS", [i, j])
T.reads(Y[v_i, v_j])
T.writes(C[v_i, v_j])
C[v_i, v_j] = T.max(Y[v_i, v_j], T.float32(0))