第一部分 -------- 在上一章中,我们讨论了 CPU 环境中的 MLC 流程。本章节将讨论如何将一些优化带到 GPU 上。我们将使用 CUDA 编程语言。然而,同样的概念也适用于其他类型的 GPU。 安装环境 ~~~~~~~~ 在本课程中,我们将使用 TVM 中正在进行的一些开发。在\ **第一部分**\ 中,我们依赖于 CUDA 11 环境,因此需要安装特定的包。 .. raw:: latex \diilbookstyleinputcell .. code:: bash python3 -m pip install mlc-ai-nightly-cu110 -f https://mlc.ai/wheels **注意:我们目前的系统尚不支持 GPU,因此部分代码不会被运行。** 准备工作 ~~~~~~~~ 首先,让我们导入必要的依赖项。 .. raw:: latex \diilbookstyleinputcell .. code:: python import numpy as np import tvm from tvm import relax from tvm.ir.module import IRModule from tvm.script import relax as R from tvm.script import tir as T GPU 体系结构 ~~~~~~~~~~~~ 让我们首先回顾一下 GPU 体系结构。典型的 GPU 包含一组流处理器 (stream multi-processors, SM),每个流处理器都有许多核心。 GPU 设备是大规模并行的,允许我们同时执行许多任务。 .. figure:: ../img/gpu_arch.png 要对 GPU 进行编程,我们需要创建一组线程块 (thread blocks),每个 thread 映射到单个核心,而 block 映射到流式多处理器 (SM)。 .. figure:: ../img/gpu_stream_processors.png 让我们使用向量相加示例开始 GPU 编程。以下 TensorIR 程序采用两个向量 ``A`` 和 ``B``\ ,执行元素相加,并将结果存储在 ``C`` 中。 .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModuleVecAdd: @T.prim_func def main(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.grid(1024): with T.block("C"): vi = T.axis.remap("S", [i]) C[vi] = A[vi] + B[vi] 我们首先将循环 ``i`` 拆分成两个循环。 .. raw:: latex \diilbookstyleinputcell .. code:: python sch = tvm.tir.Schedule(MyModuleVecAdd) block_C = sch.get_block("C") i, = sch.get_loops(block=block_C) i0, i1 = sch.split(i, [None, 128]) sch.mod.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for i_0, i_1 in T.grid(8, 128):
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A[vi], B[vi])
T.writes(C[vi])
C[vi] = A[vi] + B[vi]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A[vi], B[vi])
T.writes(C[vi])
C[vi] = A[vi] + B[vi]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1027,), "float32"), B: T.Buffer((1024,), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A[vi:vi + 3])
T.writes(B[vi])
B[vi] = A[vi] + A[vi + 1] + A[vi + 2]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1027,), "float32"), B: T.Buffer((1024,), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
A_shared = T.alloc_buffer((1027,), scope="shared")
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0 in range(130):
with T.block("A_shared"):
v0 = T.axis.spatial(1027, i_0 * 128 + ax0)
T.reads(A[v0])
T.writes(A_shared[v0])
A_shared[v0] = A[v0]
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A_shared[vi:vi + 3])
T.writes(B[vi])
B[vi] = A_shared[vi] + A_shared[vi + 1] + A_shared[vi + 2]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1027,), "float32"), B: T.Buffer((1024,), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
A_shared = T.alloc_buffer((1027,), scope="shared")
for i_0 in T.thread_binding(8, thread="blockIdx.x"):
for i_1 in T.thread_binding(128, thread="threadIdx.x"):
for ax0_0 in range(2):
for ax0_1 in T.thread_binding(128, thread="threadIdx.x"):
with T.block("A_shared"):
v0 = T.axis.spatial(1027, i_0 * 128 + (ax0_0 * 128 + ax0_1))
T.where(ax0_0 * 128 + ax0_1 < 130)
T.reads(A[v0])
T.writes(A_shared[v0])
A_shared[v0] = A[v0]
with T.block("C"):
vi = T.axis.spatial(1024, i_0 * 128 + i_1)
T.reads(A_shared[vi:vi + 3])
T.writes(B[vi])
B[vi] = A_shared[vi] + A_shared[vi + 1] + A_shared[vi + 2]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
C_local = T.alloc_buffer((1024, 1024), scope="local")
for i_0 in T.thread_binding(16, thread="blockIdx.y"):
for j_0 in T.thread_binding(16, thread="blockIdx.x"):
for i_1 in T.thread_binding(8, thread="threadIdx.y"):
for j_1 in T.thread_binding(8, thread="threadIdx.x"):
for i_2_init, j_2_init in T.grid(8, 8):
with T.block("C_init"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + i_2_init)
vj = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + j_2_init)
T.reads()
T.writes(C_local[vi, vj])
C_local[vi, vj] = T.float32(0)
for k_0 in range(256):
for k_1 in T.unroll(4):
for i_2, j_2 in T.grid(8, 8):
with T.block("C_update"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + i_2)
vj = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + j_2)
vk = T.axis.reduce(1024, k_0 * 4 + k_1)
T.reads(C_local[vi, vj], A[vi, vk], B[vk, vj])
T.writes(C_local[vi, vj])
C_local[vi, vj] = C_local[vi, vj] + A[vi, vk] * B[vk, vj]
for ax0, ax1 in T.grid(8, 8):
with T.block("C_local"):
v0 = T.axis.spatial(1024, i_0 * 64 + i_1 * 8 + ax0)
v1 = T.axis.spatial(1024, j_0 * 64 + j_1 * 8 + ax1)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# with T.block("root"):
C_local = T.alloc_buffer((1024, 1024), scope="local")
A_shared = T.alloc_buffer((1024, 1024), scope="shared")
B_shared = T.alloc_buffer((1024, 1024), scope="shared")
for i_0 in T.thread_binding(16, thread="blockIdx.y"):
for j_0 in T.thread_binding(16, thread="blockIdx.x"):
for i_1_j_1_fused in T.thread_binding(64, thread="threadIdx.x"):
for i_2_init, j_2_init in T.grid(8, 8):
with T.block("C_init"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + i_2_init)
vj = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + j_2_init)
T.reads()
T.writes(C_local[vi, vj])
C_local[vi, vj] = T.float32(0)
for k_0 in range(128):
for ax0_ax1_fused_0 in range(2):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(4):
with T.block("A_shared"):
v0 = T.axis.spatial(1024, i_0 * 64 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 8)
v1 = T.axis.spatial(1024, k_0 * 8 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 8)
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for ax0_ax1_fused_0 in range(2):
for ax0_ax1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(4):
with T.block("B_shared"):
v0 = T.axis.spatial(1024, k_0 * 8 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 64)
v1 = T.axis.spatial(1024, j_0 * 64 + (ax0_ax1_fused_0 * 256 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 64)
T.reads(B[v0, v1])
T.writes(B_shared[v0, v1])
B_shared[v0, v1] = B[v0, v1]
for k_1, i_2, j_2 in T.grid(8, 8, 8):
with T.block("C_update"):
vi = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + i_2)
vj = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + j_2)
vk = T.axis.reduce(1024, k_0 * 8 + k_1)
T.reads(C_local[vi, vj], A_shared[vi, vk], B_shared[vk, vj])
T.writes(C_local[vi, vj])
C_local[vi, vj] = C_local[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj]
for ax0, ax1 in T.grid(8, 8):
with T.block("C_local"):
v0 = T.axis.spatial(1024, i_0 * 64 + i_1_j_1_fused // 8 * 8 + ax0)
v1 = T.axis.spatial(1024, j_0 * 64 + i_1_j_1_fused % 8 * 8 + ax1)
T.reads(C_local[v0, v1])
T.writes(C[v0, v1])
C[v0, v1] = C_local[v0, v1]