PyTorch while_loop

· November 27, 2024

(Updated: Apr 22, 2026)

I’ve been following the development of the higher order ops in PyTorch nightlies for a little bit, and got a chance to try out while_loop. The best examples right now are in the tests, but as another, here’s a mandlebrot example:

import torch
from torch._higher_order_ops.while_loop import while_loop
import matplotlib.pyplot as plt

def mandelbrot_step(z, c):
    """Performs one iteration of the Mandelbrot sequence."""
    return z**2 + c

def mandelbrot(c, max_iter, threshold):
    """Compute Mandelbrot set membership for a grid of complex numbers."""
    def cond_fn(z, iter_count, mask):
        return torch.any(mask & (iter_count < max_iter))

    def body_fn(z, iter_count, mask):
        z_next = mandelbrot_step(z, c)
        diverged = torch.abs(z_next) > threshold
        mask_next = mask & ~diverged
        iter_count_next = iter_count + mask_next
        return z_next, iter_count_next, mask_next

    # Initialize variables
    z0 = torch.zeros_like(c)
    iter_count = torch.zeros(c.shape, dtype=torch.int32)
    mask = torch.ones(c.shape, dtype=torch.bool)  # All points start as candidates
    final_state = while_loop(cond_fn, body_fn, (z0, iter_count, mask))
    
    _, iterations, _ = final_state
    return iterations

# Define the grid of complex numbers
x = torch.linspace(-2.0, 1.0, 500)
y = torch.linspace(-1.5, 1.5, 500)
xx, yy = torch.meshgrid(x, y)
complex_grid = xx + 1j * yy

# Compute the Mandelbrot set
max_iter = 100
threshold = 2.0
mandelbrot_set = mandelbrot(complex_grid, max_iter, threshold)

# Plot the Mandelbrot set
plt.figure(figsize=(10, 10))
plt.imshow(mandelbrot_set, extent=(-2, 1, -1.5, 1.5), cmap="inferno")
plt.colorbar(label="Iteration count")
plt.title("Mandelbrot Set")
plt.xlabel("Real")
plt.ylabel("Imaginary")
plt.show()

In general, the only non-obvious thing about while_loop is that the cond_fn is returning a tensor, not a bool, so make sure you are getting your types right, and that the shapes must be consistent from loop to loop. If you need more accumulating type behavior, look at scan!