tvm.te¶
Namespace for Tensor Expression Language
- class tvm.te.ComputeOp¶
Scalar operation.
- class tvm.te.ExternOp¶
External operation.
- class tvm.te.HybridOp¶
Hybrid operation.
- property axis¶
Represent the IterVar axis, also defined when it is a HybridOp
- class tvm.te.PlaceholderOp¶
Placeholder operation.
- class tvm.te.ScanOp¶
Scan operation.
- property scan_axis¶
Represent the scan axis, only defined when it is a ScanOp
- class tvm.te.SpecializedCondition(conditions)¶
Specialized condition to enable op specialization.
- static current()¶
Returns the current specialized condition
- class tvm.te.Tensor¶
Tensor object, to construct, see function.Tensor
- property axis¶
Axis of the tensor.
- property ndim¶
Dimension of the tensor.
- property op¶
The corressponding
Operation
.
- property shape¶
The output shape of the tensor.
- property value_index¶
The output value index the tensor corresponds to.
- class tvm.te.TensorComputeOp¶
Tensor operation.
- class tvm.te.TensorSlice(tensor, indices)¶
Auxiliary data structure for enable slicing syntax from tensor.
- asobject()¶
Convert slice to object.
- property dtype¶
Data content of the tensor.
- tvm.te.abs(x, span=None)¶
Get absolute value of the input element-wise.
- tvm.te.acos(x)¶
Take acos of input x.
- tvm.te.acosh(x)¶
Take acos of input x.
- tvm.te.add(lhs, rhs, span=None)¶
Generic add operator.
- Parameters:
lhs (object) – The left operand.
rhs (object) – The right operand.
span (Optional[Span]) – The location of this operator in the source.
- Returns:
op – The result Expr of add operaton.
- Return type:
tvm.Expr
- tvm.te.all(*args, span=None)¶
- Create a new expression of the intersection of all conditions in the
arguments
- Parameters:
args (list) – List of symbolic boolean expressions
span (Optional[Span]) – The location of this operator in the source code.
- Returns:
expr – Expression
- Return type:
Expr
- tvm.te.any(*args, span=None)¶
Create a new experssion of the union of all conditions in the arguments
- Parameters:
args (list) – List of symbolic boolean expressions
span (Optional[Span]) – The location of this operator in the source code.
- Returns:
expr – Expression
- Return type:
Expr
- tvm.te.asin(x)¶
Take asin of input x.
- tvm.te.asinh(x)¶
Take asinh of input x.
- tvm.te.atan(x)¶
Take atan of input x.
- tvm.te.atanh(x)¶
Take atanh of input x.
- tvm.te.ceil(x, span=None)¶
Take ceil of float input x.
- tvm.te.comm_reducer(fcombine, fidentity, name='reduce')¶
Create a commutative reducer for reduction.
- Parameters:
fcombine (function(Expr -> Expr -> Expr)) – A binary function which takes two Expr as input to return a Expr.
fidentity (function(str -> Expr)) – A function which takes a type string as input to return a const Expr.
- Returns:
reducer – A function which creates a reduce expression over axis. There are two ways to use it:
accept (expr, axis, where) to produce an Reduce Expr on specified axis;
simply use it with multiple Exprs.
- Return type:
function
Example
n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, lambda t: tvm.tir.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B")
- tvm.te.compute(shape, fcompute, name='compute', tag='', attrs=None, varargs_names=None)¶
Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
- Parameters:
shape (Tuple of Expr) – The shape of the tensor
fcompute (lambda function of indices-> value) – Specifies the input source expression
name (str, optional) – The name hint of the tensor
tag (str, optional) – Additional tag information about the compute.
attrs (dict, optional) – The additional auxiliary attributes about the compute.
varargs_names (list, optional) – The names to use for each of the varargs. If not supplied, the varargs will be called i1, i2, …
- Returns:
tensor – The created tensor
- Return type:
- tvm.te.const(value, dtype='int32', span=None)¶
Create a new constant with specified value and dtype
- tvm.te.cos(x)¶
Take cos of input x.
- tvm.te.cosh(x)¶
Take cosh of input x.
- tvm.te.create_prim_func(ops: List[Tensor | Var], index_dtype_override: str | None = None) PrimFunc ¶
Create a TensorIR PrimFunc from tensor expression
- Parameters:
ops (List[Union[_tensor.Tensor, tvm.tir.Var]]) – The source expression.
Example
We define a matmul kernel using following code:
import tvm from tvm import te from tvm.te import create_prim_func import tvm.script A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") k = te.reduce_axis((0, 128), "k") C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") func = create_prim_func([A, B, C]) print(func.script())
If we want to use TensorIR schedule to do transformations on such kernel, we need to use create_prim_func([A, B, C]) to create a schedulable PrimFunc. The generated function looks like:
@T.prim_func def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j, k in T.grid(128, 128, 128): with T.block(): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] += A[vi, vk] * B[vj, vk]
- Returns:
func – The created function.
- Return type:
- tvm.te.create_schedule(ops)¶
Create a schedule for list of ops
- Parameters:
ops (list of Operations) – The source expression.
- Returns:
sch – The created schedule.
- Return type:
- tvm.te.decl_tensor_intrin(op, fcompute, name='tensor_intrin', binds=None, scalar_params=None, default_buffer_params=None)¶
Declare a tensor intrinsic function.
- Parameters:
op (Operation) – The symbolic description of the intrinsic operation
fcompute (lambda function of inputs, outputs-> stmt) –
Specifies the IR statement to do the computation. See the following note for function signature of fcompute
Note
Parameters
ins (list of
tvm.tir.Buffer
) - Placeholder for each inputsouts (list of
tvm.tir.Buffer
) - Placeholder for each outputs
Returns
stmt (
tvm.tir.Stmt
, or tuple of three stmts)If a single stmt is returned, it represents the body
If tuple of three stmts are returned they corresponds to body, reduce_init, reduce_update
name (str, optional) – The name of the intrinsic.
binds (dict of
Tensor
totvm.tir.Buffer
, optional) – Dictionary that maps the Tensor to Buffer which specified the data layout requirement of the function. By default, a new compact buffer is created for each tensor in the argument.scalar_params (a list of variables used by op, whose values will be passed) – as scalar_inputs when the tensor intrinsic is called.
default_buffer_params (Optional[dict]) – Dictionary of buffer arguments to be passed when constructing a buffer.
- Returns:
intrin – A TensorIntrin that can be used in tensorize schedule.
- Return type:
- tvm.te.div(a, b, span=None)¶
Compute a / b as in C/C++ semantics.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
When operands are integers, returns truncdiv(a, b, span).
- tvm.te.erf(x)¶
Take gauss error function of the input x.
- tvm.te.exp(x)¶
Take exponential of input x.
- tvm.te.extern(shape, inputs, fcompute, name='extern', dtype=None, in_buffers=None, out_buffers=None, tag='', attrs=None)¶
Compute several tensors via an extern function.
- Parameters:
shape (tuple or list of tuples.) – The shape of the outputs.
inputs (list of Tensor) – The inputs
fcompute (lambda function of inputs, outputs-> stmt) –
Specifies the IR statement to do the computation. See the following note for function signature of fcompute
Note
Parameters
ins (list of
tvm.tir.Buffer
) - Placeholder for each inputsouts (list of
tvm.tir.Buffer
) - Placeholder for each outputs
Returns
stmt (
tvm.tir.Stmt
) - The statement that carries out array computation.
name (str, optional) – The name hint of the tensor
dtype (str or list of str, optional) – The data types of outputs, by default dtype will be same as inputs.
in_buffers (tvm.tir.Buffer or list of tvm.tir.Buffer, optional) – Input buffers.
out_buffers (tvm.tir.Buffer or list of tvm.tir.Buffer, optional) – Output buffers.
- tag: str, optional
Additonal tag information about the compute.
- attrs: dict, optional
The additional auxiliary attributes about the compute.
- Returns:
tensor – The created tensor or tuple of tensors contains multiple outputs.
- Return type:
Tensor or list of Tensors
Example
In the code below, C is generated by calling external PackedFunc tvm.contrib.cblas.matmul
A = te.placeholder((n, l), name="A") B = te.placeholder((l, m), name="B") C = te.extern((n, m), [A, B], lambda ins, outs: tvm.tir.call_packed( "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], 0, 0), name="C")
- tvm.te.extern_primfunc(input_tensors: List[Tensor], primfunc: PrimFunc, **kwargs)¶
Compute tensors via a schedulable TIR PrimFunc
- Parameters:
- Returns:
tensor – The created tensor or tuple of tensors if it contains multiple outputs.
- Return type:
Tensor or list of Tensors
Example
In the code below, a TVMScript defined TIR PrimFunc is inlined into a TE ExternOp. Applying te.create_prim_func on this
A = te.placeholder((128, 128), name="A") B = te.placeholder((128, 128), name="B") @T.prim_func def before_split(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 C = te.extern_primfunc([A, B], func)
- tvm.te.floor(x: PrimExprWithOp, span=None)¶
Take floor of float input x.
- tvm.te.floordiv(a, b, span=None)¶
Compute the floordiv of two expressions.
- tvm.te.floormod(a, b, span=None)¶
Compute the floormod of two expressions.
- tvm.te.fmod(x, y)¶
Return the remainder of x divided by y with the same sign as x.
- tvm.te.gradient(output, inputs, head=None)¶
Perform reverse-mode automatic differentiation.
- Parameters:
output (Tensor) – The tensor to differentiate.
inputs (List[Tensor]) – The list of input tensors to be differentiated wrt.
head (Tensor) – The adjoint of the output, in other words, some tensor, by which the Jacobians will be multiplied. Its shape must be of the form prefix + output.shape. If None is passed, the identity tensor of shape output.shape + output.shape will be used.
- Returns:
tensors – The result gradient, in the same order as the inputs
- Return type:
List[Tensor]
Example
x = tvm.placeholder((32, 3, 28, 28), name='x') w1 = tvm.placeholder((10, 3, 3, 3), name='w1') w2 = tvm.placeholder((10, 10, 3, 3), name='w2') z1 = topi.nn.conv2d(x, w1, 1, 1, 1) z2 = topi.nn.conv2d(z1, w2, 1, 1, 1) y = topi.sum(z2) # produce gradients [dw1, dw2] = tvm.gradient(y, [w1, w2]) # produce Jacobians [jw1, jw2] = tvm.gradient(z2, [w1, w2]) # produce gradients, the head adjoint for z2 is provided manually [dw1, dw2] = tvm.gradient(z2, [w1, w2], topi.full_like(z2, 1.0))
- tvm.te.if_then_else(cond, t, f, span=None)¶
Conditional selection expression.
- Parameters:
- Returns:
result – The result of conditional expression.
- Return type:
Note
Unlike Select, if_then_else will not execute the branch that does not satisfy the condition. You can use it to guard against out of bound access. Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions.
- tvm.te.indexdiv(a, b, span=None)¶
Compute floor(a / b) where a and b are non-negative.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.
- tvm.te.indexmod(a, b, span=None)¶
Compute the remainder of indexdiv. a and b are non-negative.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
Use this function to split non-negative indices. This function may take advantage of operands’ non-negativeness.
- tvm.te.isfinite(x, span=None)¶
Check if input value is finite.
- tvm.te.isinf(x, span=None)¶
Check if input value is infinite.
- tvm.te.isnan(x, span=None)¶
Check if input value is Nan.
- tvm.te.log(x)¶
Take log of input x.
- tvm.te.log10(x)¶
Take log10 of input x.
- tvm.te.log2(x)¶
Take log2 of input x.
- tvm.te.max(expr, axis, where=None, init=None, *args)¶
Create a max expression over axis.
- Parameters:
- Returns:
value – The result value.
- Return type:
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this max reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.max represents tvm.te.max or tvm.tir.max. B = te.compute((m,), lambda i: tvm.max(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: max_res = tvm.max(m, n)
- tvm.te.max_value(dtype: str, span: Span | None = None) Any ¶
maximum value of dtype
- Parameters:
dtype (str) – The data type.
span (Optional[Span]) – The location of this operator in the source code.
- Returns:
value – The maximum value of dtype.
- Return type:
tvm.Expr
- tvm.te.min(expr, axis, where=None, init=None, *args)¶
Create a min expression over axis.
- Parameters:
- Returns:
value – The result value.
- Return type:
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this min reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.min represents tvm.te.min or tvm.tir.min. B = te.compute((m,), lambda i: tvm.min(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: min_res = tvm.min(m, n)
- tvm.te.min_value(dtype, span=None)¶
minimum value of dtype
- Parameters:
dtype (str) – The data type.
span (Optional[Span]) – The location of this operator in the source code.
- Returns:
value – The minimum value of dtype.
- Return type:
tvm.Expr
- tvm.te.multiply(lhs, rhs, span=None)¶
Generic multiply operator.
- Parameters:
lhs (object) – The left operand.
rhs (object) – The right operand.
span (Optional[Span]) – The location of this operator in the source.
- Returns:
op – The result Expr of multiply operaton.
- Return type:
tvm.Expr
- tvm.te.nearbyint(x, span=None)¶
Round elements of the array to the nearest integer. This intrinsic uses llvm.nearbyint instead of llvm.round which is faster but will results different from te.round. Notably nearbyint rounds according to the rounding mode, whereas te.round (llvm.round) ignores that. For differences between the two see: https://en.cppreference.com/w/cpp/numeric/math/round https://en.cppreference.com/w/cpp/numeric/math/nearbyint
- tvm.te.placeholder(shape, dtype=None, name='placeholder')¶
Construct an empty tensor object.
- tvm.te.popcount(x)¶
Count the number of set bits in input x.
- tvm.te.power(x, y, span=None)¶
x power y
- tvm.te.reduce_axis(dom, name='rv', thread_tag='', span=None)¶
Create a new IterVar for reduction.
- tvm.te.round(x, span=None)¶
Round elements of the array to the nearest integer.
- tvm.te.rsqrt(x)¶
Take reciprocal of square root of input x.
- tvm.te.scan(init, update, state_placeholder, inputs=None, name='scan', tag='', attrs=None)¶
Construct new tensors by scanning over axis.
- Parameters:
init (Tensor or list of Tensor) – The initial condition of first init.shape[0] timestamps
update (Tensor or list of Tensor) – The update rule of the scan given by symbolic tensor.
state_placeholder (Tensor or list of Tensor) – The placeholder variables used by update.
inputs (Tensor or list of Tensor, optional) – The list of inputs to the scan. This is not required, but can be useful for the compiler to detect scan body faster.
name (str, optional) – The name hint of the tensor
tag (str, optional) – Additonal tag information about the compute.
attrs (dict, optional) – The additional auxiliary attributes about the compute.
- Returns:
tensor – The created tensor or tuple of tensors contains multiple outputs.
- Return type:
Tensor or list of Tensors
Example
# The following code is equivalent to numpy.cumsum m = te.var("m") n = te.var("n") X = te.placeholder((m, n), name="X") s_state = te.placeholder((m, n)) s_init = te.compute((1, n), lambda _, i: X[0, i]) s_update = te.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i]) res = tvm.te.scan(s_init, s_update, s_state, X)
- tvm.te.sigmoid(x)¶
Quick function to get sigmoid
- tvm.te.sin(x)¶
Take sin of input x.
- tvm.te.sinh(x)¶
Take sinh of input x.
- tvm.te.size_var(name='size', dtype='int32', span=None)¶
Create a new variable represents a tensor shape size, which is non-negative.
- tvm.te.sqrt(x)¶
Take square root of input x.
- tvm.te.subtract(lhs, rhs, span=None)¶
Generic subtract operator.
- Parameters:
lhs (object) – The left operand.
rhs (object) – The right operand.
span (Optional[Span]) – The location of this operator in the source.
- Returns:
op – The result Expr of subtract operaton.
- Return type:
tvm.Expr
- tvm.te.sum(expr, axis, where=None, init=None, *args)¶
Create a sum expression over axis.
- Parameters:
- Returns:
value – The result value.
- Return type:
Example
m = te.var("m") n = te.var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), name="k") # there are two way to use this sum reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr # tvm.sum represents tvm.te.sum or tvm.tir.sum. B = te.compute((m,), lambda i: tvm.sum(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: sum_res = tvm.sum(m, n)
- tvm.te.tag_scope(tag)¶
The operator tag scope.
- Parameters:
tag (str) – The tag name.
- Returns:
tag_scope – The tag scope object, which can be used as decorator or context manger.
- Return type:
TagScope
Example
n = te.var('n') m = te.var('m') l = te.var('l') A = te.placeholder((n, l), name='A') B = te.placeholder((m, l), name='B') k = te.reduce_axis((0, l), name='k') with tvm.te.tag_scope(tag='matmul'): C = te.compute((n, m), lambda i, j: te.sum(A[i, k] * B[j, k], axis=k)) # or use tag_scope as decorator @tvm.te.tag_scope(tag="conv") def compute_relu(data): return te.compute(data.shape, lambda *i: tvm.tir.Select(data(*i) < 0, 0.0, data(*i)))
- tvm.te.tan(x)¶
Take tan of input x.
- tvm.te.tanh(x)¶
Take hyperbolic tanh of input x.
- tvm.te.thread_axis(dom=None, tag='', name='', span=None)¶
Create a new IterVar to represent thread index.
- Parameters:
- Returns:
axis – The thread itervar.
- Return type:
- tvm.te.trace(args, trace_action='tvm.default_trace_action')¶
Trace tensor data at the runtime.
The trace function allows to trace specific tensor at the runtime. The tracing value should come as last argument. The trace action should be specified, by default tvm.default_trace_action is used.
- Parameters:
args (list of Expr or Buffers.) – Positional arguments.
trace_action (str.) – The name of the trace action.
- Returns:
call – The call expression.
- Return type:
See also
tvm.tir.call_packed
Creates packed function.
- tvm.te.trunc(x, span=None)¶
Get truncated value of the input.
The truncated value of the scalar x is the nearest integer i which is closer to zero than x is.
- tvm.te.truncdiv(a, b, span=None)¶
Compute the truncdiv of two expressions.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
This is the default integer division behavior in C.
- tvm.te.truncmod(a, b, span=None)¶
Compute the truncmod of two expressions.
- Parameters:
- Returns:
res – The result expression.
- Return type:
Note
This is the default integer division behavior in C.