Computational Graph Optimization ================================ Prelude ------- Most of the MLC process can be viewed as transformation among tensor functions. In the past chapters, we studied how to transform each primitive tensor functions individually. In this chapter, let us talk about high-level transformations among computational graphs. .. figure:: ../img/mlc-elem-transform.png Preparations ------------ To begin with, let us import the necessary dependencies. .. raw:: latex \diilbookstyleinputcell .. code:: python # This is needed for deferring annotation parsing in TVMScript import numpy as np import tvm from tvm import relax, topi from tvm.ir.module import IRModule from tvm.script import relax as R from tvm.script import tir as T Pattern Match and Rewriting --------------------------- To begin with, let us start with the following example. .. raw:: latex \diilbookstyleinputcell .. code:: python @tvm.script.ir_module class MyModule: @R.function def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): with R.dataflow(): lv0 = relax.op.multiply(x, y) gv0 = relax.op.add(lv0, y) R.output(gv0) return gv0 ``MyModule`` contains a relax function with two high-level operators, relax.op.multiply and relax.op.add. Our goal is to find these two operators and replace it with a call into ``relax.op.ewise_fma`` operator. Before we dive into how to do that exactly, let us first examine the data structure that makes up the MyModule. Each IRModule contains a collection of functions, and the function body is composed of a set of data structures called abstract syntax trees (AST). .. raw:: latex \diilbookstyleinputcell .. code:: python relax_func = MyModule["main"] Each function is represented by a ``relax.expr.Function`` node. .. raw:: latex \diilbookstyleinputcell .. code:: python type(relax_func) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.relax.expr.Function The function contains a list of parameters. .. raw:: latex \diilbookstyleinputcell .. code:: python relax_func.params .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [x, y] The function contains a body fields that represents its return value and set of binding blocks in the function. .. raw:: latex \diilbookstyleinputcell .. code:: python func_body = relax_func.body type(func_body) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output tvm.relax.expr.SeqExpr The function body SeqExpr contains a sequence of (binding) blocks .. raw:: latex \diilbookstyleinputcell .. code:: python func_body.blocks .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [x: R.Tensor((3, 4), dtype="float32") y: R.Tensor((3, 4), dtype="float32") with R.dataflow(): lv0: R.Tensor((3, 4), dtype="float32") = R.multiply(x, y) gv0: R.Tensor((3, 4), dtype="float32") = R.add(lv0, y) R.output(gv0)] .. raw:: latex \diilbookstyleinputcell .. code:: python dataflow_block = func_body.blocks[0] In our particular case, we have a single data flow block that contains two bindings. Each binding corresponds to one of the following two lines .. raw:: latex \diilbookstyleinputcell .. code:: python lv0 = relax.op.multiply(x, y) gv0 = relax.op.add(lv0, y) .. raw:: latex \diilbookstyleinputcell .. code:: python dataflow_block.bindings .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [x: R.Tensor((3, 4), dtype="float32") y: R.Tensor((3, 4), dtype="float32") lv0: R.Tensor((3, 4), dtype="float32") = R.multiply(x, y), lv0: R.Tensor((3, 4), dtype="float32") y: R.Tensor((3, 4), dtype="float32") gv0: R.Tensor((3, 4), dtype="float32") = R.add(lv0, y)] .. raw:: latex \diilbookstyleinputcell .. code:: python binding = dataflow_block.bindings[0] Each binding have a var field that corresponds to the left hand side of the binding (``lv0``, ``gv0``). .. raw:: latex \diilbookstyleinputcell .. code:: python binding.var .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output lv0 And its value field corresponds to the right-hand side of the binding. Each value field corresponds to a ``relax.Call`` node representing a call into a primitive function. .. raw:: latex \diilbookstyleinputcell .. code:: python binding.value .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output R.multiply(x, y) .. figure:: ../img/relax_func_data_structure.png The above figure summarizes the data structure involved in this particular function. One approach to rewrite the program would be to traverse MyModule’s AST recursively and generate a transformed AST. We can certainly do that using the python API available. However, we can use extra tooling support to simplify the process. The following code block follows a design pattern called **visitor pattern** that allows us to visit each AST node and rewrite them to transformed versions. .. raw:: latex \diilbookstyleinputcell .. code:: python @relax.expr_functor.mutator class EwiseFMARewriter(relax.PyExprMutator): def visit_call_(self, call): call = self.visit_expr_post_order(call) add_op = tvm.ir.Op.get("relax.add") multiply_op = tvm.ir.Op.get("relax.multiply") ewise_fma_op = tvm.ir.Op.get("relax.ewise_fma") if call.op != add_op: return call value = self.lookup_binding(call.args[0]) if not isinstance(value, relax.Call) or value.op != multiply_op: return call fma_call = relax.Call( ewise_fma_op, [value.args[0], value.args[1], call.args[1]], None, None ) return fma_call updated_fn = EwiseFMARewriter().visit_expr(MyModule["main"]) updated_fn.show() .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output /usr/share/miniconda/envs/mlc/lib/python3.8/site-packages/tvm/script/highlight.py:117: UserWarning: No module named 'black' To print formatted TVM script, please install the formatter 'Black': /usr/share/miniconda/envs/mlc/bin/python -m pip install "black==22.3.0" --upgrade --user warnings.warn( .. raw:: html
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((3, 4), dtype="float32"), y: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((3, 4), dtype="float32"):
with R.dataflow():
lv0: R.Tensor((3, 4), dtype="float32") = R.multiply(x, y)
gv0: R.Tensor((3, 4), dtype="float32") = R.ewise_fma(x, y, y)
R.output(gv0)
return gv0
# from tvm.script import relax as R
@R.function
def main(x: R.Tensor((3, 4), dtype="float32"), y: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((3, 4), dtype="float32"):
with R.dataflow():
gv0: R.Tensor((3, 4), dtype="float32") = R.ewise_fma(x, y, y)
R.output(gv0)
return gv0
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
with R.dataflow():
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
lv1: R.Tensor((1, 128), dtype="float32") = R.matmul(x, lv, out_dtype="void")
lv2: R.Tensor((1, 128), dtype="float32") = R.add(lv1, metadata["relax.expr.Constant"][1])
lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void")
lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, metadata["relax.expr.Constant"][3])
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def fused_matmul_add0(x: R.Tensor((1, 784), dtype="float32"), w: R.Tensor((784, 128), dtype="float32"), b: R.Tensor((128,), dtype="float32")) -> R.Tensor((1, 128), dtype="float32"):
R.func_attr({"Primitive": 1})
with R.dataflow():
lv: R.Tensor((1, 128), dtype="float32") = R.matmul(x, w, out_dtype="void")
gv: R.Tensor((1, 128), dtype="float32") = R.add(lv, b)
R.output(gv)
return gv
@R.function
def fused_matmul_add1(x: R.Tensor((1, 128), dtype="float32"), w: R.Tensor((128, 10), dtype="float32"), b: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"Primitive": 1})
with R.dataflow():
lv: R.Tensor((1, 10), dtype="float32") = R.matmul(x, w, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = R.add(lv, b)
R.output(gv)
return gv
@R.function
def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv: R.Tensor((784, 128), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][0], axes=None)
lv2: R.Tensor((1, 128), dtype="float32") = cls.fused_matmul_add0(x, lv, metadata["relax.expr.Constant"][1])
lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((128, 10), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][2], axes=None)
lv6: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul_add1(lv3, lv4, metadata["relax.expr.Constant"][3])
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
@T.prim_func
def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
@T.prim_func
def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), T_matmul_NN: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
with T.block("T_matmul_NN"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
T.writes(T_matmul_NN[v_i, v_j])
with T.init():
T_matmul_NN[v_i, v_j] = T.float32(0)
T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
@T.prim_func
def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), T_matmul_NN: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
with T.block("T_matmul_NN"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(rxplaceholder[v_i, v_k], rxplaceholder_1[v_k, v_j])
T.writes(T_matmul_NN[v_i, v_j])
with T.init():
T_matmul_NN[v_i, v_j] = T.float32(0)
T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + rxplaceholder[v_i, v_k] * rxplaceholder_1[v_k, v_j]
@T.prim_func
def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
@T.prim_func
def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@T.prim_func
def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@R.function
def fused_matmul_add0(x: R.Tensor((1, 784), dtype="float32"), w: R.Tensor((784, 128), dtype="float32"), b: R.Tensor((128,), dtype="float32")) -> R.Tensor((1, 128), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.matmul, (x, w), out_sinfo=R.Tensor((1, 128), dtype="float32"))
gv = R.call_tir(cls.add, (lv, b), out_sinfo=R.Tensor((1, 128), dtype="float32"))
R.output(gv)
return gv
@R.function
def fused_matmul_add1(x: R.Tensor((1, 128), dtype="float32"), w: R.Tensor((128, 10), dtype="float32"), b: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"Primitive": 1})
cls = Module
with R.dataflow():
lv = R.call_tir(cls.matmul1, (x, w), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv = R.call_tir(cls.add1, (lv, b), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
@R.function
def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((784, 128), dtype="float32"))
lv2: R.Tensor((1, 128), dtype="float32") = cls.fused_matmul_add0(x, lv, metadata["relax.expr.Constant"][1])
lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (metadata["relax.expr.Constant"][2],), out_sinfo=R.Tensor((128, 10), dtype="float32"))
lv6: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul_add1(lv3, lv4, metadata["relax.expr.Constant"][3])
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def fused_matmul_add0(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), w: T.Buffer((T.int64(784), T.int64(128)), "float32"), b: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_matmul_NN = T.alloc_buffer((T.int64(1), T.int64(128)))
for i, j, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
with T.block("T_matmul_NN"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(x[v_i, v_k], w[v_k, v_j])
T.writes(T_matmul_NN[v_i, v_j])
with T.init():
T_matmul_NN[v_i, v_j] = T.float32(0)
T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + x[v_i, v_k] * w[v_k, v_j]
for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_matmul_NN[v_ax0, v_ax1], b[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T_matmul_NN[v_ax0, v_ax1] + b[v_ax1]
@T.prim_func
def fused_matmul_add1(x: T.Buffer((T.int64(1), T.int64(128)), "float32"), w: T.Buffer((T.int64(128), T.int64(10)), "float32"), b: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_matmul_NN = T.alloc_buffer((T.int64(1), T.int64(10)))
for i, j, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
with T.block("T_matmul_NN"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(x[v_i, v_k], w[v_k, v_j])
T.writes(T_matmul_NN[v_i, v_j])
with T.init():
T_matmul_NN[v_i, v_j] = T.float32(0)
T_matmul_NN[v_i, v_j] = T_matmul_NN[v_i, v_j] + x[v_i, v_k] * w[v_k, v_j]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_matmul_NN[v_ax0, v_ax1], b[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = T_matmul_NN[v_ax0, v_ax1] + b[v_ax1]
@T.prim_func
def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))
@T.prim_func
def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@T.prim_func
def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
@R.function
def main(x: R.Tensor((1, 784), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.transpose, (metadata["relax.expr.Constant"][0],), out_sinfo=R.Tensor((784, 128), dtype="float32"))
lv2 = R.call_tir(cls.fused_matmul_add0, (x, lv, metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv3 = R.call_tir(cls.relu, (lv2,), out_sinfo=R.Tensor((1, 128), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (metadata["relax.expr.Constant"][2],), out_sinfo=R.Tensor((128, 10), dtype="float32"))
lv6 = R.call_tir(cls.fused_matmul_add1, (lv3, lv4, metadata["relax.expr.Constant"][3]), out_sinfo=R.Tensor((1, 10), dtype="float32"))
gv: R.Tensor((1, 10), dtype="float32") = lv6
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.