Part 2 ------ We discussed building MLC flows for CPU and GPU environments in the past chapters. This chapter focuses on how we build conceptual programming models for specialized hardware backends. Preparations ~~~~~~~~~~~~ To begin with, let us import the necessary dependencies. .. raw:: latex \diilbookstyleinputcell .. code:: python import numpy as np import tvm from tvm import relax from tvm.ir.module import IRModule from tvm.script import relax as R from tvm.script import tir as T Hardware Specialization Trend ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. figure:: ../img/hardware_specialization.png If we look at the machine learning hardware landscape, one emerging theme recently is specialization. Traditionally, we build our solutions on generic scalar processors, where we can perform operations on one floating point at a time. The vector instructions set such as AVX and ARM/Neon provide effective ways to speed up our programs but also bring some complexities to how we write the programs. The latest accelerators for machine learning introduced specialized units for tensor computing, with instructions for multi-dimensional data copy and matrix/tensor computations. Key Elements of Specialized Code ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To help us better understand elements of specialized hardware programming, let us first study the following **low-level NumPy** code. While this code still runs in python, it resembles a set of possible operations that can happen in a specialized hardware backend. .. raw:: latex \diilbookstyleinputcell .. code:: python def accel_fill_zero(C): C[:] = 0 def accel_tmm_add(C, A, B): C[:] += A @ B.T def accel_dma_copy(reg, dram): reg[:] = dram[:] def lnumpy_tmm(A: np.ndarray, B: np.ndarray, C: np.ndarray): # a special accumulator memory C_accumulator = np.empty((16, 16), dtype="float32") A_reg = np.empty((16, 16), dtype="float32") B_reg = np.empty((16, 16), dtype="float32") for i in range(64): for j in range(64): accel_fill_zero(C_accumulator[:,:]) for k in range(64): accel_dma_copy(A_reg[:], A[i * 16 : i * 16 + 16, k * 16 : k * 16 + 16]) accel_dma_copy(B_reg[:], B[j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) accel_tmm_add(C_accumulator[:,:], A_reg, B_reg) accel_dma_copy(C[i * 16 : i * 16 + 16, j * 16 : j * 16 + 16], C_accumulator[:,:]) .. figure:: ../img/hardware_specialization_abc.png The above low-level NumPy program contains the following key elements: - The basic unit of computation is a 16x16x16 matrix multiplication (``accel_tmm_add``) - ``accel_tmm_add`` takes in two inputs – ``A_reg`` and ``B_reg`` and accumulates into an accumulator memory. - The data copy is performed using a special function (``accel_dma_copy``). In a real-world hardware backend, we usually expect ``A_reg``, ``B_reg``, and ``C_accumulator`` to map to special memory regions (or registers) in the hardware. These are called **special memory scopes**. Additionally, there is a limited set of hardware-accelerated operations we can perform on these settings. Operations such ``accel_tmm_add`` can be mapped to real hardware instructions or an efficient kernel function implementation provided by the vendor. We can run the following code block to confirm the low-level NumPy code runs correctly. .. raw:: latex \diilbookstyleinputcell .. code:: python dtype = "float32" a_np = np.random.rand(1024, 1024).astype(dtype) b_np = np.random.rand(1024, 1024).astype(dtype) c_tmm = a_np @ b_np.T .. raw:: latex \diilbookstyleinputcell .. code:: python c_np = np.empty((1024, 1024), dtype="float32") lnumpy_tmm(a_np, b_np, c_np) np.testing.assert_allclose(c_np, c_tmm, rtol=1e-5) A Block with Tensorized Computation ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ One of our key observations is that the specialized accelerator code is not structured in the unit of scalar computations. Most of the TensorIR code we have run so far contains a block that computes a single element in the output Tensor. Many specialized accelerators run computations over regions of tensors. The block construct in TensorIR helps us to group such relevant computation. .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MatmulBlockModule: @T.prim_func def main( A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32"), ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i0, j0, k0 in T.grid(64, 64, 64): with T.block("tmm-16x16"): vi0, vj0, vk0 = T.axis.remap("SSR", [i0, j0, k0]) with T.init(): for i1, j1 in T.grid(16, 16): with T.block("tmm_init"): vi1, vj1 = T.axis.remap("SS", [i1, j1]) C[vi0 * 16 + vi1, vj0 * 16 + vj1] = T.float32(0) for i1, j1, k1 in T.grid(16, 16, 16): with T.block("tmm"): vi1, vj1, vk1 = T.axis.remap("SSR", [i1, j1, k1]) C[vi0 *16 + vi1, vj0 * 16 + vj1] += \ A[vi0 * 16 + vi1, vk0 * 16 + vk1] * B[vj0 * 16 + vj1, vk0 * 16 + vk1] .. raw:: latex \diilbookstyleinputcell .. code:: python MatmulBlockModule.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, j0, k0 in T.grid(64, 64, 64):
with T.block("tmm-16x16"):
vi0, vj0, vk0 = T.axis.remap("SSR", [i0, j0, k0])
T.reads(A[vi0 * 16:vi0 * 16 + 16, vk0 * 16:vk0 * 16 + 16], B[vj0 * 16:vj0 * 16 + 16, vk0 * 16:vk0 * 16 + 16])
T.writes(C[vi0 * 16:vi0 * 16 + 16, vj0 * 16:vj0 * 16 + 16])
with T.init():
for i1, j1 in T.grid(16, 16):
with T.block("tmm_init"):
vi1, vj1 = T.axis.remap("SS", [i1, j1])
T.reads()
T.writes(C[vi0 * 16 + vi1, vj0 * 16 + vj1])
C[vi0 * 16 + vi1, vj0 * 16 + vj1] = T.float32(0)
for i1, j1, k1 in T.grid(16, 16, 16):
with T.block("tmm"):
vi1, vj1, vk1 = T.axis.remap("SSR", [i1, j1, k1])
T.reads(C[vi0 * 16 + vi1, vj0 * 16 + vj1], A[vi0 * 16 + vi1, vk0 * 16 + vk1], B[vj0 * 16 + vj1, vk0 * 16 + vk1])
T.writes(C[vi0 * 16 + vi1, vj0 * 16 + vj1])
C[vi0 * 16 + vi1, vj0 * 16 + vj1] = C[vi0 * 16 + vi1, vj0 * 16 + vj1] + A[vi0 * 16 + vi1, vk0 * 16 + vk1] * B[vj0 * 16 + vj1, vk0 * 16 + vk1]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i0_0, j0, i0_1, k0 in T.grid(16, 64, 4, 64):
with T.block("tmm-16x16"):
vi0 = T.axis.spatial(64, i0_0 * 4 + i0_1)
vj0, vk0 = T.axis.remap("SR", [j0, k0])
T.reads(A[vi0 * 16:vi0 * 16 + 16, vk0 * 16:vk0 * 16 + 16], B[vj0 * 16:vj0 * 16 + 16, vk0 * 16:vk0 * 16 + 16])
T.writes(C[vi0 * 16:vi0 * 16 + 16, vj0 * 16:vj0 * 16 + 16])
with T.init():
for i1, j1 in T.grid(16, 16):
with T.block("tmm_init"):
vi1, vj1 = T.axis.remap("SS", [i1, j1])
T.reads()
T.writes(C[vi0 * 16 + vi1, vj0 * 16 + vj1])
C[vi0 * 16 + vi1, vj0 * 16 + vj1] = T.float32(0)
for i1, j1, k1 in T.grid(16, 16, 16):
with T.block("tmm"):
vi1, vj1, vk1 = T.axis.remap("SSR", [i1, j1, k1])
T.reads(C[vi0 * 16 + vi1, vj0 * 16 + vj1], A[vi0 * 16 + vi1, vk0 * 16 + vk1], B[vj0 * 16 + vj1, vk0 * 16 + vk1])
T.writes(C[vi0 * 16 + vi1, vj0 * 16 + vj1])
C[vi0 * 16 + vi1, vj0 * 16 + vj1] = C[vi0 * 16 + vi1, vj0 * 16 + vj1] + A[vi0 * 16 + vi1, vk0 * 16 + vk1] * B[vj0 * 16 + vj1, vk0 * 16 + vk1]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i_0, j_0, k_0, i_1, j_1, k_1 in T.grid(64, 64, 64, 16, 16, 16):
with T.block("matmul"):
vi = T.axis.spatial(1024, i_0 * 16 + i_1)
vj = T.axis.spatial(1024, j_0 * 16 + j_1)
vk = T.axis.reduce(1024, k_0 * 16 + k_1)
T.reads(A[vi, vk], B[vj, vk])
T.writes(C[vi, vj])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
for i_0, j_0, k_0 in T.grid(64, 64, 64):
with T.block("matmul_o"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(A[vi_o * 16:vi_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], B[vj_o * 16:vj_o * 16 + 16, vk_o * 16:vk_o * 16 + 16])
T.writes(C[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16])
with T.init():
for i_1, j_1 in T.grid(16, 16):
with T.block("matmul_init"):
vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
C[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
for i_1, j_1, k_1 in T.grid(16, 16, 16):
with T.block("matmul"):
vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
T.reads(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
A_global_A_reg = T.alloc_buffer((1024, 1024), scope="global.A_reg")
B_global_B_reg = T.alloc_buffer((1024, 1024), scope="global.B_reg")
C_global_accumulator = T.alloc_buffer((1024, 1024), scope="global.accumulator")
for i_0, j_0 in T.grid(64, 64):
for k_0 in range(64):
for ax0, ax1 in T.grid(16, 16):
with T.block("A_global.A_reg"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(A[v0, v1])
T.writes(A_global_A_reg[v0, v1])
A_global_A_reg[v0, v1] = A[v0, v1]
for ax0, ax1 in T.grid(16, 16):
with T.block("B_global.B_reg"):
v0 = T.axis.spatial(1024, j_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(B[v0, v1])
T.writes(B_global_B_reg[v0, v1])
B_global_B_reg[v0, v1] = B[v0, v1]
with T.block("matmul_o"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(A_global_A_reg[vi_o * 16:vi_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], B_global_B_reg[vj_o * 16:vj_o * 16 + 16, vk_o * 16:vk_o * 16 + 16])
T.writes(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16])
with T.init():
for i_1, j_1 in T.grid(16, 16):
with T.block("matmul_init"):
vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
for i_1, j_1, k_1 in T.grid(16, 16, 16):
with T.block("matmul"):
vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
T.reads(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
T.writes(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
for ax0, ax1 in T.grid(16, 16):
with T.block("C_global.accumulator"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, j_0 * 16 + ax1)
T.reads(C_global_accumulator[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_global_accumulator[v0, v1]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
A_global_A_reg = T.alloc_buffer((1024, 1024), scope="global.A_reg")
B_global_B_reg = T.alloc_buffer((1024, 1024), scope="global.B_reg")
C_global_accumulator = T.alloc_buffer((1024, 1024), scope="global.accumulator")
for i_0, j_0 in T.grid(64, 64):
with T.block("matmul_o_init"):
vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
T.reads()
T.writes(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16])
for i_1, j_1 in T.grid(16, 16):
with T.block("matmul_init"):
vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
for k_0 in range(64):
for ax0, ax1 in T.grid(16, 16):
with T.block("A_global.A_reg"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(A[v0, v1])
T.writes(A_global_A_reg[v0, v1])
A_global_A_reg[v0, v1] = A[v0, v1]
for ax0, ax1 in T.grid(16, 16):
with T.block("B_global.B_reg"):
v0 = T.axis.spatial(1024, j_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(B[v0, v1])
T.writes(B_global_B_reg[v0, v1])
B_global_B_reg[v0, v1] = B[v0, v1]
with T.block("matmul_o_update"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16], A_global_A_reg[vi_o * 16:vi_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], B_global_B_reg[vj_o * 16:vj_o * 16 + 16, vk_o * 16:vk_o * 16 + 16])
T.writes(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16])
for i_1, j_1, k_1 in T.grid(16, 16, 16):
with T.block("matmul"):
vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
T.reads(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i], B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
T.writes(C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = C_global_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + A_global_A_reg[vi_o * 16 + vi_i, vk_o * 16 + vk_i] * B_global_B_reg[vj_o * 16 + vj_i, vk_o * 16 + vk_i]
for ax0, ax1 in T.grid(16, 16):
with T.block("C_global.accumulator"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, j_0 * 16 + ax1)
T.reads(C_global_accumulator[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_global_accumulator[v0, v1]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
# with T.block("root"):
A_global_A_reg = T.alloc_buffer((1024, 1024), scope="global.A_reg")
B_global_B_reg = T.alloc_buffer((1024, 1024), scope="global.B_reg")
C_global_accumulator = T.alloc_buffer((1024, 1024), scope="global.accumulator")
for i_0, j_0 in T.grid(64, 64):
with T.block("matmul_o_init"):
vi_o, vj_o = T.axis.remap("SS", [i_0, j_0])
T.reads()
T.writes(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16])
for i_1, j_1 in T.grid(16, 16):
with T.block("matmul_init"):
vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
T.reads()
T.writes(C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
C_global_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
for k_0 in range(64):
for ax0, ax1 in T.grid(16, 16):
with T.block("A_global.A_reg"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(A[v0, v1])
T.writes(A_global_A_reg[v0, v1])
A_global_A_reg[v0, v1] = A[v0, v1]
for ax0, ax1 in T.grid(16, 16):
with T.block("B_global.B_reg"):
v0 = T.axis.spatial(1024, j_0 * 16 + ax0)
v1 = T.axis.spatial(1024, k_0 * 16 + ax1)
T.reads(B[v0, v1])
T.writes(B_global_B_reg[v0, v1])
B_global_B_reg[v0, v1] = B[v0, v1]
with T.block("matmul_o_update"):
vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
T.reads(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16], A_global_A_reg[vi_o * 16:vi_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], B_global_B_reg[vj_o * 16:vj_o * 16 + 16, vk_o * 16:vk_o * 16 + 16])
T.writes(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16])
A_1 = T.match_buffer(A_global_A_reg[vi_o * 16:vi_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], (16, 16), strides=("A_s0", 1), scope="global.A_reg", offset_factor=16)
B_1 = T.match_buffer(B_global_B_reg[vj_o * 16:vj_o * 16 + 16, vk_o * 16:vk_o * 16 + 16], (16, 16), strides=("B_s0", 1), scope="global.B_reg", offset_factor=16)
C_1 = T.match_buffer(C_global_accumulator[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16], (16, 16), strides=("C_s0", 1), scope="global.accumulator", offset_factor=16)
T.call_extern("int32", "tmm16", T.tvm_access_ptr(T.type_annotation("float32"), C_1.data, C_1.elem_offset, C_1.strides[0] * 16, 2), T.tvm_access_ptr(T.type_annotation("float32"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), T.tvm_access_ptr(T.type_annotation("float32"), B_1.data, B_1.elem_offset, B_1.strides[0] * 16, 1), A_1.strides[0], B_1.strides[0], C_1.strides[0])
for ax0, ax1 in T.grid(16, 16):
with T.block("C_global.accumulator"):
v0 = T.axis.spatial(1024, i_0 * 16 + ax0)
v1 = T.axis.spatial(1024, j_0 * 16 + ax1)
T.reads(C_global_accumulator[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_global_accumulator[v0, v1]