Control flow#

Control flow is if, the loop family, and while — each maps to the obvious CUDA.

if#

A Python if / else becomes a CUDA if / else. Guard work by a thread/lane comparison, or elect a single issuing thread with T.ptx.elect_sync():

if tx < 128:
    A[tx] = A[tx] * T.float32(2.0)
else:
    A[tx] = A[tx] + T.float32(1.0)

if T.ptx.elect_sync():
    ...                              # one elected lane (e.g. to issue TMA/MMA)
if (((int)threadIdx.x) < 128) {
  A_ptr[tx] = A_ptr[tx] * 2.0f;
} else {
  A_ptr[tx] = A_ptr[tx] + 1.0f;
}

For an expression-level choice (no branch), use T.if_then_else(cond, a, b). It lowers to a ternary, so it introduces no control-flow divergence:

O_ptr[tx] = (A_ptr[tx] > 0.0f) ? A_ptr[tx] : 0.0f;

Uniform vs. divergent control flow#

Per-thread guards such as if tx < 128 are fine for ordinary work, but collective operations must be reached uniformly by every thread they synchronize.

For example, T.cuda.cta_sync() maps to __syncthreads(), which requires all threads in the thread block. It must never sit inside a thread- or warpgroup-divergent branch: if placed inside if wg_id == 0:, the other warpgroups will never arrive and the kernel will deadlock. When only one warpgroup needs to synchronize, use a warpgroup-scoped T.cuda.warpgroup_sync(id) (see Scaling GEMM with Warp Specialization and Clusters and CUDA C++/PTX intrinsics).

The same caution applies to barrier setup. An mbarrier .init() lowers to a single-thread guard (if (threadIdx.x < 1)). Nesting it inside another divergent branch can leave the barrier uninitialized, leading to unspecified launch failures.

loop#

Loops come in four flavors; a plain Python range becomes T.serial:

  • T.serial(n) — a sequential loop (ptxas may still unroll it).

  • T.unroll(n) — fully unrolled (expanded to straight-line statements).

  • T.vectorized(n) — a vectorized loop.

  • T.grid(*extents) — a nested loop nest.

break / continue work inside loops.

for i, j in T.grid(8, 8):
    B[i, j] = T.max(A[i, j], T.float32(0.0))
for (int i = 0; i < 8; ++i)
  for (int j = 0; j < 8; ++j)
    B_ptr[i * 8 + j] = max(A_ptr[i * 8 + j], 0.0f);

T.unroll(4) instead expands to four straight-line statements with no loop.

while#

A while loop runs until its condition is false. Use a mutable scalar counter (see Buffers and memory):

i: T.int32 = 0
while i < 64:
    A[i] = A[i] + T.float32(1.0)
    i += 1

It lowers to a while (1) with an early-exit break (the counter is a one-element register buffer):

int i_ptr[1];
i_ptr[0] = 0;
while (1) {
  if (!(i_ptr[0] < 64)) { break; }
  A_ptr[i_ptr[0]] = A_ptr[i_ptr[0]] + 1.0f;
  i_ptr[0] = i_ptr[0] + 1;
}