When you upgrade PyTorch or refactor model code, you might notice your torch.compile-optimized model suddenly running slower than before. The culprit is often unexpected graph breaks that prevent proper fusion and optimization.
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
def forward(self, x):
x = self.linear(x)
print(f"Shape: {x.shape}") # This causes a graph break!
return x.relu()
model = MyModel()
compiled_model = torch.compile(model)
x = torch.randn(32, 128)
output = compiled_model(x)
Running this code produces unexpected slowdowns because the print statement forces torch.compile to split the computation graph into multiple pieces.
Step 1: Understanding Graph Breaks in torch.compile
torch.compile works by tracing your model's operations and fusing them into optimized CUDA kernels. When it encounters operations it cannot trace or optimize, it creates a "graph break" - essentially splitting your model into separate compiled chunks with Python overhead between them.
Common causes of graph breaks:
- Print statements or logging calls
- Dynamic control flow (if statements based on tensor values)
- Data-dependent shapes
- Unsupported operations
- Accessing tensor
.item()or converting to Python scalars
- In-place operations on inputs
Each graph break adds overhead and prevents cross-boundary optimizations like kernel fusion.
Step 2: Identifying Graph Breaks Using Debug Flags
PyTorch provides several environment variables and debugging tools to diagnose compilation issues.
Using TORCH_LOGS
$ TORCH_LOGS="+dynamo" python your_script.py
This produces detailed output showing where graph breaks occur:
[WARNING] torch._dynamo.convert_frame: [Stack 0] Graph break: print
Using torch._dynamo.explain
import torch
import torch.nn as nn
class ProblematicModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 3)
self.conv2 = nn.Conv2d(64, 128, 3)
def forward(self, x):
x = self.conv1(x)
# Graph break here
if x.max() > 0.5:
x = x * 2
x = self.conv2(x)
return x
model = ProblematicModel()
# Explain compilation without actually compiling
explanation = torch._dynamo.explain(model)(torch.randn(1, 3, 32, 32))
print(explanation.graph_break_count) # Shows number of breaks
print(explanation.break_reasons) # Lists specific reasons
Output example:
Graph break count: 1
Break reasons: ['dynamic control flow: data-dependent if statement']
Using Verbose Logging
import torch
# Enable detailed compilation logs
torch._dynamo.config.verbose = True
torch._dynamo.config.log_level = "INFO"
compiled_model = torch.compile(model, backend="inductor")
Step 3: Fixing Common Graph Break Patterns
Pattern 1: Removing Debug Print Statements
Before (causes graph break):
def forward(self, x):
x = self.layer1(x)
print(f"After layer1: {x.shape}") # Graph break!
x = self.layer2(x)
return x
After (no graph break):
def forward(self, x):
x = self.layer1(x)
# Use torch._dynamo.is_compiling() to conditionally print
if not torch._dynamo.is_compiling():
print(f"After layer1: {x.shape}")
x = self.layer2(x)
return x
The is_compiling() check ensures print statements only run during eager mode, not during compilation tracing.
Pattern 2: Replacing Data-Dependent Control Flow
Before (causes graph break):
def forward(self, x):
x = self.encoder(x)
# Data-dependent branching
if x.mean() > 0:
x = self.path_a(x)
else:
x = self.path_b(x)
return x
After (no graph break):
def forward(self, x):
x = self.encoder(x)
# Use torch.where for conditional computation
condition = (x.mean() > 0).float()
result_a = self.path_a(x)
result_b = self.path_b(x)
x = condition * result_a + (1 - condition) * result_b
return x
This computes both paths and blends results based on the condition, allowing torch.compile to trace the entire graph.
Pattern 3: Avoiding Tensor to Python Scalar Conversions
Before (causes graph break):
def forward(self, x, mask):
x = self.process(x)
# Converting tensor to scalar
num_valid = mask.sum().item() # Graph break!
scale = 1.0 / num_valid
x = x * scale
return x
After (no graph break):
def forward(self, x, mask):
x = self.process(x)
# Keep as tensor operation
num_valid = mask.sum()
scale = 1.0 / (num_valid + 1e-8) # Add epsilon for stability
x = x * scale
return x
Keeping computations in tensor space allows continuous optimization.
Step 4: Using AOTAutograd for Detailed Analysis
AOTAutograd (Ahead-Of-Time Autograd) is torch.compile's intermediate representation layer. You can inspect it to understand how your model is being decomposed.
import torch
from torch._dynamo.backends.common import aot_autograd
class DebugModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(100, 100)
self.linear2 = nn.Linear(100, 10)
def forward(self, x):
x = self.linear1(x)
x = torch.relu(x)
# Intentional graph break
x = x.contiguous() # May cause issues
x = self.linear2(x)
return x
model = DebugModel()
# Compile with AOTAutograd inspection
compiled = torch.compile(
model,
backend="aot_eager", # Use eager backend for debugging
fullgraph=True # Enforce single graph (will error on breaks)
)
x = torch.randn(32, 100)
try:
output = compiled(x)
except Exception as e:
print(f"Graph break detected: {e}")
The fullgraph=True flag forces compilation to fail if any graph breaks occur, making debugging explicit.
Step 5: Profiling Performance Impact
Quantify the performance impact of graph breaks using torch.profiler:
import torch
from torch.profiler import profile, ProfilerActivity
class SlowModel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(512, 512) for _ in range(10)])
def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(x)
# Graph break in loop
if i % 2 == 0:
x = x.clone() # Potential break
return x
model = SlowModel().cuda()
compiled_model = torch.compile(model)
x = torch.randn(128, 512).cuda()
# Warmup
for _ in range(10):
_ = compiled_model(x)
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True
) as prof:
for _ in range(100):
output = compiled_model(x)
# Analyze
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Look for frequent "Python function calls" or "CUDAGraphs boundaries" in the profiler output - these indicate graph break overhead.
Step 6: Systematic Debugging Workflow
Here's a complete debugging script you can adapt:
import torch
import torch.nn as nn
from contextlib import contextmanager
import time
@contextmanager
def compile_diagnostics():
"""Context manager for comprehensive torch.compile debugging"""
# Store original settings
original_verbose = torch._dynamo.config.verbose
original_suppress_errors = torch._dynamo.config.suppress_errors
# Enable detailed logging
torch._dynamo.config.verbose = True
torch._dynamo.config.suppress_errors = False
torch._dynamo.config.log_level = "DEBUG"
try:
yield
finally:
# Restore settings
torch._dynamo.config.verbose = original_verbose
torch._dynamo.config.suppress_errors = original_suppress_errors
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 256)
)
def forward(self, x):
return self.net(x)
def benchmark_model(model, x, warmup=50, iterations=200):
"""Benchmark compiled vs eager mode"""
# Warmup
for _ in range(warmup):
_ = model(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iterations):
_ = model(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return elapsed / iterations
# Test setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TestModel().to(device)
x = torch.randn(128, 256).to(device)
# Baseline: eager mode
eager_time = benchmark_model(model, x)
print(f"Eager mode: {eager_time*1000:.3f} ms")
# Test compilation with diagnostics
with compile_diagnostics():
compiled_model = torch.compile(model, mode="max-autotune")
compiled_time = benchmark_model(compiled_model, x)
print(f"Compiled mode: {compiled_time*1000:.3f} ms")
speedup = eager_time / compiled_time
print(f"Speedup: {speedup:.2f}x")
# Check for graph breaks
explanation = torch._dynamo.explain(model)(x)
if explanation.graph_break_count > 0:
print(f"\nWarning: {explanation.graph_break_count} graph breaks detected!")
print("Reasons:")
for reason in explanation.break_reasons:
print(f" - {reason}")
Run this script to establish baseline performance and identify issues:
$ python debug_compile.py
Additional Tips and Edge Cases
Handling Dynamic Shapes
If your model accepts variable-sized inputs, use dynamic shapes configuration:
# Allow dynamic batch size
compiled_model = torch.compile(
model,
dynamic=True # Enable dynamic shape support
)
# Or specify which dimensions are dynamic
torch._dynamo.config.assume_static_by_default = False
Dealing with Custom CUDA Kernels
Custom CUDA operations often cause graph breaks. Wrap them properly:
from torch.library import custom_op
@custom_op("mylib::custom_kernel", mutates_args=())
def custom_kernel(x: torch.Tensor) -> torch.Tensor:
# Your custom CUDA kernel
return x * 2
# Register fake implementation for tracing
@custom_kernel.register_fake
def _(x):
return torch.empty_like(x)
Debugging Inductor Backend Issues
If the default inductor backend has issues, try alternatives:
# Test different backends
backends = ["inductor", "aot_eager", "cudagraphs"]
for backend in backends:
try:
compiled = torch.compile(model, backend=backend)
output = compiled(x)
print(f"{backend}: Success")
except Exception as e:
print(f"{backend}: Failed - {e}")
Common Pitfall: In-Place Operations
In-place operations on model inputs can prevent optimization:
# Bad: modifying input in-place
def forward(self, x):
x.mul_(2) # In-place operation
return self.net(x)
# Good: create new tensor
def forward(self, x):
x = x * 2 # Out-of-place operation
return self.net(x)
Monitoring Compilation Cache
torch.compile caches compiled graphs. Clear the cache when debugging:
# Clear compilation cache
torch._dynamo.reset()
# Check cache directory
import os
print(f"Cache dir: {torch._dynamo.config.cache_dir}")
Cache pollution from previous runs can mask graph break issues during development.
Using torch.export for Stricter Checking
For production deployment, use torch.export which has stricter requirements:
from torch.export import export
# This will fail loudly on any graph breaks
try:
exported = export(model, (x,))
print("Model successfully exported - no graph breaks!")
except Exception as e:
print(f"Export failed - graph breaks present: {e}")
Graph breaks that torch.compile tolerates will cause torch.export to fail, making them easier to catch during development.
Troubleshooting Checklist
Run through this checklist when facing performance regressions:
- Enable verbose logging:
TORCH_LOGS="+dynamo" python script.py - Check graph break count with
torch._dynamo.explain - Profile with
torch.profilerto quantify impact - Remove print/log statements in forward pass
- Replace data-dependent control flow with tensor operations
- Avoid
.item(),.tolist(), and scalar conversions - Test with
fullgraph=Trueto enforce single graph - Clear cache with
torch._dynamo.reset()between tests - Try
dynamic=Truefor variable input shapes - Consider alternative backends if inductor fails