How to Fix torch.compile Performance Regressions Caused by Graph Breaks



When you upgrade PyTorch or refactor model code, you might notice your torch.compile-optimized model suddenly running slower than before. The culprit is often unexpected graph breaks that prevent proper fusion and optimization.


import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(128, 128)
    
    def forward(self, x):
        x = self.linear(x)
        print(f"Shape: {x.shape}")  # This causes a graph break!
        return x.relu()

model = MyModel()
compiled_model = torch.compile(model)

x = torch.randn(32, 128)
output = compiled_model(x)


Running this code produces unexpected slowdowns because the print statement forces torch.compile to split the computation graph into multiple pieces.


Step 1: Understanding Graph Breaks in torch.compile


torch.compile works by tracing your model's operations and fusing them into optimized CUDA kernels. When it encounters operations it cannot trace or optimize, it creates a "graph break" - essentially splitting your model into separate compiled chunks with Python overhead between them.


Common causes of graph breaks:

  • Print statements or logging calls
  • Dynamic control flow (if statements based on tensor values)
  • Data-dependent shapes
  • Unsupported operations
  • Accessing tensor .item() or converting to Python scalars
  • In-place operations on inputs


Each graph break adds overhead and prevents cross-boundary optimizations like kernel fusion.


Step 2: Identifying Graph Breaks Using Debug Flags


PyTorch provides several environment variables and debugging tools to diagnose compilation issues.


Using TORCH_LOGS

$ TORCH_LOGS="+dynamo" python your_script.py


This produces detailed output showing where graph breaks occur:

[WARNING] torch._dynamo.convert_frame: [Stack 0] Graph break: print


Using torch._dynamo.explain

import torch
import torch.nn as nn

class ProblematicModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.conv2 = nn.Conv2d(64, 128, 3)
    
    def forward(self, x):
        x = self.conv1(x)
        # Graph break here
        if x.max() > 0.5:
            x = x * 2
        x = self.conv2(x)
        return x

model = ProblematicModel()

# Explain compilation without actually compiling
explanation = torch._dynamo.explain(model)(torch.randn(1, 3, 32, 32))

print(explanation.graph_break_count)  # Shows number of breaks
print(explanation.break_reasons)  # Lists specific reasons


Output example:

Graph break count: 1
Break reasons: ['dynamic control flow: data-dependent if statement']


Using Verbose Logging

import torch

# Enable detailed compilation logs
torch._dynamo.config.verbose = True
torch._dynamo.config.log_level = "INFO"

compiled_model = torch.compile(model, backend="inductor")


Step 3: Fixing Common Graph Break Patterns


Pattern 1: Removing Debug Print Statements


Before (causes graph break):

def forward(self, x):
    x = self.layer1(x)
    print(f"After layer1: {x.shape}")  # Graph break!
    x = self.layer2(x)
    return x


After (no graph break):

def forward(self, x):
    x = self.layer1(x)
    # Use torch._dynamo.is_compiling() to conditionally print
    if not torch._dynamo.is_compiling():
        print(f"After layer1: {x.shape}")
    x = self.layer2(x)
    return x


The is_compiling() check ensures print statements only run during eager mode, not during compilation tracing.


Pattern 2: Replacing Data-Dependent Control Flow


Before (causes graph break):

def forward(self, x):
    x = self.encoder(x)
    # Data-dependent branching
    if x.mean() > 0:
        x = self.path_a(x)
    else:
        x = self.path_b(x)
    return x


After (no graph break):

def forward(self, x):
    x = self.encoder(x)
    # Use torch.where for conditional computation
    condition = (x.mean() > 0).float()
    result_a = self.path_a(x)
    result_b = self.path_b(x)
    x = condition * result_a + (1 - condition) * result_b
    return x


This computes both paths and blends results based on the condition, allowing torch.compile to trace the entire graph.


Pattern 3: Avoiding Tensor to Python Scalar Conversions


Before (causes graph break):

def forward(self, x, mask):
    x = self.process(x)
    # Converting tensor to scalar
    num_valid = mask.sum().item()  # Graph break!
    scale = 1.0 / num_valid
    x = x * scale
    return x


After (no graph break):

def forward(self, x, mask):
    x = self.process(x)
    # Keep as tensor operation
    num_valid = mask.sum()
    scale = 1.0 / (num_valid + 1e-8)  # Add epsilon for stability
    x = x * scale
    return x


Keeping computations in tensor space allows continuous optimization.


Step 4: Using AOTAutograd for Detailed Analysis


AOTAutograd (Ahead-Of-Time Autograd) is torch.compile's intermediate representation layer. You can inspect it to understand how your model is being decomposed.


import torch
from torch._dynamo.backends.common import aot_autograd

class DebugModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(100, 100)
        self.linear2 = nn.Linear(100, 10)
    
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        # Intentional graph break
        x = x.contiguous()  # May cause issues
        x = self.linear2(x)
        return x

model = DebugModel()

# Compile with AOTAutograd inspection
compiled = torch.compile(
    model,
    backend="aot_eager",  # Use eager backend for debugging
    fullgraph=True  # Enforce single graph (will error on breaks)
)

x = torch.randn(32, 100)
try:
    output = compiled(x)
except Exception as e:
    print(f"Graph break detected: {e}")


The fullgraph=True flag forces compilation to fail if any graph breaks occur, making debugging explicit.


Step 5: Profiling Performance Impact


Quantify the performance impact of graph breaks using torch.profiler:

import torch
from torch.profiler import profile, ProfilerActivity

class SlowModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(512, 512) for _ in range(10)])
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # Graph break in loop
            if i % 2 == 0:
                x = x.clone()  # Potential break
        return x

model = SlowModel().cuda()
compiled_model = torch.compile(model)

x = torch.randn(128, 512).cuda()

# Warmup
for _ in range(10):
    _ = compiled_model(x)

# Profile
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True
) as prof:
    for _ in range(100):
        output = compiled_model(x)

# Analyze
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


Look for frequent "Python function calls" or "CUDAGraphs boundaries" in the profiler output - these indicate graph break overhead.


Step 6: Systematic Debugging Workflow


Here's a complete debugging script you can adapt:

import torch
import torch.nn as nn
from contextlib import contextmanager
import time

@contextmanager
def compile_diagnostics():
    """Context manager for comprehensive torch.compile debugging"""
    # Store original settings
    original_verbose = torch._dynamo.config.verbose
    original_suppress_errors = torch._dynamo.config.suppress_errors
    
    # Enable detailed logging
    torch._dynamo.config.verbose = True
    torch._dynamo.config.suppress_errors = False
    torch._dynamo.config.log_level = "DEBUG"
    
    try:
        yield
    finally:
        # Restore settings
        torch._dynamo.config.verbose = original_verbose
        torch._dynamo.config.suppress_errors = original_suppress_errors

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )
    
    def forward(self, x):
        return self.net(x)

def benchmark_model(model, x, warmup=50, iterations=200):
    """Benchmark compiled vs eager mode"""
    # Warmup
    for _ in range(warmup):
        _ = model(x)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    start = time.perf_counter()
    for _ in range(iterations):
        _ = model(x)
    
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    elapsed = time.perf_counter() - start
    return elapsed / iterations

# Test setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TestModel().to(device)
x = torch.randn(128, 256).to(device)

# Baseline: eager mode
eager_time = benchmark_model(model, x)
print(f"Eager mode: {eager_time*1000:.3f} ms")

# Test compilation with diagnostics
with compile_diagnostics():
    compiled_model = torch.compile(model, mode="max-autotune")
    compiled_time = benchmark_model(compiled_model, x)
    print(f"Compiled mode: {compiled_time*1000:.3f} ms")

speedup = eager_time / compiled_time
print(f"Speedup: {speedup:.2f}x")

# Check for graph breaks
explanation = torch._dynamo.explain(model)(x)
if explanation.graph_break_count > 0:
    print(f"\nWarning: {explanation.graph_break_count} graph breaks detected!")
    print("Reasons:")
    for reason in explanation.break_reasons:
        print(f"  - {reason}")


Run this script to establish baseline performance and identify issues:

$ python debug_compile.py


Additional Tips and Edge Cases


Handling Dynamic Shapes

If your model accepts variable-sized inputs, use dynamic shapes configuration:

# Allow dynamic batch size
compiled_model = torch.compile(
    model,
    dynamic=True  # Enable dynamic shape support
)

# Or specify which dimensions are dynamic
torch._dynamo.config.assume_static_by_default = False


Dealing with Custom CUDA Kernels

Custom CUDA operations often cause graph breaks. Wrap them properly:

from torch.library import custom_op

@custom_op("mylib::custom_kernel", mutates_args=())
def custom_kernel(x: torch.Tensor) -> torch.Tensor:
    # Your custom CUDA kernel
    return x * 2

# Register fake implementation for tracing
@custom_kernel.register_fake
def _(x):
    return torch.empty_like(x)


Debugging Inductor Backend Issues

If the default inductor backend has issues, try alternatives:

# Test different backends
backends = ["inductor", "aot_eager", "cudagraphs"]

for backend in backends:
    try:
        compiled = torch.compile(model, backend=backend)
        output = compiled(x)
        print(f"{backend}: Success")
    except Exception as e:
        print(f"{backend}: Failed - {e}")


Common Pitfall: In-Place Operations

In-place operations on model inputs can prevent optimization:

# Bad: modifying input in-place
def forward(self, x):
    x.mul_(2)  # In-place operation
    return self.net(x)

# Good: create new tensor
def forward(self, x):
    x = x * 2  # Out-of-place operation
    return self.net(x)


Monitoring Compilation Cache

torch.compile caches compiled graphs. Clear the cache when debugging:

# Clear compilation cache
torch._dynamo.reset()

# Check cache directory
import os
print(f"Cache dir: {torch._dynamo.config.cache_dir}")


Cache pollution from previous runs can mask graph break issues during development.


Using torch.export for Stricter Checking

For production deployment, use torch.export which has stricter requirements:

from torch.export import export

# This will fail loudly on any graph breaks
try:
    exported = export(model, (x,))
    print("Model successfully exported - no graph breaks!")
except Exception as e:
    print(f"Export failed - graph breaks present: {e}")


Graph breaks that torch.compile tolerates will cause torch.export to fail, making them easier to catch during development.


Troubleshooting Checklist


Run through this checklist when facing performance regressions:


  1. Enable verbose logging: TORCH_LOGS="+dynamo" python script.py
  2. Check graph break count with torch._dynamo.explain
  3. Profile with torch.profiler to quantify impact
  4. Remove print/log statements in forward pass
  5. Replace data-dependent control flow with tensor operations
  6. Avoid .item(), .tolist(), and scalar conversions
  7. Test with fullgraph=True to enforce single graph
  8. Clear cache with torch._dynamo.reset() between tests
  9. Try dynamic=True for variable input shapes
  10. Consider alternative backends if inductor fails

How to Fix Django FilteredRelation SQL Injection Vulnerability (CVE-2025-23687)