tvm.relax.transform

Relax transformations.

tvm.relax.transform.AdjustMatmulOrder()

Reorder x*(A*B) to (x*A)*B

Useful for optimizing LoRA computations, where matmul(x, LoraA*LoraB) may be computed as matmul(matmul(x, LoraA), LoraB), reducing the total memory usage.

Returns:

ret – The corresponding pass.

Return type:

tvm.transform.Pass

tvm.relax.transform.AllocateWorkspace() Pass

Allocate a workspace, represented by a tensor of size big enough for all external functions that require a temporary storage, and append it to the arguments of external functions.

An external function can specify its workspace requirement by the kWorkspaceSize attribute.

Returns:

ret – The registered pass for allocating workspace.

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.AlterOpImpl(op_impl_map: Dict[str, PrimFunc], op_buffer_transforms: Dict[str, List[IndexMap | Callable]], op_buffer_axis_separators: Dict[str, List[axis_separator | Callable]])

Replace all PrimFunc’s which have matching ‘operator_name’ attribute, with replacement PrimFunc that could possibly have different layouts on i/o buffers. The layout transformations on i/o buffers is present in the op_buffer_transforms map. Inserts the layout transformations in the call sites of PrimFuncs being replaced to transform i/o tensors into expected layout by new PrimFunc.

Parameters:
  • op_impl_map (Dict[str, PrimFunc]) – op_kind to PrimFunc map

  • op_buffer_transforms (Dict[str, List[Union[IndexMap, Callable]]) – op_kind to layout transformation map for each of the buffers

  • op_buffer_axis_separators (Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]) – op_kind to axis_separator for each index_map

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.AnnotateTIROpPattern() Pass

Annotate Op Pattern Kind for TIR functions

Returns:

ret

Return type:

tvm.ir.transform.Pass

class tvm.relax.transform.AttachExternModules(extern_modules: List[ExternModule])

Attach variable bounds to each Relax function, which primarily helps with memory planning.

tvm.relax.transform.AttachGlobalSymbol() Pass

Attach global_symbol to Relax functions and TIR Primfuncs for codegen.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.BindParams(func_name: str, params: Dict[str | Var, NDArray | ndarray]) Pass

Bind params of function of the module to constant tensors.

Parameters:
  • func_name (str) – The function name to be bound

  • params (Dict[) –

    Union[str,relax.Var],

    Union[tvm.runtime.NDArray, np.ndarray],

    ]

    The map from parameter or parameter name to constant tensors.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.BindSymbolicVars(binding_map: Mapping[str | Var, PrimExpr], func_name: str | None = None) Pass

Bind params of function of the module to constant tensors.

Parameters:
  • binding_map (Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr]) – The map from symbolic varname to integer.

  • func_name (Optional[str]) – The function name to be bound. If None (default), all functions within the module will be updated.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.BundleModelParams() Pass

Bundle several model parameters into a single tuple paramters

For each function, if the function has the attribute “num_input”, separate between run-time parameters and compile-time weights. Run-time parameters (e.g. activations) are the first num_input parameters, and the remainder are compile-time weights.

Returns:

ret – The registered pass for lifting transformation of parameters.

Return type:

tvm.transform.Pass

tvm.relax.transform.CallTIRRewrite() Pass

Perform explicit tensor allocation for call_tir and call_dps_packed.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.CanonicalizeBindings() Pass

Canonicalizes variable definitions (e.g., if there is y = x and z = y, it replaces uses of y and z with x). Also simplifies match cast nodes (eliminating redundant checks) and tuple indices.

Best combined with constant folding and the elimination of unused definitions.

Note: If a dataflow var is used only in a binding to the dataflow block output var (i.e., a non-dataflow var), this pass will also remove the dataflow var and replaces the output var’s binding with the dataflow var’s direct definition.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.CombineParallelMatmul(check=None)

Combine multiple matmul operators sharing the same LHS matrix into one, followed by slicing. When all matmul branches in a tree have the same set of fused ops, the fused ops are applied to the combined matmul output before slicing.

Currently, only a limited set of fused ops is supported. It includes bias add, relu, gelu, gelu_tanh and silu activation.

Parameters:

check (Callable[[relax.Var, List[relax.Var], List[relax.Var], Dict[relax.Var, Expr]], bool]) – A function to filter out unwanted branches, with the signature (input, [rhs], [bias], binding) -> bool.

Returns:

ret – The corresponding pass.

Return type:

tvm.transform.Pass

tvm.relax.transform.ConvertLayout(desired_layouts: Dict[str, List[str]]) Pass

Automatic layout conversion pass.

Parameters:

desired_layouts (Dict[str, List[str]]) – The desired layout of conv2d ops is a map from the name of the op to the desired layout of the desired feature map, weight and output. For example, if we want to convert the layout of conv2d from NCHW to NHWC, we can set the desired layout of conv2d to be {"relax.nn.conv2d": ["NHWC", "OHWI"]}.

Returns:

ret – The registered pass for layout conversion.

Return type:

tvm.transform.Pass

tvm.relax.transform.ConvertToDataflow(min_size: int = 2) Pass

A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks.

Params

min_size: int

The minimum number of consecutive dataflow bindings the pass needs to extract a new block.

returns:

ret – The pass.

rtype:

tvm.ir.transform.Pass

class tvm.relax.transform.DataflowBlockPass

A pass that works on each tvm.relax.DataflowBlock in a module.

tvm.relax.transform.DataflowUseInplaceCalls() Pass

Pass that changes calls to operators that can be done in-place (generally, these are elementwise operations) into in-place implementations. Supported operators will be replaced by calls to call_tir_inplace that invoke in-place PrimFunc implementations of those operators (which are based on the legalizations of those operators).

Returns:

ret – The pass

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.DeadCodeElimination(entry_functions: List[str] | None = None) Pass

Remove dead code in the IRModule. Currently it removes:

  1. Unused local VarBindings in a DataflowBlock.

  2. Unused DataflowBlocks in a function.

  3. Unused Relax functions in the module. We detect the call chain from the entry function, and remove all unused functions.

Notes

For function-wise DCE, use py:func:tvm.relax.analysis.remove_all_unused.

Parameters:

entry_functions (Optional[List[str]]) – The set of entry functions to start from.

Returns:

ret – The registered pass.

Return type:

tvm.transform.Pass

tvm.relax.transform.DecomposeOpsForInference(func_name: str | None = None) Pass

Decompose composite operators that are composed by other operators during inference. For example, the result of batch norm (a triple) will be simplified. Attention, tensor_to_shape, etc. can be also decomposed into a number of simplified operators as well.

Parameters:

func_name (Optional[str]) – The name of the specified function. If not specified, the pass will run in all functions.

Returns:

ret – The registered pass

Return type:

tvm.transform.Pass

tvm.relax.transform.DecomposeOpsForTraining(func_name: str | None = None) Pass

Decompose composite operators that are composed by other operators during training. For example, the result of batch norm (a triple) will be simplified. Attention, tensor_to_shape, etc. can be also decomposed into a number of simplified operators as well.

Parameters:

func_name (Optional[str]) – The name of the specified function. If not specified, the pass will run in all functions.

Returns:

ret – The registered pass

Return type:

tvm.transform.Pass

tvm.relax.transform.EliminateCommonSubexpr(call_only=False) FunctionPass

Eliminate common subexpressions within functions.

Note: For nested functions, this pass performs CSE within those functions

Parameters:

call_only (bool) – If True, enable eliminating only call nodes.

Returns:

ret – The registered pass that eliminates common subexpressions.

Return type:

tvm.transform.Pass

tvm.relax.transform.ExpandMatmulOfSum()

Expand matmul(x, A+B) to matmul(x,A) + matmul(x,B)

If either operand can be fully computed at compile-time (only depends on function parameters after kNumInput), this expansion is suppressed.

Useful for optimizing LoRA computations, where matmul(x, Base + LoraA*LoraB) may be expanded to matmul(x, Base) + matmul(x, LoraA*LoraB), allowing it to optimized with CombineParallelMatmul.

Returns:

ret – The corresponding pass.

Return type:

tvm.transform.Pass

tvm.relax.transform.ExpandTupleArguments() Pass

Expand tuple arguments to internal functions

Returns:

ret

Return type:

tvm.ir.transform.Pass

class tvm.relax.transform.FastMathTransform(*args, **kwargs)

Pass to convert the expensive non linear functions to their fast but approximate counterparts.

tvm.relax.transform.FewShotTuning(valid_count: int = 1, benchmark: bool = False) Pass

The pass is designed for few shot tuning for static shape PrimFuncs. It examines all the blocks within the PrimFunc and conducts loop fusion, splitting, and other transformations based on MetaSchedule schedule rules but directly samples from the search space instead of using the tuning algorithm. User can specify the number of valid counts to try and whether to use runner for benchmarking.

Parameters:
  • valid_count (int) – The number of valid counts to try.

  • benchmark (bool) – Whether to use runner for benchmarking.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.FoldConstant() Pass

Fold constant expressions.

Returns:

ret

Return type:

tvm.ir.transform.Pass

class tvm.relax.transform.FunctionPass

A pass that works on each tvm.relax.Function in a module. A function pass class should be created through function_pass.

tvm.relax.transform.FuseOps(fuse_opt_level=-1) Pass

This pass groups bindings in a dataflow block of Relax functions and generate a new grouped Relax function for each group, according to the fusion algorithm described in the pass implementation. By grouping bindings into new Relax functions, we substitute the bindings in the function being manipulated into function calls to the new grouped function.

A follow-up pass named “FuseTIR” will generate a TIR PrimFunc for each grouped function.

Parameters:

fuse_opt_level (int) – The level of fuse optimization. -1 indicates that the level will be inferred from pass context.

Returns:

ret – The registered pass for operator fusion.

Return type:

tvm.transform.Pass

tvm.relax.transform.FuseOpsByPattern(patterns: List[FusionPattern | Tuple], bind_constants: bool = True, annotate_codegen: bool = False) Pass

Apply pattern matching to each function in the given module, and group matched expressions into a new function.

The end result is similar to FuseOps, but fusion is driven completely by the provided patterns.

Parameters:
  • patterns (List[Union[FusionPattern, Tuple]]) –

    A list of patterns to be matched. The order of the patterns determines the order of priority in which they are matched. Higher-priority patterns should come earlier in the list.

    In addition to FusionPattern, a tuple can be passed as item of this list. The pattern will be constructed through FusionPattern(*item)

  • bind_constants (bool) – Whether or not to keep bound constants in the grouped function.

  • annotate_codegen (bool) –

    If True, wrap each created composite function with another function, whose body consists only of a call to the composite function, and annotate the outer function with “Codegen” and “global_symbol” attributes. The “Codegen” attribute is set as the prefix of the corresponding pattern name. For example, “dnnl” if the pattern name is “dnnl.conv2d_relu”.

    This must be True if the created composite functions are intended to be offloaded to an external backend without using the MergeCompositeFunctions pass.

Returns:

ret – The registered pass for pattern-based fusion.

Return type:

tvm.transform.Pass

tvm.relax.transform.FuseTIR() Pass

Fuse primitive relax function into a larger TIR function if possible

Returns:

ret – The registered pass for tir fusion.

Return type:

tvm.transform.Pass

class tvm.relax.transform.FusionPattern(name: str, pattern: DFPattern, annotation_patterns: Mapping[str, DFPattern] | None = None, check: Callable[[PatternCheckContext], bool] | None = None, attrs_getter: Callable[[Dict[str, Expr]], Dict[str, str]] | None = None)

The pattern used by FuseOpsByPattern. It’s mainly DFPattern but with other information to help during the fusion pass.

Parameters:
  • name (str) – The name of pattern. Usually it starts with the name of backend, like ‘cutlass.matmul’.

  • pattern (DFPattern) – The dataflow pattern that will be used to match expressions that can be handled by external backends.

  • annotation_patterns (Mapping[str, DFPattern]) – The map which is used to extract important expressions from the pattern match result. All DFPattern in this map should be part of the pattern.

  • check (Callable[[PatternCheckContext], bool]) – The function to check whether the match result is accepted.

tvm.relax.transform.Gradient(func_name: str, require_grads: Var | List[Var] | None = None, target_index: int = 0) Pass

Reverse-mode automatic differentiation.

This pass will differentiate one function in the IRModule. Now the input function must have only one dataflow block.

For a given function specified by func_name, it generates a new function with the name func_name + “_adjoint”. The new function computes the gradient of the differentiation target with respect to the arguments specified by require_grads of the original function.

If the function has only one return value, the return value will be specified as target. If the function has more than one return values, the target will be specified as the target_index-th return value. The target must be a scalar (0-dim tensor).

The new function will be like:

@R.function
def main_adjoint(original_parameters):
    with R.dataflow():
        # the bindings of the original function
        ...
        # calculating the gradients
        ...
        R.output(original_outputs, grad_1, grad_2, ...)
    return (original_return_value, (grad_1, grad_2, ...))

This AD pass also supports checkpointing as described in “Training deep nets with sublinear memory cost.” - Chen, Tianqi, et al. (2016). See tvm.relax.testing.nn.checkpoint for more details.

Parameters:
  • func_name (str) – The name of the specific function.

  • require_grads (Optional[Union[relax.Var, List[relax.Var]]]) – The relax variables whose adjoints is needed. Must be parameters of the given function and should not be duplicate. If it is not specified, adjoints of all parameters would be computed.

  • target_index (int) – If the specified function has more than one return values, specify the index of the return value as the target. If it is not specified, the first return value will be the target.

Returns:

ret – The Pass.

Return type:

tvm.ir.transform.Pass

Examples

The following code shows how to use this pass:

@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tensor((), dtype="float32"):
        with R.dataflow():
            lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            # use R.sum to reduce the tensor to a scalar
            lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
            R.output(lv2)
        return lv2

After = relax.transform.Gradient("main")(Module)

The module after the Gradient pass will be:

@I.ir_module
class After:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tensor((), dtype="float32"):
        with R.dataflow():
            lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
            R.output(lv2)
        return lv2

    @R.function
    def main_adjoint(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(
        R.Tensor((), dtype="float32"),
        R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")),
    ):
        with R.dataflow():
            # original bindings
            lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
            lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False)
            # bindings w.r.t. intermediate variables
            lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), dtype="float32")
            lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(
                lv2_adjoint, (3, 3)
            )
            # bindings w.r.t. parameters
            x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
            y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
            R.output(lv2, x_adjoint, y_adjoint)
        # return value: (orig_return_values, tuple(adjoints))
        return (lv2, (x_adjoint, y_adjoint))

The second example is returning multiple values and specifying the target with target_index:

@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
        with R.dataflow():
            lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            R.output(lv1, lv2)
        return (lv1, lv2)

After = relax.transform.Gradient("main", target_index=1)(Module)

The module after the Gradient pass will be:

@I.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")):
        with R.dataflow():
            lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            R.output(lv1, lv2)
        return (lv1, lv2)

    @R.function
    def main_adjoint(
        x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), dtype="float32")
    ) -> R.Tuple(
        R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")),
        R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")),
    ):
        with R.dataflow():
            # original bindings
            lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
            lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, keepdims=False)
            # bindings w.r.t. intermediate variables
            # gradient of intermediate variables that is not related to the target will not
            # be calculated
            lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), dtype="float32")
            # bindings w.r.t. parameters
            x_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros((3, 3), dtype="float32")
            y_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(
                lv2_adjoint, (3, 3)
            )
            R.output(lv1, lv2, x_adjoint, y_adjoint)
        # return value: (orig_return_values, tuple(adjoints))
        return ((lv1, lv2), (x_adjoint, y_adjoint))
tvm.relax.transform.InlinePrivateFunctions() Pass

Inline all private relax functions

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.KillAfterLastUse() Pass

Drop all tensor/storage objects after last use

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.LambdaLift() Pass

A pass that lifts local functions into global.

Returns:

ret

Return type:

tvm.ir.transform.Pass

class tvm.relax.transform.LazyTransformParams(fget_item='get_item', fset_item='set_item', extra_get_item_params=None, extra_set_item_params=None)

Convert transform_params functions into a lazy version. (Load the input to memory on demand, and immediately free it after the last use.)

Note: ToNonDataflow() and RemovePurityTracking() should be invoked before this pass.

Parameters:
  • fget_item (str) – The name of the get_item function.

  • fset_item (str) – The name of the set_item function.

  • extra_get_item_params (list of relax.Var) – The parameters of the get_item function except index. The given parameters will be placed before index. For example, if extra_get_item_params is [param1, param2], then the pass will generate call_packed(fget_item, [param1, param2, index])

  • extra_set_item_params (list of relax.Var) – The parameters of the set_item function except index and value. The given parameters will be placed before index and value. For example, if extra_set_item_params is [param1, param2], then the pass will generate call_packed(fset_item, [param1, param2, index, value])

tvm.relax.transform.LegalizeOps(customize_legalize_map: Dict[str, Callable[[BlockBuilder, Call], Expr]] | None = None, enable_warning: bool = False)

Legalize high-level operator calls in Relax functions to call_tir with corresponding low-level TIR PrimFuncs.

For each high-level operator, we register the way of legalizing it as a function, which takes a context BlockBuilder and the relax.Call being legalized as input, and returns the legalized call. Here the input BlockBuilder is mainly used for adding the PrimFunc created by call_te into the context IRModule.

The legalization function for each operator is registered as an attribute (with attribute key FLegalize) of the operator.

This pass provides customizability for users to use their own legalization function for operators. The pass takes an optional customized map, with the key to be the operator name (str) and value to be the function (LegalizeFunc). The default legalization function will be overridden by the customized one.

Parameters:
  • customize_legalize_map (Optional[Dict[str, LegalizeFunc]]) – The customized operator legalization function map. The customized function will override the default one.

  • enable_warning (bool) – A boolean value indicating if to print warnings for CallNode whose op’s legalization function is not registered. By default we don’t print warnings.

Returns:

ret – The registered pass

Return type:

tvm.transform.Pass

Examples

The following code shows how to use this pass:

# Define the pass input IRModule
@tvm.script.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        z: R.Tensor((2, 3), "float32") = R.add(x, y)
        r: R.Tensor((2, 3), "float32") = R.multiply(y, z)
        return r

# Define the customized legalization function for "relax.add"
def customize_legalize_add(bb: relax.BlockBuilder, call: relax.Call) -> relax.Expr:
    from tvm import topi
    return bb.call_te(topi.add, call.args[1], call.args[0])

# Apply the pass with the customized function to the module.
mod = LegalizeOps({"relax.add": customize_legalize_add})(Module)

Print out the result by mod.show(), we can see the IRModule after legalization becomes

@tvm.script.ir_module
class Module:
    @R.function
    def main(
        x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
    ) -> R.Tensor((2, 3), "float32"):
        z = R.call_tir(add, (y, x), (2, 3), dtype="float32")
        r = R.call_tir(multiply, (y, z), (2, 3), dtype="float32")
        return r

    @T.prim_func
    def add(
        A: T.Buffer((2, 3), "float32"),
        B: T.Buffer((2, 3), "float32"),
        T_add: T.Buffer((2, 3), "float32"),
    ):
        T.func_attr({"tir.noalias": True})
        for ax0, ax1 in T.grid(2, 3):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func
    def multiply(
        A: T.Buffer((2, 3), "float32"),
        B: T.Buffer((2, 3), "float32"),
        T_multiply: T.Buffer((2, 3), "float32"),
    ):
        T.func_attr({"tir.noalias": True})
        for ax0, ax1 in T.grid(2, 3):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_multiply[v_ax0, v_ax1])
                T_multiply[v_ax0, v_ax1] = A[v_ax0, v_ax1] * B[v_ax0, v_ax1]
tvm.relax.transform.LiftTransformParams() Pass

Lift transformation of the parameters of a function.

When some inputs of the function is marked as ‘parameters’ (the model weights), this pass identifies the transformation of the parameters and lifts them to a separate function called transform_params. transform_params takes a tuple of the original parameters as input and returns a tuple of the transformed parameters. The original function will be rewritten to accept a tuple of transformed parameters as input.

Users are expected to invoke the transform_params function in runtime and pass the transformed parameters to the original function as input.

Returns:

ret – The registered pass for lifting transformation of parameters.

Return type:

tvm.transform.Pass

tvm.relax.transform.LowerAllocTensor() Pass

Lower remaining instances of R.builtin.alloc_tensor

The static memory planner removes static instances of R.builtin.alloc_tensor, replacing with R.memory.alloc_storage and R.memory.alloc_tensor. However, R.builtin.alloc_tensor still remains for any dynamic allocations.

This transform replaces any remaining R.builtin.alloc_tensor instances with R.memory.alloc_storage and R.memory.alloc_tensor. If no R.builtin.alloc_tensor are present, this pass has no effect.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.MergeCompositeFunctions() Pass

Group one or multiple composite functions created by FuseOpsByPattern into a new function. The new function will be annotated with “Codegen” and “global_symbol” attributes, and it is intented to be offloaded to an external backend.

Returns:

ret – The registered pass for merging composite functions.

Return type:

tvm.transform.Pass

tvm.relax.transform.MetaScheduleApplyDatabase(work_dir: str | None = None, enable_warning: bool = False) Pass

Apply the best schedule from tuning database.

Parameters:
  • work_dir (Optional[str]) – work directory to deduce default database if database is not provided (it will be ignored when an user passes database)

  • enable_warning (bool) – A boolean value indicating if to print warnings for TIR functions not showing up in the database. By default we don’t print warning.

Returns:

ret – The registered pass

Return type:

tvm.transform.Pass

tvm.relax.transform.MetaScheduleTuneIRMod(params: Dict[str, NDArray], work_dir: str, max_trials_global: int, max_trials_per_task: int | None = None, op_names: List[str] | None = None) Pass

Tune Relax IRModule with MetaSchedule.

Parameters:
  • params (Dict[str, NDArray]) – model params

  • work_dir (str) – work directory

  • max_trials_gloabl (int) – maximum number of total trials allowed for tuning

  • max_trials_per_task (int) – maximum number of trials per task

  • op_names (Optional[List[str]]) –

    A list of operator names to specify which op to tune. When it is None, all operators

    are tuned.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.MetaScheduleTuneTIR(work_dir: str, max_trials_global: int) Pass

Tune TIR with MetaSchedule. :param work_dir: work directory :type work_dir: str :param max_trials_gloabl: maximum number of total trials allowed for tuning :type max_trials_gloabl: int

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.Normalize() Pass

Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.NormalizeGlobalVar() Pass

Possibly rename the GlobalVar in an IRModule to ensure these properties:

1. (Invariant) First ensure every public function has the same name as its “global_symbol” attribute 2. To ensure 1., we may need to rename private functions with conflicting names; 3. Finally, the name of every GlobalVar is unique in the IRModule.

Returns:

ret

Return type:

tvm.ir.transform.Pass

class tvm.relax.transform.OptimizeLayoutTransform

Pass to remove redundant transform layout operators introduced by AlterOpImpl pass.

class tvm.relax.transform.PatternCheckContext

The input of check function FusionPattern.check.

Parameters:
  • matched_expr (Expr) – The expression that’s matched with the FusionPattern.pattern.

  • annotated_expr (Mapping[str, Expr]) – A map which contains all expressions matched by the sub patterns in FusionPattern.annotation_patterns.

  • matched_bindings (Mapping[relax.Var, Expr]) – Map from variable to its value. It contains variables from bindings that is being fused by FuseOpsByPattern.

  • var_usages (Mapping[relax.Var, Sequence[relax.Var]]) – A map mapping variable definitions to a set of uses. It has all variables used in the function.

  • value_to_bound_var (Mapping[Expr, relax.Var]) – Map from value to its bound variable. It doesn’t have variables after the matched expression.

tvm.relax.transform.RealizeVDevice() Pass

Propagate virtual device information.

Returns:

ret – The registered pass

Return type:

tvm.transform.Pass

tvm.relax.transform.RemovePurityChecking() Pass

Activate relax.force_pure on all pure functions in the module and unwrap all pure override ops into the normal versions.

This effectively means that there will be no more purity tracking, useful for low-level code generation.

Returns:

ret – The Pass.

Return type:

tvm.ir.transform.Pass

Note

Should be used after ToNonDataflow()

class tvm.relax.transform.RemoveRedundantReshape

Transformation pass to remove redundant reshape operator

tvm.relax.transform.RemoveUnusedOutputs() Pass

Remove unused outputs from internal functions

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.RemoveUnusedParameters() Pass

Remove unused arguments to internal functions

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.ReorderTakeAfterMatmul()

Reorder matmul(x, take(weights, indices)) to take(matmul(x,weights),indices)

Useful for optimizing LoRA computations, where several LoRAs may be batched together.

Returns:

ret – The corresponding pass.

Return type:

tvm.transform.Pass

tvm.relax.transform.RewriteCUDAGraph() Pass

Rewrite a Relax module for executing with CUDA graph. This pass identifies the regions that can be executed with CUDA graph and lifts them into new functions for runtime graph capturing.

Returns:

ret – The registered pass for rewriting cuda graph

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.RewriteDataflowReshape() Pass

Convert all reshape-like call_tir to VM reshape operator call. The VM reshape operator calls will be further lowered to a CreateView operation at runtime, instead of doing real data copy. Here “reshape-like” includes reshape, expand_dims, flatten, etc.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.RunCodegen(target_options: dict | None = None, entry_functions: List[str] | None = None) Pass

Produce the runtime::Module with an annotated codegen and global symbol.

Parameters:
  • target_options (Optional[dict]) – Pairs of a target name and compilation options

  • entry_functions (Optional[List[str]]) – The set of entry functions to start from.

Returns:

ret – The registered pass to remove unused functions.

Return type:

tvm.transform.Pass

tvm.relax.transform.SplitCallTIRByPattern(patterns: List[PrimFunc], fcodegen: Callable) Pass
Split a PrimFunc into 2 parts: the first part is a TIR PrimFunc which is

matched with some pattern, and the second part is the rest of the original PrimFunc. It will call fcodegen to generate the code for the matched pattern to replace it with a ExternFunc call.

Parameters:
  • patterns (List[PrimFunc]) – The list of patterns to match.

  • fcodegen (Callable[[List[MatchResult]], List[Object]]) – The function to generate the code for the matched patterns.

Returns:

ret – The registered pass for splitting call_tir.

Return type:

tvm.transform.Pass

tvm.relax.transform.StaticPlanBlockMemory() Pass

The static memory planning pass on BindingBlock level. The pass will reuse allocated memory to its best effort, in order to reduce the total amount of allocated memory size.

The pass “supports” dynamic shape in the way of TIR variable upper bound annotation. We can optionally annotate the attribute “tir_var_upper_bound” to Relax functions. The attribute value is a dict from strings to integers, denoting the name of TIR variables to the upper bound values of the TIR vars. Note: The annotated upper bound attribute only applies to TIR vars in the function signature for clarity.

For example, we can annotate a Relax function with R.func_attr({"tir_var_upper_bound": {"n": 1024}}). It means the maximum value of variable that names “n” in the function signature will have upper bound 1024. And we will use 1024 as its value during memory planning.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.ToMixedPrecision(out_dtype='float32', fp16_input_names: List[str] | None = None) Pass

Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only, and will automatically cast fp32 to fp16 for certain ops.

Parameters:
  • out_dtype (str) – The output data type of gemm/conv, which is the data type of the accumulator.

  • fp16_input_names (List[str]) – The names of function parameters whose dtype should become fp16. The function signature would change accordingly.

Returns:

ret – The registered pass for mixed precision.

Return type:

tvm.transform.Pass

tvm.relax.transform.ToNonDataflow() Pass

Transform all dataflow structure to non-dataflow version.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.UpdateParamStructInfo(sinfo_func: Callable[[Var], StructInfo | None])

Update struct info of parameters

Update struct info of parameters. Internal bindings and function return type will be updated using relax’s struct inference rules. Errors resulting from struct inference will be propagated to the user.

Parameters:

sinfo_func (Callable[[relax.Var], Optional[StructInfo]]) – A function that is called once for each function parameter, and returns the updated struct info to be used for it. If the function returns None, the parameter is not modified.

Returns:

ret – The corresponding pass.

Return type:

tvm.transform.Pass

tvm.relax.transform.UpdateVDevice(new_vdevice: VDevice, index: int) Pass

Update virtual device.

Parameters:
  • new_vdevice (tvm.ir.VDevice) – The new virtual device.

  • index (int) – The device index indicates the device on which the update will be performed.

Returns:

ret – The registered pass that modifies the virtual device.

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.VMBuiltinLower() Pass

Lowering generic intrinsic to VM intrinsics.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.VMShapeLower(*, emit_err_ctx: bool = True) Pass

Lower the symbolic shape and argument and match-cast structinfo matching.

Parameters:

emit_err_ctx (Optional[bool]) – Whether emit err context string, can be turned off for testing purposes.

Returns:

ret

Return type:

tvm.ir.transform.Pass

tvm.relax.transform.dataflowblock_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False) Callable | DataflowBlockPass

Decorate a dataflowblock pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created dataflowblock pass using the given optimization function.

Parameters:
  • pass_func (Optional[Callable[(DataflowBlock, Module, PassContext) -> DataflowBlock]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this dataflowblock pass.

  • name (Optional[str]) – The name of the dataflowblock pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the dataflowblock pass is dependent on.

  • traceable (Boolean) – Boolean variable whether the dataflowblock pass is traceable

Returns:

create_dataflowblock_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new DataflowBlockPass will be returned when we decorate a pass function. A new DataflowBlockPass class will be returned when we decorate a class type.

Return type:

Union[Callable, DataflowBlockPass]

Examples

The following code block decorates a dataflowblock pass class.

@relax.transform.dataflowblock_pass(opt_level=1)
class TestReplaceBinding:
    # Simple test function to replace the first VarBinding to another.

    def __init__(self):
        # create a new VarBinding
        m, n = tir.Var("m", "int64"), tir.Var("n", "int64")
        lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32"))
        val = relax.const(np.random.rand(24, 56))
        self.new_binding = relax.VarBinding(lv0, val)

    def transform_dataflowblock(self, block, mod, ctx):
        # just for demo purposes
        # Replace the first binding in the DataflowBlock
        new_bindings = [self.new_binding, block.bindings[1]]
        new_block = relax.expr.DataflowBlock(new_bindings, block.span)
        return new_block

@tvm.script.ir_module
class InputMod:
    @R.function
    def f1(x: Tensor[(m, n), "float32"]):
        with relax.dataflow():
            lv0 = relax.multiply(x, x)
            gv0 = relax.add(x, x)
            relax.output(gv0)
        return gv0
# block_pass is now a special pass that replaces every
# first binding to the constant value binding
block_pass = TestReplaceBinding()
# now every first binding in DataflowBlock of InputMod
# is replaced by new_binding
updated_mod = block_pass(InputMod)

The following code creates a dataflowblock pass by decorating a user defined transform function.

@relax.transform.dataflowblock_pass(opt_level=2)
def transform(block, mod, ctx):
    # my transformations here.
    return block

block_pass = transform
assert isinstance(block_pass, relax.transform.DataflowBlockPass)
assert block_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = block_pass(m)
# Now transform should have been applied to every DataflowBlock in
# the provided module m. And the updated module will be returned.
tvm.relax.transform.function_pass(pass_func=None, opt_level=None, name=None, required=None, traceable=False) Callable | FunctionPass

Decorate a function pass.

This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function.

Parameters:
  • pass_func (Optional[Callable[(Function, Module, PassContext) -> Function]]) – The transformation function or class.

  • opt_level (int) – The optimization level of this function pass.

  • name (Optional[str]) – The name of the function pass. The name could be empty. In this case, the name of the optimization function will be used as the pass name.

  • required (Optional[List[str]]) – The list of passes that the function pass is dependent on.

  • traceable (Boolean) – Boolean variable whether the function pass is traceable

Returns:

create_function_pass – A decorator will be returned if pass_func is not provided, otherwise return the decorated result. The returned decorator has two behaviors depending on the input: A new FunctionPass will be returned when we decorate a pass function. A new FunctionPass class will be returned when we decorate a class type.

Return type:

Union[Callable, FunctionPass]

Examples

The following code block decorates a function pass class.

@relax.transform.function_pass(opt_level=1)
class TestReplaceFunc:
    def __init__(self, new_func):
        self.new_func = new_func

    def transform_function(self, func, mod, ctx):
        # just for demo purposes
        # transform func to new_func
        return self.new_func

@R.function
def f1(x: Tensor[(m, n), "float32"]):
    return x

@tvm.script.ir_module
class InputMod:
    @R.function
    def f2(x: Tensor[(m, n), "float32"]):
        gv0 = relax.add(x, x)
        return gv0
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# now every function in InputMod is replaced by f1
updated_mod = fpass(InputMod)

The following code creates a function pass by decorating a user defined transform function.

@relax.transform.function_pass(opt_level=2)
def transform(func, mod, ctx):
    # my transformations here.
    return func

function_pass = transform
assert isinstance(function_pass, relax.transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now transform should have been applied to every function in
# the provided module m. And the updated module will be returned.