I’ve been following the development of the higher order ops in PyTorch nightlies for a little bit, and got a chance to try out while_loop. The best examples right now are in the tests, but as another, here’s a mandlebrot example:
import torch
from torch._higher_order_ops.while_loop import while_loop
import matplotlib.pyplot as plt
def mandelbrot_step(z, c):
"""Performs one iteration of the Mandelbrot sequence."""
return z**2 + c
def mandelbrot(c, max_iter, threshold):
"""Compute Mandelbrot set membership for a grid of complex numbers."""
def cond_fn(z, iter_count, mask):
return torch.any(mask & (iter_count < max_iter))
def body_fn(z, iter_count, mask):
z_next = mandelbrot_step(z, c)
diverged = torch.abs(z_next) > threshold
mask_next = mask & ~diverged
iter_count_next = iter_count + mask_next
return z_next, iter_count_next, mask_next
# Initialize variables
z0 = torch.zeros_like(c)
iter_count = torch.zeros(c.shape, dtype=torch.int32)
mask = torch.ones(c.shape, dtype=torch.bool) # All points start as candidates
final_state = while_loop(cond_fn, body_fn, (z0, iter_count, mask))
_, iterations, _ = final_state
return iterations
# Define the grid of complex numbers
x = torch.linspace(-2.0, 1.0, 500)
y = torch.linspace(-1.5, 1.5, 500)
xx, yy = torch.meshgrid(x, y)
complex_grid = xx + 1j * yy
# Compute the Mandelbrot set
max_iter = 100
threshold = 2.0
mandelbrot_set = mandelbrot(complex_grid, max_iter, threshold)
# Plot the Mandelbrot set
plt.figure(figsize=(10, 10))
plt.imshow(mandelbrot_set, extent=(-2, 1, -1.5, 1.5), cmap="inferno")
plt.colorbar(label="Iteration count")
plt.title("Mandelbrot Set")
plt.xlabel("Real")
plt.ylabel("Imaginary")
plt.show()

In general, the only non-obvious thing about while_loop is that the cond_fn is returning a tensor, not a bool, so make sure you are getting your types right, and that the shapes must be consistent from loop to loop. If you need more accumulating type behavior, look at scan!
