自动程序优化 ============ 前言 ---- 在过去的章节中,我们学习了如何构建元张量函数并将它们连接起来以进行端到端的模型执行。到目前为止,我们使用了三种主要的抽象类型。 - 驱动高层执行的计算图抽象 - 元张量函数的抽象 - 通过注册环境函数从而能被调用的库函数 所有这些元素都封装在一个 IRModule 中。大多数 MLC 过程可以看作是元张量函数之间的变换。 有许多不同的方法可以变换同一个程序。本章将讨论自动化一些流程的方法。 准备工作 -------- 首先,我们将导入必要的依赖项并创建辅助函数。 .. raw:: latex \diilbookstyleinputcell .. code:: python import numpy as np import tvm from tvm import relax from tvm.ir.module import IRModule from tvm.script import relax as R from tvm.script import tir as T .. raw:: latex \diilbookstyleinputcell .. code:: python import IPython def code2html(code): """Helper function to use pygments to turn the code string into highlighted html.""" import pygments from pygments.formatters import HtmlFormatter from pygments.lexers import Python3Lexer formatter = HtmlFormatter() html = pygments.highlight(code, Python3Lexer(), formatter) return "%s\n" % (formatter.get_style_defs(".highlight"), html) 回顾:变换元张量函数 -------------------- 让我们首先回顾一下我们在前几章中所做的事情——变换单个元张量函数。 .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModule: @T.prim_func def main( A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32"), ): T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i, j, k in T.grid(128, 128, 128): with T.block("C"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] 首先,让我们定义一组用于测试评估的输入输出。 .. raw:: latex \diilbookstyleinputcell .. code:: python dtype = "float32" a_np = np.random.rand(128, 128).astype(dtype) b_np = np.random.rand(128, 128).astype(dtype) c_mm = a_np @ b_np 我们可以如下构建和运行 ``MyModule``\ 。 .. raw:: latex \diilbookstyleinputcell .. code:: python a_nd = tvm.nd.array(a_np) b_nd = tvm.nd.array(b_np) c_nd = tvm.nd.empty((128, 128), dtype="float32") lib = tvm.build(MyModule, target="llvm") f_timer_before = lib.time_evaluator("main", tvm.cpu()) print("Time cost of MyModule: %.3f ms" % (f_timer_before(a_nd, b_nd, c_nd).mean * 1000)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Time cost of MyModule: 2.076 ms 接下来,我们通过重新组织循环访问模式来稍微变换 ``MyModule``\ 。 .. raw:: latex \diilbookstyleinputcell .. code:: python def schedule_mm(sch: tvm.tir.Schedule, jfactor=4): block_C = sch.get_block("C", "main") i, j, k = sch.get_loops(block=block_C) j_0, j_1 = sch.split(loop=j, factors=[None, jfactor]) sch.reorder(i, j_0, k, j_1) sch.decompose_reduction(block_C, k) return sch .. raw:: latex \diilbookstyleinputcell .. code:: python sch = tvm.tir.Schedule(MyModule) sch = schedule_mm(sch) IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            for i, j_0 in T.grid(128, 32):
                for j_1_init in range(4):
                    with T.block("C_init"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
                        T.reads()
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.float32(0)
                for k, j_1 in T.grid(128, 4):
                    with T.block("C_update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 4 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    
然后我们可以构建并运行重新组织后的程序。 .. raw:: latex \diilbookstyleinputcell .. code:: python lib = tvm.build(sch.mod, target="llvm") f_timer_after = lib.time_evaluator("main", tvm.cpu()) print("Time cost of MyModule=>schedule_mm: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Time cost of MyModule=>schedule_mm: 1.797 ms 变换的历史轨迹 ~~~~~~~~~~~~~~ 除了 ``sch.mod``\ ,\ ``tir.Schedule`` 提供的另一个数据结构是\ **历史轨迹** (trace),它包含了 IRModule 在变换过程中所涉及的步骤。我们可以使用以下代码将其打印出来。 .. raw:: latex \diilbookstyleinputcell .. code:: python print(sch.trace) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) l4, l5 = sch.split(loop=l2, factors=[None, 4], preserve_unit_iters=True) sch.reorder(l1, l4, l3, l5) b6 = sch.decompose_reduction(block=b0, loop=l3) .. raw:: latex \diilbookstyleinputcell .. code:: python def schedule_mm(sch: tvm.tir.Schedule, jfactor=4): block_C = sch.get_block("C", "main") i, j, k = sch.get_loops(block=block_C) j_0, j_1 = sch.split(loop=j, factors=[None, jfactor]) sch.reorder(i, j_0, k, j_1) sch.decompose_reduction(block_C, k) return sch 上面的历史轨迹与我们在 ``schedule_mm`` 中指定的变换一致。需要注意的一点是,历史轨迹加上原始程序一起,为我们提供了一种能够完全重新生成最终输出程序的方法。记住这一点,我们将在本章中使用历史轨迹作为检查变换的另一种方式。 随机调度变换 (Stochastic Schedule Transformation) ------------------------------------------------- 到目前为止,我们已经详细说明了我们想要对原始 TensorIR 程序进行哪些变换。其中许多选择基于我们对底层环境的理解,例如缓存和硬件单元。 然而,在实践中,我们可能无法准确地决定每一个细节。因而,我们想指定\ **什么是变换程序的可能方法,同时省略一些细节**\ 。 实现目标的一种自然方法是在我们的变换中添加一些随机元素。下面的代码就是这样做的。 .. raw:: latex \diilbookstyleinputcell .. code:: python def stochastic_schedule_mm(sch: tvm.tir.Schedule): block_C = sch.get_block("C", "main") i, j, k = sch.get_loops(block=block_C) j_factors = sch.sample_perfect_tile(loop=j, n=2) j_0, j_1 = sch.split(loop=j, factors=j_factors) sch.reorder(i, j_0, k, j_1) sch.decompose_reduction(block_C, k) return sch .. figure:: ../img/auto_prog_optim_stoch_sch_transformation.png 让我们对比 ``stochastic_schedule_mm`` 和 ``schedule_mm``\ 。可以发现,它们唯一的区别是指定 ``j_factors`` 的方式。在 ``schedule_mm`` 中, ``j_factors`` 作为我们指定的参数传入。在 ``stochastic_schedule_mm`` 中,它来自 ``sch.sample_perfect_tile``\ 。 顾名思义,\ ``sch.sample_perfect_tile`` 尝试使用随机数来作为 ``j_factors`` 的值。它在输入循环的长度的因子中进行采样,以便采样结果能完美地分割循环。例如,当原始循环长度为 ``128`` 时,拆分循环的可能方式包括:\ ``[8, 16]``\ 、\ ``[32, 4]``\ 、\ ``[2, 64]``\ (注意 ``8 * 16 = 32 * 4 = 2 * 64 = 128``\ )。 我们首先通过运行以下代码块来尝试查看 ``stochastic_schedule_mm`` 的效果。尝试多次运行以下代码块并观察结果差异。你可能会发现每次运行代码块时, ``j_1`` 的循环边界都会发生变化。 .. raw:: latex \diilbookstyleinputcell .. code:: python sch = tvm.tir.Schedule(MyModule) sch = stochastic_schedule_mm(sch) IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            for i, j_0 in T.grid(128, 32):
                for j_1_init in range(4):
                    with T.block("C_init"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 4 + j_1_init)
                        T.reads()
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.float32(0)
                for k, j_1 in T.grid(128, 4):
                    with T.block("C_update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 4 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    
这里发生的是,每次我们运行 ``stochastic_schedule_mm`` 时,它都会随机采样一组不同的 ``j_factors``\ 。我们可以打印出最新的历史轨迹,以查看我们在采样中做出的决定。 .. raw:: latex \diilbookstyleinputcell .. code:: python print(sch.trace) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[32, 4]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) 当我们查看历史轨迹时,请密切注意 ``sample_perfect_tile`` 的 ``decision=[...]`` 部分。 它们对应于我们上次调用 ``stochastic_schedule_mm`` 时 ``sampling_perfect_tile`` 返回的值。 作为查看 ``stochastic_schedule_mm`` 不同采样结果的另一种方法,我们可以多次运行以下代码块并查看历史轨迹。 .. raw:: latex \diilbookstyleinputcell .. code:: python sch = tvm.tir.Schedule(MyModule) sch = stochastic_schedule_mm(sch) print(sch.trace) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[16, 8]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) 深入研究随机变换 ~~~~~~~~~~~~~~~~ 现在让我们更深入地研究随机调度变换中发生的事情。我们可以发现它是原始确定性变换的简单泛化,包含两个附加元素: - 来自 ``sample_perfect_tile`` 的随机变量和我们在示例中未涵盖的其他采样操作。 - 利用随机变量进行的后续变换操作。 让我们尝试逐步运行随机变换。 .. raw:: latex \diilbookstyleinputcell .. code:: python sch = tvm.tir.Schedule(MyModule) block_C = sch.get_block("C", "main") i, j, k = sch.get_loops(block=block_C) j_factors = sch.sample_perfect_tile(loop=j, n=2) .. raw:: latex \diilbookstyleinputcell .. code:: python type(j_factors[0]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.tir.expr.Var ``j_factors`` 中的元素并不是实整数。相反,它们是指被采样的随机变量的\ **符号变量**\ 。我们可以将这些变量传递给变换的 API 从而指定诸如因子值之类的选择。 .. raw:: latex \diilbookstyleinputcell .. code:: python print(sch.trace) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16]) 调度的历史轨迹在 ``decisions`` 字段中记录这些符号变量的选择。因此后续步骤将能够查找这些选择来决定如何拆分循环。 .. raw:: latex \diilbookstyleinputcell .. code:: python IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            for i, j, k in T.grid(128, 128, 128):
                with T.block("C"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(C[vi, vj])
                    with T.init():
                        C[vi, vj] = T.float32(0)
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    
如果我们查看当前时间点的代码,我们可以发现 IRModule 保持不变,因为我们只对随机变量进行了采样,但还没有基于它们进行任何变换操作。 现在让我们采取一些行动: .. raw:: latex \diilbookstyleinputcell .. code:: python j_0, j_1 = sch.split(loop=j, factors=j_factors) sch.reorder(i, j_0, k, j_1) 这些操作被记录在以下历史轨迹中。 .. raw:: latex \diilbookstyleinputcell .. code:: python print(sch.trace) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) 如果我们重新查看代码,变换后的模块现在对应于执行操作后的更新版本。 .. raw:: latex \diilbookstyleinputcell .. code:: python IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            for i, j_0, k, j_1 in T.grid(128, 8, 128, 16):
                with T.block("C"):
                    vi = T.axis.spatial(128, i)
                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
                    vk = T.axis.reduce(128, k)
                    T.reads(A[vi, vk], B[vk, vj])
                    T.writes(C[vi, vj])
                    with T.init():
                        C[vi, vj] = T.float32(0)
                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    
我们可以做一些进一步的变换以达到最终状态。 .. raw:: latex \diilbookstyleinputcell .. code:: python sch.reorder(i, j_0, k, j_1) sch.decompose_reduction(block_C, k) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tir.BlockRV(0x4eba280) .. raw:: latex \diilbookstyleinputcell .. code:: python IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            for i, j_0 in T.grid(128, 8):
                for j_1_init in range(16):
                    with T.block("C_init"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
                        T.reads()
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.float32(0)
                for k, j_1 in T.grid(128, 16):
                    with T.block("C_update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 16 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    
随机变换搜索 ------------ 你可能会意识到,\ ``stochastic_schedule_mm`` 创建了一个\ **可能程序的搜索空间**\ ,具体取决于在每个采样步骤中做出的具体决定。 .. figure:: ../img/auto_prog_optim_transformation_search.png 回到我们最初的直觉,我们希望能够指定一组\ **可能的程序**\ 而不是一个程序。 ``stochastic_schedule_mm`` 正是这样做的。当然,接下来要问的一个很自然的问题是:什么是最佳选择? 我们需要一个搜索算法来做到这一点。为了展示这里可以做什么,让我们首先在下面的代码块中尝试最直接的搜索算法——随机搜索。它尝试重复运行 ``stochastic_schedule_mm``\ ,获取转换后的模块,运行测试,然后保留历史上最好(用时最短)的模块。 .. raw:: latex \diilbookstyleinputcell .. code:: python def random_search(mod: tvm.IRModule, num_trials=5): best_result = None best_sch = None for i in range(num_trials): sch = stochastic_schedule_mm(tvm.tir.Schedule(mod)) lib = tvm.build(sch.mod, target="llvm") f_timer_after = lib.time_evaluator("main", tvm.cpu()) result = f_timer_after(a_nd, b_nd, c_nd).mean print("=====Attempt %d, time-cost: %.3f ms====" % (i, result * 1000)) print(sch.trace) # book keep the best result so far if best_result is None or result < best_result: best_result = result best_sch = sch return best_sch sch = random_search(MyModule) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output =====Attempt 0, time-cost: 1.586 ms==== # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[128, 1]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) =====Attempt 1, time-cost: 1.177 ms==== # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) =====Attempt 2, time-cost: 1.166 ms==== # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) =====Attempt 3, time-cost: 1.668 ms==== # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[128, 1]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) =====Attempt 4, time-cost: 1.126 ms==== # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) 如果我们运行代码,我们会发现它经过了几个选择,然后在五次试验中返回了最佳运行。 .. raw:: latex \diilbookstyleinputcell .. code:: python print(sch.trace) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output # from tvm import tir def apply_trace(sch: tir.Schedule) -> None: b0 = sch.get_block(name="C", func_name="main") l1, l2, l3 = sch.get_loops(block=b0) v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[8, 16]) l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True) sch.reorder(l1, l6, l3, l7) b8 = sch.decompose_reduction(block=b0, loop=l3) 在实践中,我们使用更智能的算法。如果我们对其它设备的优化感兴趣,我们还需要提供额外的工具,例如远程设备上的基准测试。 TVM 的 Meta-Schedule API 提供了这些附加功能。 ``meta_schedule`` 是支持搜索可能变换空间的命名空间。Meta-Schedule 在幕后做了很多额外的事情: - 跨越多个进程的并行基准测试。 - 使用\ **代价模型** (cost model) 来避免每次都进行基准测试。 - 基于历史轨迹进行\ **遗传搜索** (evolutionary search),而不是每次都随机采样。 尽管有这些工具,但我们关键思想是保持不变的:\ **使用随机变换来指定好的程序的搜索空间,使用 ``tune_tir`` API 帮助在搜索空间内搜索并找到最优的调度变换**\ 。 .. raw:: latex \diilbookstyleinputcell .. code:: python from tvm import meta_schedule as ms database = ms.tune_tir( mod=MyModule, target="llvm --num-cores=1", max_trials_global=64, num_trials_per_iter=64, space=ms.space_generator.ScheduleFn(stochastic_schedule_mm), work_dir="./tune_tmp", task_name="main" ) sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1") .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 2023-03-18 11:48:04 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0 .. raw:: html
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 3.4059 1231.4906 1231.4906 5 Y
.. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Total trials: 5 Total latency (us): 1231.49 2023-03-18 11:48:04 [DEBUG] [task_scheduler.cc:318] ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done ------------------------------------------------------------------------------------------------------ 0 | main | 4194304 | 1 | 3.4059 | 1231.4906 | 1231.4906 | 5 | Y ------------------------------------------------------------------------------------------------------ Total trials: 5 Total latency (us): 1231.49 ``tune_tir`` 函数返回在调优过程中找到的优化后的调度。 .. raw:: latex \diilbookstyleinputcell .. code:: python sch.trace.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm import tir
    def apply_trace(sch: tir.Schedule) -> None:
      b0 = sch.get_block(name="C", func_name="main")
      l1, l2, l3 = sch.get_loops(block=b0)
      v4, v5 = sch.sample_perfect_tile(loop=l2, n=2, max_innermost_factor=16, decision=[16, 8])
      l6, l7 = sch.split(loop=l2, factors=[v4, v5], preserve_unit_iters=True)
      sch.reorder(l1, l6, l3, l7)
      b8 = sch.decompose_reduction(block=b0, loop=l3)
      sch.enter_postproc()
    
.. raw:: latex \diilbookstyleinputcell .. code:: python IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            for i, j_0 in T.grid(128, 16):
                for j_1_init in range(8):
                    with T.block("C_init"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1_init)
                        T.reads()
                        T.writes(C[vi, vj])
                        C[vi, vj] = T.float32(0)
                for k, j_1 in T.grid(128, 8):
                    with T.block("C_update"):
                        vi = T.axis.spatial(128, i)
                        vj = T.axis.spatial(128, j_0 * 8 + j_1)
                        vk = T.axis.reduce(128, k)
                        T.reads(C[vi, vj], A[vi, vk], B[vk, vj])
                        T.writes(C[vi, vj])
                        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    
.. raw:: latex \diilbookstyleinputcell .. code:: python lib = tvm.build(sch.mod, target="llvm") f_timer_after = lib.time_evaluator("main", tvm.cpu()) print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Time cost of MyModule after tuning: 1.156 ms 利用默认的自动调度 ~~~~~~~~~~~~~~~~~~ 在上一节中,我们展示了如何使用我们精心设计的随机变换来优化 IRModule 的计算。Meta-Schedule 带有内置通用随机变换集合,能够适用于广泛的 TensorIR 计算。这种方法也称为\ **自动调度** (auto-scheduling),因为搜索空间是由系统生成的。我们可以通过删除行 ``space=ms.space_generator.ScheduleFn(stochastic_schedule_mm)`` 来运行它。 在底层,Meta-Schedule 分析每个 TensorIR block 的数据访问和循环模式,并提出对程序的随机变换方式。我们不会在本章中讨论这些通用的变换,但要注意它们也只是随机转换加上代码分析而已。我们可以使用上一节中学到的相同机制来增强自动调度。我们将在以后的章节中触及这个主题。 .. raw:: latex \diilbookstyleinputcell .. code:: python database = ms.tune_tir( mod=MyModule, target="llvm --num-cores=1", max_trials_global=64, num_trials_per_iter=64, work_dir="./tune_tmp", task_name="main", ) sch = ms.tir_integration.compile_tir(database, MyModule, "llvm --num-cores=1") .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 2023-03-18 11:48:59 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0 .. raw:: html
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 4194304 1 22.3003 188.0832 188.0832 64 Y
.. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Total trials: 64 Total latency (us): 188.083 2023-03-18 11:48:59 [DEBUG] [task_scheduler.cc:318] ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done ------------------------------------------------------------------------------------------------------ 0 | main | 4194304 | 1 | 22.3003 | 188.0832 | 188.0832 | 64 | Y ------------------------------------------------------------------------------------------------------ Total trials: 64 Total latency (us): 188.083 .. raw:: latex \diilbookstyleinputcell .. code:: python lib = tvm.build(sch.mod, target="llvm") f_timer_after = lib.time_evaluator("main", tvm.cpu()) print("Time cost of MyModule after tuning: %.3f ms" % (f_timer_after(a_nd, b_nd, c_nd).mean * 1000)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Time cost of MyModule after tuning: 0.260 ms 结果比我们的原始代码快得多。我们可以查看历史轨迹和最终代码。就本章而言,你不需要了解所有变换。在高层次的理解中,历史轨迹包含: - 更多级的循环转换 - 中间计算的矢量化 - 并行化和循环展开 .. raw:: latex \diilbookstyleinputcell .. code:: python sch.trace.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm import tir
    def apply_trace(sch: tir.Schedule) -> None:
      b0 = sch.get_block(name="C", func_name="main")
      b1 = sch.get_block(name="root", func_name="main")
      sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")
      l2, l3, l4 = sch.get_loops(block=b0)
      v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64, decision=[16, 4, 1, 2])
      l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8], preserve_unit_iters=True)
      v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 4, 16])
      l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16], preserve_unit_iters=True)
      v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64, decision=[16, 8])
      l23, l24 = sch.split(loop=l4, factors=[v21, v22], preserve_unit_iters=True)
      sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)
      b25 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")
      sch.reverse_compute_at(block=b25, loop=l18, preserve_unit_loops=True, index=-1)
      sch.annotate(block_or_loop=b1, ann_key="meta_schedule.parallel", ann_val=16)
      sch.annotate(block_or_loop=b1, ann_key="meta_schedule.vectorize", ann_val=64)
      v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2)
      sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v26)
      sch.enter_postproc()
      b27 = sch.get_block(name="root", func_name="main")
      sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.parallel")
      sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.vectorize")
      sch.unannotate(block_or_loop=b27, ann_key="meta_schedule.unroll_explicit")
      b28, b29 = sch.get_child_blocks(b27)
      l30, l31, l32, l33, l34, l35, l36, l37, l38, l39 = sch.get_loops(block=b28)
      l40 = sch.fuse(l30, l31, l32, preserve_unit_iters=True)
      sch.parallel(loop=l40)
      l41 = sch.fuse(l39, preserve_unit_iters=True)
      sch.vectorize(loop=l41)
      sch.annotate(block_or_loop=l40, ann_key="pragma_auto_unroll_max_step", ann_val=64)
      sch.annotate(block_or_loop=l40, ann_key="pragma_unroll_explicit", ann_val=1)
      l42, l43, l44, l45 = sch.get_loops(block=b29)
      l46 = sch.fuse(l45, preserve_unit_iters=True)
      sch.vectorize(loop=l46)
      sch.annotate(block_or_loop=l42, ann_key="pragma_auto_unroll_max_step", ann_val=64)
      sch.annotate(block_or_loop=l42, ann_key="pragma_unroll_explicit", ann_val=1)
      b47 = sch.get_block(name="C", func_name="main")
      l48, l49, l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b47)
      b56 = sch.decompose_reduction(block=b47, loop=l50)
    
.. raw:: latex \diilbookstyleinputcell .. code:: python IPython.display.HTML(code2html(sch.mod.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            C_global = T.alloc_buffer((128, 128))
            for i_0_j_0_i_1_fused in T.parallel(64, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for j_1 in range(2):
                    for i_2_init, j_2_init, i_3_init in T.grid(1, 4, 2):
                        for j_3_fused_init in T.vectorized(16):
                            with T.block("C_init"):
                                vi = T.axis.spatial(128, i_0_j_0_i_1_fused // 4 * 8 + i_0_j_0_i_1_fused % 4 * 2 + i_2_init * 2 + i_3_init)
                                vj = T.axis.spatial(128, j_1 * 64 + j_2_init * 16 + j_3_fused_init)
                                T.reads()
                                T.writes(C_global[vi, vj])
                                T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
                                C_global[vi, vj] = T.float32(0)
                    for k_0, i_2, j_2, k_1, i_3 in T.grid(16, 1, 4, 8, 2):
                        for j_3_fused in T.vectorized(16):
                            with T.block("C_update"):
                                vi = T.axis.spatial(128, i_0_j_0_i_1_fused // 4 * 8 + i_0_j_0_i_1_fused % 4 * 2 + i_2 * 2 + i_3)
                                vj = T.axis.spatial(128, j_1 * 64 + j_2 * 16 + j_3_fused)
                                vk = T.axis.reduce(128, k_0 * 8 + k_1)
                                T.reads(C_global[vi, vj], A[vi, vk], B[vk, vj])
                                T.writes(C_global[vi, vj])
                                T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
                                C_global[vi, vj] = C_global[vi, vj] + A[vi, vk] * B[vk, vj]
                    for ax0 in range(2):
                        for ax1_fused in T.vectorized(64):
                            with T.block("C_global"):
                                v0 = T.axis.spatial(128, i_0_j_0_i_1_fused // 4 * 8 + i_0_j_0_i_1_fused % 4 * 2 + ax0)
                                v1 = T.axis.spatial(128, j_1 * 64 + ax1_fused)
                                T.reads(C_global[v0, v1])
                                T.writes(C[v0, v1])
                                C[v0, v1] = C_global[v0, v1]
    
章节检查点 ~~~~~~~~~~ 让我们对到目前为止所学的内容进行检查。 - 随机调度允许我们表示“可能的变换是什么”。 - Meta-Schedule 的 ``tune_tir`` API 帮助我们在搜索空间内找到一个好的解决方案。 - Meta-Schedule 带有一组默认的内置随机变换,涵盖了广泛的搜索空间。 回到端到端模型执行 ------------------ 到目前为止,我们已经学会了自动优化单个元张量函数。我们如何才能把利用它改进我们的端到端模型执行? 从 MLC 的角度来看,自动搜索是一个模块化的步骤,我们只需要用调优结果提供的新的元张量函数实现替换原始的元张量函数实现。 我们将重用上一章中的两层 MLP 示例。 .. raw:: latex \diilbookstyleinputcell .. code:: python import torch import torchvision test_data = torchvision.datasets.FashionMNIST( root="data", train=False, download=True, transform=torchvision.transforms.ToTensor() ) test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True) class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] img, label = next(iter(test_loader)) img = img.reshape(1, 28, 28).numpy() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 100.0% Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 100.0% Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz 100.0% Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz 100.0%Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw .. raw:: latex \diilbookstyleinputcell .. code:: python import matplotlib.pyplot as plt plt.figure() plt.imshow(img[0]) plt.colorbar() plt.grid(False) plt.show() print("Class:", class_names[label[0]]) .. figure:: output_index_1f4d27_59_0.png .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Class: Shirt 我们同样下载我们将在示例中使用的预训练模型参数。 .. raw:: latex \diilbookstyleinputcell .. code:: python # Hide outputs !wget -nc https://github.com/mlc-ai/web-data/raw/main/models/fasionmnist_mlp_params.pkl .. figure:: ../img/e2e_fashionmnist_mlp_model.png 提醒一下,上面的图展示了我们用到的模型。 .. raw:: latex \diilbookstyleinputcell .. code:: python import pickle as pkl mlp_params = pkl.load(open("fasionmnist_mlp_params.pkl", "rb")) data_nd = tvm.nd.array(img.reshape(1, 784)) nd_params = {k: tvm.nd.array(v) for k, v in mlp_params.items()} 让我们使用一个混合 IRModule。它其中大多数步骤都调用环境函数,同时带有一个 TensorIR 函数 ``linear0``\ 。 .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModuleMixture: @T.prim_func def linear0(X: T.Buffer((1, 784), "float32"), W: T.Buffer((128, 784), "float32"), B: T.Buffer((128,), "float32"), Z: T.Buffer((1, 128), "float32")): T.func_attr({"global_symbol": "linear0", "tir.noalias": True}) Y = T.alloc_buffer((1, 128), "float32") for i, j, k in T.grid(1, 128, 784): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk] for i, j in T.grid(1, 128): with T.block("Z"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Y[vi, vj] + B[vj] @R.function def main(x: R.Tensor((1, 784), "float32"), w0: R.Tensor((128, 784), "float32"), b0: R.Tensor((128,), "float32"), w1: R.Tensor((10, 128), "float32"), b1: R.Tensor((10,), "float32")): with R.dataflow(): lv0 = R.call_dps_packed("linear0", (x, w0, b0), R.Tensor((1, 128), dtype="float32")) lv1 = R.call_dps_packed("env.relu", (lv0,), R.Tensor((1, 128), dtype="float32")) out = R.call_dps_packed("env.linear", (lv1, w1, b1), R.Tensor((1, 10), dtype="float32")) R.output(out) return out .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.register_func("env.linear", override=True) def torch_linear(x: tvm.nd.NDArray, w: tvm.nd.NDArray, b: tvm.nd.NDArray, out: tvm.nd.NDArray): x_torch = torch.from_dlpack(x) w_torch = torch.from_dlpack(w) b_torch = torch.from_dlpack(b) out_torch = torch.from_dlpack(out) torch.mm(x_torch, w_torch.T, out=out_torch) torch.add(out_torch, b_torch, out=out_torch) @tvm.register_func("env.relu", override=True) def lnumpy_relu(x: tvm.nd.NDArray, out: tvm.nd.NDArray): x_torch = torch.from_dlpack(x) out_torch = torch.from_dlpack(out) torch.maximum(x_torch, torch.Tensor([0.0]), out=out_torch) 我们可以绑定参数,看看它是否给出了正确的预测。 .. raw:: latex \diilbookstyleinputcell .. code:: python MyModuleWithParams = relax.transform.BindParams("main", nd_params)(MyModuleMixture) .. raw:: latex \diilbookstyleinputcell .. code:: python ex = relax.build(MyModuleWithParams, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd) pred_kind = np.argmax(nd_res.numpy(), axis=1) print("MyModuleWithParams Prediction:", class_names[pred_kind[0]]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output MyModuleWithParams Prediction: Shirt 以下代码计算模块在变化前的运行时间。请注意,因为这是一个小模型,不同次运行之间数字可能会有所波动,所以我们只需要关注整体幅度。 .. raw:: latex \diilbookstyleinputcell .. code:: python ftimer = vm.module.time_evaluator("main", tvm.cpu(), number=100) print("MyModuleWithParams time-cost: %g ms" % (ftimer(data_nd).mean * 1000)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output MyModuleWithParams time-cost: 0.25141 ms 我们现在准备好调整 ``linear0``\ 。下图总结了我们的整个流程。 .. figure:: ../img/auto_prog_optim_optim_flow.png 目前,调优 API 只接受一个带有一个 ``main`` 函数的 IRModule,所以我们首先将 ``linear0`` 取出到另一个模块的 main 函数中并将其传递给 ``tune_tir``\ 。 .. raw:: latex \diilbookstyleinputcell .. code:: python mod_linear = tvm.IRModule.from_expr(MyModuleMixture["linear0"].with_attr("global_symbol", "main")) IPython.display.HTML(code2html(mod_linear.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    
    @I.ir_module
    class Module:
        @T.prim_func
        def main(X: T.Buffer((1, 784), "float32"), W: T.Buffer((128, 784), "float32"), B: T.Buffer((128,), "float32"), Z: T.Buffer((1, 128), "float32")):
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # with T.block("root"):
            Y = T.alloc_buffer((1, 128))
            for i, j, k in T.grid(1, 128, 784):
                with T.block("Y"):
                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                    T.reads(X[vi, vk], W[vj, vk])
                    T.writes(Y[vi, vj])
                    with T.init():
                        Y[vi, vj] = T.float32(0)
                    Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
            for i, j in T.grid(1, 128):
                with T.block("Z"):
                    vi, vj = T.axis.remap("SS", [i, j])
                    T.reads(Y[vi, vj], B[vj])
                    T.writes(Z[vi, vj])
                    Z[vi, vj] = Y[vi, vj] + B[vj]
    
.. raw:: latex \diilbookstyleinputcell .. code:: python database = ms.tune_tir( mod=mod_linear, target="llvm --num-cores=1", max_trials_global=64, num_trials_per_iter=64, work_dir="./tune_tmp", task_name="main", ) sch = ms.tir_integration.compile_tir(database, mod_linear, "llvm --num-cores=1") .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 2023-03-18 11:49:55 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0 .. raw:: html
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 200832 1 6.9648 28.8354 28.8354 64 Y
.. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output 2023-03-18 11:49:55 [DEBUG] [task_scheduler.cc:318] ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done ----------------------------------------------------------------------------------------------------- 0 | main | 200832 | 1 | 6.9648 | 28.8354 | 28.8354 | 64 | Y ----------------------------------------------------------------------------------------------------- Total trials: 64 Total latency (us): 28.8354 Total trials: 64 Total latency (us): 28.8354 现在我们需要在调优后用新函数替换原来的 ``linear0``\ 。我们可以通过首先获得一个 ``global_var``\ (一个指向 IRModule 中函数的 ``pointer`` 引用),然后调用 ``update_func`` 来用新的函数替换原本的函数。 .. raw:: latex \diilbookstyleinputcell .. code:: python MyModuleWithParams2 = relax.transform.BindParams("main", nd_params)(MyModuleMixture) new_func = sch.mod["main"].with_attr("global_symbol", "linear0") gv = MyModuleWithParams2.get_global_var("linear0") MyModuleWithParams2.update_func(gv, new_func) IPython.display.HTML(code2html(MyModuleWithParams2.script())) .. raw:: html
# from tvm.script import ir as I
    # from tvm.script import tir as T
    # from tvm.script import relax as R
    
    @I.ir_module
    class Module:
        @T.prim_func
        def linear0(X: T.Buffer((1, 784), "float32"), W: T.Buffer((128, 784), "float32"), B: T.Buffer((128,), "float32"), Z: T.Buffer((1, 128), "float32")):
            T.func_attr({"global_symbol": "linear0", "tir.noalias": True})
            # with T.block("root"):
            Y = T.alloc_buffer((1, 128))
            for i_0, j_0 in T.grid(1, 1):
                for i_1, j_1 in T.grid(1, 8):
                    for i_2_init, j_2_init, i_3_init in T.grid(1, 2, 1):
                        for j_3_fused_init in T.vectorized(8):
                            with T.block("Y_init"):
                                vi = T.axis.spatial(1, i_0 + i_1 + i_2_init + i_3_init)
                                vj = T.axis.spatial(128, j_0 * 128 + j_1 * 16 + j_2_init * 8 + j_3_fused_init)
                                T.reads()
                                T.writes(Y[vi, vj])
                                T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
                                Y[vi, vj] = T.float32(0)
                    for k_0, i_2, j_2, k_1, i_3 in T.grid(14, 1, 2, 56, 1):
                        for j_3_fused in T.vectorized(8):
                            with T.block("Y_update"):
                                vi = T.axis.spatial(1, i_0 + i_1 + i_2 + i_3)
                                vj = T.axis.spatial(128, j_0 * 128 + j_1 * 16 + j_2 * 8 + j_3_fused)
                                vk = T.axis.reduce(784, k_0 * 56 + k_1)
                                T.reads(Y[vi, vj], X[vi, vk], W[vj, vk])
                                T.writes(Y[vi, vj])
                                T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"})
                                Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk]
                for ax0, ax1 in T.grid(1, 128):
                    with T.block("Z"):
                        vi, vj = T.axis.remap("SS", [ax0, ax1])
                        T.reads(Y[vi, vj], B[vj])
                        T.writes(Z[vi, vj])
                        Z[vi, vj] = Y[vi, vj] + B[vj]
    
        @R.function
        def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
            with R.dataflow():
                lv0 = R.call_dps_packed("linear0", (x, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                lv1 = R.call_dps_packed("env.relu", (lv0,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
                out = R.call_dps_packed("env.linear", (lv1, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
                R.output(out)
            return out
    
    # Metadata omitted. Use show_meta=True in script() method to show it.
    
我们可以发现上面代码中的 ``linear0`` 已经被替换了。 .. raw:: latex \diilbookstyleinputcell .. code:: python ex = relax.build(MyModuleWithParams2, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) nd_res = vm["main"](data_nd) pred_kind = np.argmax(nd_res.numpy(), axis=1) print("MyModuleWithParams2 Prediction:", class_names[pred_kind[0]]) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output MyModuleWithParams2 Prediction: Shirt 再次运行代码,我们可以发现我们得到了明显的时间减少,这主要归功于新的 ``linear0`` 函数。 .. raw:: latex \diilbookstyleinputcell .. code:: python ftimer = vm.module.time_evaluator("main", tvm.cpu(), number=50) print("MyModuleWithParams2 time-cost: %g ms" % (ftimer(data_nd).mean * 1000)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output MyModuleWithParams2 time-cost: 0.158885 ms 讨论 ---- 我们可能会注意到,我们前两章关注的是\ **抽象**\ ,而本章开始关注\ **变换**\ 。随机变换指定了可以优化的内容,而无需显式地确定所有选择。Meta-Schedule API 帮助我们搜索可能的变换空间并选择最佳变换。 重要的是,将搜索结果放回端到端流程只是将原始函数的实现替换为调优后的新函数的实现。 因此,我们再次遵循下图中的通用 MLC 流程。在以后的内容中,我们将介绍更多关于元张量函数和计算图函数的变换。一个好的 MLC 流程将这些变换组合在一起,形成最终部署的形式。 .. figure:: ../img/mlc_process.png 概括 ---- - 随机变换帮助我们指定可能程序的搜索空间。 - Meta-Schedule 在搜索空间中搜索,并找到优化后的程序。 - 我们可以使用另一种变换,将初始的元张量函数替换为优化后的函数,并更新端到端执行流程。