Three Framework Problems
ML Frameworks

Purpose
Why does the framework—the layer between your math and your hardware—silently constrain every decision that follows?
Neural networks are defined by mathematics (matrix multiplications, gradient computations, activation functions), but mathematics does not execute itself. Between the equations and the silicon lies a translation layer that decides how operations are scheduled on hardware, how memory is allocated across the compute hierarchy, and how gradients flow backward through the computational graph. The framework is this translation layer, and the translation is not neutral. An eager-mode framework that prioritizes debugging flexibility sacrifices the graph-level optimizations that can halve inference latency. A framework lacking support for the target accelerator renders the hardware investment useless. A framework with a rich training API but no export path to edge devices means the model cannot reach the deployment target it was designed for. Architecture choices are at least visible: engineers debate model size, layer count, and attention mechanisms explicitly. Framework choices are more insidious because they operate below the level of daily attention, silently determining which optimizations are possible, which hardware is reachable, and which deployment paths exist. In the AI Triad (Introduction), the framework is the invisible mediator between Algorithm and Machine, and its design choices—baked into its compilation stack, memory management, and operator libraries—are difficult to reverse. Migrating between frameworks invalidates data pipelines, serving infrastructure, model checkpoints, and team expertise, typically requiring months of engineering effort for production systems. Framework selection is therefore an infrastructure commitment that determines what the system can do on the hardware it must run on.
- Explain how ML frameworks solve three core problems: execution (computational graphs), differentiation (automatic differentiation), and abstraction (hardware-optimized operations)
- Compare eager, static, and hybrid (JIT) execution strategies using the Compilation Continuum and Dispatch Overhead principles to determine when compilation benefits outweigh costs
- Describe the nn.Module abstraction pattern for hierarchical composition, automatic parameter discovery, and mode-dependent behavior
- Analyze how the memory wall drives framework optimization strategies including kernel fusion, mixed-precision training, and activation checkpointing
- Evaluate major framework architectures (TensorFlow, PyTorch, JAX) based on their execution models, differentiation approaches, and deployment trade-offs
- Evaluate framework selection trade-offs by matching model requirements, hardware constraints, and deployment targets across the cloud-to-edge spectrum
Two lines of code: model = transformer(...) followed by loss.backward(). Between them, invisible to the programmer, the framework orchestrates billions of floating-point operations across memory hierarchies, computes exact gradients through millions of parameters using automatic differentiation (the systematic application of the chain rule to compute derivatives), schedules thousands of GPU kernel launches, and manages gigabytes of intermediate state. The simplicity is an illusion. Those two lines trigger machinery as complex as a compiler, because that is exactly what a modern ML framework is.
The architectures defined in Network Architectures specify what computations neural networks perform, but knowing what to compute is entirely different from knowing how to compute it efficiently. A transformer’s attention mechanism (introduced in Network Architectures) requires coordinating computation across memory hierarchies and accelerator cores in patterns that naive implementations would execute 100\(\times\) slower than optimized ones. Implementing these operations from scratch for every model would make deep learning economically infeasible. ML frameworks exist to bridge this gap by translating high-level model definitions into hardware-specific execution plans that extract maximum performance from silicon.
A framework is to machine learning what a compiler is to traditional programming. A C compiler translates human-readable code into optimized machine instructions, managing register allocation, instruction scheduling, and memory layout. An ML framework translates high-level model definitions into hardware-specific execution plans, managing operator fusion, memory reuse, and device placement. This analogy is more than metaphor: modern frameworks literally include compilers, as we will see throughout this chapter.
Every ML framework, regardless of API or design philosophy, must solve three core problems. First, the execution problem: when and how should computation happen? Should operations execute immediately as written (eager execution1), or should the framework build a complete description first—a computational graph2 (a structured representation of operations and their dependencies)—and optimize before executing (graph execution)? This choice shapes debugging capability, optimization potential, and deployment flexibility. Second, the differentiation problem: how should the framework compute gradients automatically? As established in Neural Computation, training (the complex orchestration detailed in Model Training) requires derivatives of a loss function with respect to millions or billions of parameters, and manual differentiation is error-prone at this scale. Frameworks must implement automatic differentiation systems that compute exact gradients for arbitrary compositions of operations while managing the memory overhead of storing intermediate values. Third, the hardware abstraction problem: how should the framework target diverse hardware from a single interface? The same model definition should run on CPUs, GPUs, Tensor Processing Units (TPUs), and mobile devices, each with different memory constraints and optimal execution patterns.
1 Eager Execution: This mode executes each operation sequentially and immediately, which enables direct debugging with standard tools but sacrifices the global view needed for graph-level optimizations. Without seeing the full sequence of computations, the framework cannot fuse operations or pre-plan memory, forfeiting potential speedups of over 30 percent that compilers like torch.compile can provide.
2 Computational Graph: The “optimize before executing” distinction in the triggering sentence is the key design choice. By capturing the full program as a data structure (pioneered by Theano in 2010), the framework can fuse multiple operations into a single GPU kernel before any code runs, reducing overhead by over 10\(\times\). The engineering cost of this visibility is that the executed program differs from the source code, making debugging significantly harder, a trade-off every graph-based framework must justify against the performance gain.
These three problems are deeply interconnected. The execution model determines when differentiation occurs and what optimizations are possible. The abstraction layer must support both execution styles across all hardware targets. Solving any one problem in isolation leads to frameworks that excel in narrow contexts but fail in broader deployment. Because these problems are ultimately about translating mathematics into efficient hardware execution, a useful perspective is to view frameworks not as libraries but as compilers.
Systems Perspective 1.1: The ML Compiler
Your “Source Code” is the model architecture (the \(O\) term). The framework’s job is to take this high-level math and compile it into a series of hardware-specific kernel launches that:
- Minimize Data Movement (\(D_{\text{vol}}\)) through techniques like kernel fusion.
- Maximize Utilization (\(\eta\)) by matching operations to specialized hardware units like Tensor Cores.
- Minimize Overhead (\(L_{\text{lat}}\)) through efficient asynchronous dispatch and graph capture.
Choosing a framework means choosing the compiler that determines how efficiently a model uses hardware.
With these three problems in mind, we can now define what a machine learning framework fundamentally is.
Definition 1.1: Machine Learning Frameworks
Machine Learning Frameworks are software systems that translate high-level mathematical model definitions into hardware-optimized execution plans by managing the computational graph, automatic differentiation, kernel dispatch, and memory allocation across the hardware hierarchy.
- Significance (Quantitative): Frameworks directly determine the system efficiency (\(\eta\)) term in the iron law. XLA’s operator fusion, for example, eliminates intermediate memory writes between consecutive elementwise operations: fusing a matrix multiplication, bias add, and rectified linear unit (ReLU) into a single kernel reduces the total data movement (\(D_{\text{vol}}\)) by 2–3\(\times\) vs. three separate kernel launches, yielding observed end-to-end speedups of 1.5–2\(\times\) on transformer training without any model changes.
- Distinction (Durable): Unlike a numerical library such as NumPy, which executes each operation immediately (eager evaluation), an ML framework can defer execution to analyze the full computational graph and apply global optimizations: operator fusion, memory layout transformations, and parallel scheduling. These optimizations are impossible when operations are evaluated one at a time.
- Common Pitfall: A frequent misconception is that frameworks are interchangeable API wrappers. Framework choice determines which hardware optimizations are available: a PyTorch model using the default eager execution mode cannot benefit from XLA’s graph-level fusion until explicitly compiled with
torch.compile(), and the resulting throughput difference can exceed 2\(\times\) on the same hardware.
The compiler metaphor is not decorative. An ML framework translates logical intent into physical execution under the constraints of the iron law, deciding how to partition computation across memory hierarchies, when to trade numerical precision for throughput, and how to schedule operations so that the dominant term (data movement, computation, or overhead) is minimized. The framework is where the governing physics developed throughout this book becomes executable code.
The scale of this translation is not obvious from the API surface. A single call to loss.backward() triggers operation recording, memory allocation for gradients, reverse-order graph traversal, and hardware-optimized kernel dispatch—machinery that would require hundreds of lines of manual calculus for even a three-layer network. For a contemporary language model, the framework additionally orchestrates billions of floating-point operations across distributed hardware, coordinating memory hierarchies, communication protocols, and numerical precision. Building this from scratch would be economically prohibitive for most organizations, which is why the history of ML frameworks is a history of progressively automating these layers.
The three problems—execution, differentiation, and abstraction—did not emerge simultaneously. Each arose as a response to scaling limitations in the previous generation of tools. Tracing this evolution reveals why modern frameworks are designed as they are and why the particular trade-offs they embody were, in hindsight, inevitable.
The Ladder of Abstraction
In 1979, writing a matrix multiplication in Fortran that saturated the hardware required deep knowledge of cache lines, register scheduling, and vector units. By 2016, a single line of Python (torch.matmul(A, B)) achieved the same peak throughput without the programmer knowing anything about the silicon. That compression of effort did not happen in one step; it accumulated across four decades of abstraction, each layer solving a bottleneck that made the previous generation impractical for scaling. The result is a Ladder of Abstraction where each rung automates what the rung below exposed.
- Solving Performance (1979–1992): The Basic Linear Algebra Subprograms (BLAS)3 and LAPACK4 solved the problem of Hardware Primitives. They provided standardized, highly optimized implementations of matrix operations such as general matrix multiply (GEMM)5. This layer ensures that
C = A @ Bruns at near-peak silicon speed, regardless of the language calling it.
3 BLAS (Basic Linear Algebra Subprograms): The 1979 API specification that forms the bottom rung of the ladder described here. By decoupling C = A @ B from its hardware implementation, BLAS forced vendors to compete on optimized libraries (NVIDIA cuBLAS, Intel MKL) for a fixed set of primitives. Every framework above it on the ladder inherits this bargain: a single BLAS call from any language can saturate an A100, achieving over 312 TFLOPS for GEMM alone, without the framework knowing anything about the silicon.
4 LAPACK (Linear Algebra PACKage): Extends BLAS by providing a standard API for higher-level routines (SVD, eigendecomposition, least-squares) that vendors implement with chip-specific code layered on top of fast GEMM kernels. This layered design is the architectural pattern every ML framework inherits: high-level operations delegate downward to hand-tuned primitives, so a vendor-optimized LAPACK call can execute over 10\(\times\) faster than a naive implementation without the framework author writing a single line of hardware-specific code.
5 GEMM: The single operation that the “near-peak silicon speed” claim rests on. Hardware vendors hand-tune GEMM for their specific chips because every layer in a neural network reduces to matrix multiplication, making this one routine the performance floor for all frameworks above it on the ladder. The catch: GEMM achieves peak throughput only when matrix dimensions satisfy strict alignment constraints (multiples of eight for NVIDIA Tensor Cores), and violating these rules drops a framework from over 90 percent to roughly 30 percent of \(R_{\text{peak}}\).
- Solving Usability (2006): NumPy6 solved the problem of Developer Velocity. By wrapping low-level BLAS routines in high-level Python, it allowed scientists to write code in a friendly language while executing it in optimized C/Fortran. This “Vectorization” pattern, where the slow language handles logic and the fast language handles loops, became the standard contract for scientific computing.
6 NumPy (Numerical Python): In 2005, Travis Oliphant unified two competing Python array libraries (Numeric and Numarray) into a single package, giving the scientific computing community one BLAS-backed array standard at the moment it needed to scale. The “vectorization” contract this created (write logic in Python, execute loops in C/Fortran via BLAS) became the design template for every ML framework that followed: PyTorch tensors and TensorFlow arrays are direct descendants, extending the same n-dimensional array abstraction to GPUs. Python’s dominance in ML is a direct inheritance from this consolidation decision.
- Solving Differentiation (2015–present): Deep Learning Frameworks (Theano7, TensorFlow, PyTorch) solved the problem of Gradient Computation. While NumPy required manual derivation of backpropagation gradients (error-prone and slow), these frameworks introduced Automatic Differentiation via the computational graph (Rumelhart et al. 1986). This turned the chain rule into a software primitive, allowing researchers to define forward passes and get backward passes for free.
7 Theano: Developed at the Montreal Institute for Learning Algorithms (MILA) under Yoshua Bengio starting in 2007, Theano was the first framework to compile symbolic mathematical expressions into optimized CPU and GPU code via computational graphs (Bergstra et al. 2010). Its key insight – that a Python-defined computation graph could be compiled to CUDA without the researcher writing GPU code – became the architectural template for TensorFlow (2015) and influenced PyTorch’s autograd design. Theano was retired in 2017, but every modern framework inherits its core abstraction.
As Figure 1 illustrates, this progression reveals a critical insight: frameworks exist to bridge the gap between mathematical intent and silicon reality. As we move up the ladder, we gain productivity but lose transparency—a trade-off we explore in the Execution Problem (Section 1.2).
Each generation abstracted away details that consumed engineering effort in the previous one, yet each abstraction introduced new trade-offs. BLAS hid assembly-level optimization but fixed the interface. NumPy hid memory management but required manual differentiation. Modern frameworks hide gradient computation but introduce the execution model choice we examine next.
All modern frameworks converge on the same three core problems: how to execute computation, how to differentiate it, and how to abstract across hardware. We begin with the most visible of these: the execution problem, because its resolution determines what optimizations the other two problems can exploit.
Execution Problem
Consider two engineers writing the same neural network. The first debugs interactively, printing tensor shapes after each operation, inspecting intermediate values, and stepping through code with pdb. The second waits 30 seconds for compilation, then watches the model run 3\(\times\) faster with no ability to inspect any intermediate state. Both are correct; they have simply made different choices about the execution problem, the question of whether operations should execute immediately as written or be recorded for later execution. This choice creates a cascade of engineering trade-offs that shape every aspect of framework behavior, from debugging workflows to deployment options to peak hardware utilization.
Why execution strategy matters: The memory wall
To understand why execution strategy matters so much, consider the memory wall (first introduced in Neural Computation), the growing gap between processor computational speed and memory bandwidth. Modern GPUs can perform arithmetic far faster than they can fetch data. On an A100 GPU with 312 TFLOPS of compute and 2.0 TB/s of memory bandwidth, element-wise operations like ReLU achieve less than one percent of peak compute capacity, not because the hardware is slow, but because they spend nearly all their time waiting for data. The Roofline Model (The roofline model) formalizes this trade-off, showing exactly when operations are memory bound vs. compute bound.
The memory wall creates a critical classification: operations are either compute-bound (limited by arithmetic throughput, like large matrix multiplications) or memory-bound (limited by data movement, like activation functions and normalization). Most individual neural network operation types (activations, normalizations, element-wise operations) are memory bound, though the large matrix multiplications that dominate total compute time can be compute bound.
The key optimization for memory-bound operations is kernel fusion, combining multiple operations into a single GPU function (called a kernel)8 to avoid intermediate memory traffic. Fusing a sequence of LayerNorm, Dropout, and ReLU into one kernel can yield 5\(\times\) speedup by eliminating intermediate writes between operations. FlashAttention9 fuses the entire attention computation, reducing HBM traffic by 10–20\(\times\) and achieving 2–4\(\times\) wall-clock speedup.
8 Kernel (GPU): In GPU programming, a kernel is the function dispatched to execute in parallel across thousands of threads. Each kernel launch incurs 5–20 \(\mu\)s of CPU-side overhead for parameter assembly and GPU signaling, which means that small, unfused operations spend more time on launch overhead (\(L_{\text{lat}}\)) than on useful arithmetic. Reducing kernel count through fusion is therefore a direct attack on the overhead term of the iron law.
9 FlashAttention: Kernel fusion taken to its logical extreme, fusing the entire attention computation (Q, K, V projections, softmax, output) into a single kernel that tiles data to fit in SRAM (introduced in Network Architectures). By reducing HBM traffic 10–20\(\times\), FlashAttention transforms a memory-bound operation into a compute-bound one, demonstrating that framework-level fusion can shift an operation’s position on the Roofline Model from bandwidth-limited to throughput-limited.
A framework can only fuse operations it can see together. If operations execute immediately one at a time (eager execution), the framework cannot fuse them. If operations are recorded first into a graph (deferred execution), the framework can analyze and optimize the entire computation. This is why execution strategy matters so much: it determines what optimizations are even possible.
The computational graph
Kernel fusion is the key optimization for memory-bound operations, but fusion requires seeing multiple operations together. How do frameworks represent computation in a way that makes this visibility possible? The answer is the computational graph, a directed acyclic graph (DAG) where nodes represent operations and edges represent data dependencies. This graph is the framework’s internal model of the computation.
To ground this abstraction, examine Figure 2: computing \(z = x \times y\) maps onto two input nodes (\(x\) and \(y\)), one operation node (multiplication), and one output node (\(z\)). The execution problem asks: when is this graph constructed, and when is it executed?
Real machine learning models require much more complex graph structures. Figure 3 extends this representation to show a neural network computation graph alongside the system components that reason about it. In the left panel, notice how data flows through six operation nodes in a directed acyclic graph—each node’s output becomes the next node’s input. The right panel reveals what the framework gains by having this graph: it can query the structure to plan memory allocation for each tensor’s lifetime, and it can assign operations to devices based on data dependencies rather than execution order. The critical insight is that the graph exists independently of execution, enabling the framework to optimize before any arithmetic occurs.
This graph representation is more than a visualization; it is the data structure that enables both efficient execution and automatic differentiation. The answer to when this graph is constructed creates a design choice with cascading implications:
- For debugging: Can you print intermediate values? Step through code with a debugger?
- For optimization: Can the framework see multiple operations at once to fuse them?
- For deployment: Can the model run without a Python interpreter?
- For flexibility: Can control flow depend on computed tensor values?
No single execution model optimizes all these dimensions. Frameworks must choose their position in this trade-off space, and practitioners must understand these trade-offs to select appropriate tools and write efficient code. The following sections examine how different execution strategies navigate these constraints.
Three execution strategies
The computational graph representation enables global optimization, but it raises a critical design question: when should the framework build this graph? Consider a simple operation like y = x * 2. Two distinct approaches exist:
Immediate execution: Perform the multiplication right now, storing the result in
y. Natural and debuggable, but the framework sees only one operation at a time.Deferred execution: Record the intention to multiply, building a graph of operations. Execute later when explicitly requested. Less intuitive, but the framework sees the complete computation, enabling optimization.
Neither approach dominates; each embodies different trade-offs between flexibility and optimization potential. Modern frameworks have explored three primary execution strategies: eager execution with dynamic graphs, static computation graphs, and hybrid approaches that combine just-in-time (JIT) compilation with eager development. We examine each through its systems implications.
Eager execution with dynamic graphs
Example 1.1: Eager vs. Graph Execution Code Comparison
PyTorch (Eager Execution):
import torch
x = torch.tensor([1.0, 2.0])
y = x * 2
print(f"Intermediate value: {y}") # Works immediately
z = y.sum()TensorFlow one.x (Static Graph):
import tensorflow as tf
x = tf.placeholder(tf.float32)
y = x * 2
# print(y) -> Prints Tensor("mul:0"...), not value!
z = tf.reduce_sum(y)
with tf.Session() as sess:
result = sess.run(z, feed_dict={x: [1.0, 2.0]})Eager execution runs operations immediately as encountered, building the computation graph dynamically during execution. When a programmer writes y = x * 2, the multiplication happens instantly and the result is available for immediate use.
This provides the flexibility of normal programming: developers can print intermediate values, use conditionals based on computed results, and debug with standard tools. The framework records operations as they happen, constructing a dynamic graph that reflects the actual execution path taken.
For gradient computation, the framework records a history of operations in what is called an autograd tape10, a transient data structure built during execution. Each tensor operation creates a node that records: the operation performed, references to input tensors, and how to compute gradients. These nodes form a directed acyclic graph (DAG) of operations built during forward pass execution, not before.
10 Autograd Tape: A transient data structure built during forward execution, where each node records the operation type, input tensor references, saved intermediate values, and the backward function for chain rule application. The tape’s memory cost scales linearly with network depth and is destroyed after the backward pass. For deep models, this transient graph can consume more memory than the model weights themselves, which is why activation checkpointing (trading recomputation for memory) becomes necessary for training models that would otherwise exhaust accelerator memory.
Consider this example using PyTorch, which implements eager execution as its default mode. Listing 1 shows how operations are recorded as they execute.
import torch
x = torch.tensor([1.0], requires_grad=True)
y = x * 2 # Executes immediately; records MulBackward node
z = y + 1 # Executes immediately; records AddBackward node
# The autograd tape exists NOW, built during executionAfter these two operations, the framework has constructed an autograd tape with two nodes: one for the multiplication and one for the addition. The tape records that z depends on y, and y depends on x.
Calling z.backward() traverses this tape in reverse topological order, applying the chain rule at each node:
- Compute \(\frac{\partial z}{\partial z} = 1\) (seed gradient)
- Call
AddBackward0.backward()\(\rightarrow \frac{\partial z}{\partial y} = 1\) - Call
MulBackward0.backward()\(\rightarrow \frac{\partial z}{\partial x} = 2\) - Accumulate gradient in
x.grad
After backward() completes, the autograd tape is destroyed to free memory. The next forward pass builds a completely new tape. This design enables memory-efficient training: the system pays for gradient computation storage only during the backward pass.
War Story 1.1: The Silent Gradient Killer
x += 1 instead of x = x + 1.
The Failure: In-place operations modify the data directly in memory. However, the autograd tape (the computational graph) often needs the original value of x to compute gradients for previous layers. By overwriting x, the engineer destroyed the history needed for the chain rule.
The Consequence: The framework did not crash. Instead, it computed gradients using the modified value of x, resulting in mathematically incorrect updates. The model trained, but its loss plateaued at a high value. The team spent weeks debugging hyperparameters, never suspecting that a “memory optimization” had silently corrupted the calculus.
The Systems Lesson: Frameworks are graph construction engines, and in-place operations violate the immutability required for automatic differentiation. Writing x += 1 does not merely add a number: it sabotages the graph’s history (Paszke et al. 2019).
Follow this “define-by-run” execution model step by step in Figure 4. Notice the alternating pattern: define, execute, define, execute. Each operation completes entirely before the next begins, which is why standard Python debuggers work—a developer can set a breakpoint between any two operations and inspect the actual tensor values. This contrasts sharply with static graphs, where all operations must be defined before any execution occurs.
Systems implications: Flexibility
The dynamic autograd tape enables capabilities impossible with static graphs. Conditionals and loops can depend on tensor values computed during execution, enabling algorithms like beam search, dynamic recurrent neural network (RNN) lengths, or adaptive computation that adjust their behavior based on intermediate results. Different iterations can process tensors of different sizes without redefining the computation—essential for natural language processing where sentence lengths vary. Because operations execute immediately in standard Python, developers can print tensors, inspect values, and use standard debuggers (pdb, breakpoints) to diagnose errors in the same way they would debug any Python program.
Systems implications: Overhead
This flexibility comes with performance costs that map directly to the iron law (Iron Law of ML Systems). Each forward pass rebuilds the autograd tape from scratch, adding Python object creation, reference counting, and node linking overhead to \(L_{\text{lat}}\) on every iteration. Every operation goes through Python dispatch—function lookup, argument parsing, type checking—costing ~10μs per operation, which becomes significant for models with thousands of operations. Because the graph is built during execution, the framework cannot see across operations to fuse kernels, so each operation launches its own GPU kernel, inflating both \(O\) and \(D_{\text{vol}}\). The autograd tape itself stores references to all intermediate tensors and Function nodes, increasing memory consumption by 2–3\(\times\) compared to forward-only execution and adding pressure to \(D_{\text{vol}}\). Together, these costs create a performance ceiling that becomes visible as models grow smaller and dispatch overhead dominates computation.
For a typical ResNet-50 forward pass, eager execution overhead adds approximately 5–10 ms compared to an optimized compiled version, with the majority spent in Python dispatch and tape construction rather than actual computation.
The dispatch tax: Python overhead vs. GPU reality
Eager execution’s performance ceiling is driven by a fundamental systems mismatch: the speed of the host-side interpreter vs. the speed of the device-side silicon. We quantify this using The Dispatch Tax, defined as the fraction of time spent in the host-side orchestration (Python) vs. actual device execution (GPU).
Every operation in an eager framework (like standard PyTorch) must pay a fixed “Tax” of approximately 10 \(\mu\)s for Python to lookup the function, check tensor types, and launch the kernel.
- For small operations (for example, a ReLU on a small vector), the kernel might execute in only 1 \(\mu\)s. The dispatch tax is 90 percent, meaning the GPU spends the vast majority of its time waiting for the next command.
- For large operations (for example, a massive \(4096\times4096\) matrix multiply), the kernel executes for 100 \(\mu\)s. The dispatch tax drops to 9 percent, and the system becomes compute bound.
The dispatch tax explains why models with many small layers run significantly slower than their raw FLOP count predicts. To reach the “Titan” standard of efficiency, frameworks must move from Kernel-by-Kernel Dispatch to Graph-Level Execution, where the dispatch tax is paid once for the entire graph rather than per operation. The hybrid JIT and compilation strategies in Section 1.2.4.2 exist precisely to address this overhead.
The overhead costs of eager execution raise a natural question: what if we could see the entire computation before executing any of it? This is precisely what static computation graphs provide.
Static computation graphs
Static graph execution defines the complete computational graph as a symbolic representation first, then executes it separately. This “define-then-run” execution model means the graph exists before any computation occurs, enabling aggressive ahead-of-time optimization. The key insight is that if the framework sees the entire computation before running it, the framework can analyze, transform, and optimize the graph globally—a visibility impossible when operations execute immediately one at a time.
Two-phase execution
Static graphs implement a clear separation between graph construction and execution. Listing 2 illustrates the two phases using TensorFlow one.x, which pioneered this approach: symbolic definition creates placeholders and operations without computation, while explicit execution triggers actual arithmetic:
# Phase 1: Graph Construction (symbolic, no computation)
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
# Define graph symbolically
x = tf.placeholder(tf.float32, shape=[1]) # Just a placeholder
y = x * 2 # Not executed, just recorded
z = y + 1 # Still no execution
# At this point, nothing has been computed
# Phase 2: Graph Execution (actual computation)
with tf.Session() as sess:
result = sess.run(z, feed_dict={x: [1.0]})
# Now computation happens: result = [3.0]Compare this with the dynamic model by examining Figure 5. Notice the clear boundary between phases: in the definition phase (left), the framework builds a complete blueprint without touching any data; in the execution phase (right), data flows through an already-optimized graph. This separation enables the framework to answer questions during the definition phase that are impossible to answer operation-by-operation: “Which intermediate tensors can share memory?” “Which operations can fuse into a single kernel?” “What is the total memory footprint?” By the time execution begins, these optimizations are already baked in.
The key difference from eager execution is that during construction, x, y, and z are not tensors containing values but rather symbolic nodes in a graph. Operations like * and + add nodes to the graph definition without performing any arithmetic. The print(y) line in the code example would reveal this distinction—it would print tensor metadata, not a computed value. Execution is triggered explicitly through sess.run(), at which point the framework analyzes the complete graph, optimizes it, and executes the optimized version with the provided input data.
Ahead-of-time optimization
Because the framework has the complete graph before execution, it can perform optimizations impossible in eager mode. The kernel fusion opportunity introduced in Section 1.2.1 becomes actionable here: because the framework sees y = x * 2 and z = y + 1 together in the graph, it can fuse them into z = x * 2 + 1, eliminating the intermediate y and halving memory traffic. With the full graph visible, the compiler can also calculate exact memory requirements for all tensors before execution, pre-allocating memory in a single pass and reusing buffers where lifetimes do not overlap. Tensor layouts can be transformed globally (for example, NCHW to NHWC) to match hardware preferences without runtime copying. Dead-code elimination (DCE)11 removes operations whose results are never consumed, and constant folding pre-computes operations on constant values at graph construction time, so the cost is paid once rather than on every forward pass.
11 Dead Code Elimination (DCE): Removes graph nodes whose results are never consumed by any downstream operation. In ML graphs, dead code arises from debugging operations left in production (print nodes, assertions), unused conditional branches, and gradient computations for frozen layers. For large transformer models, DCE eliminates 5–15 percent of graph nodes, reducing both \(O\) (fewer operations) and \(L_{\text{lat}}\) (fewer kernel launches). The DAG structure makes this safe: the framework verifies no downstream node depends on a candidate before removing it.
These optimizations map directly to iron law terms: kernel fusion reduces \(D_{\text{vol}}\) by eliminating intermediate memory writes, constant folding reduces \(O\) by computing values once, memory pre-allocation reduces \(L_{\text{lat}}\) by avoiding runtime allocation overhead, and dead code elimination reduces both \(O\) and \(D_{\text{vol}}\). Concretely, in large transformer models, constant folding and dead code elimination can reduce total FLOPs by 5-10% before the first batch even arrives.
Compilation frameworks like XLA (Accelerated Linear Algebra)12 (Google 2025) take this further, compiling the TensorFlow graph to optimized machine code for specific hardware. For a transformer encoder block, XLA can achieve 1.5–2\(\times\) speedup over unoptimized execution through aggressive fusion and hardware-specific code generation.
12 XLA (Accelerated Linear Algebra): The “optimized machine code” in the triggering sentence means XLA fuses an entire subgraph into one kernel, eliminating both launch overhead (\(L_{\text{lat}}\)) and intermediate memory writes (\(D_{\text{vol}}\)). The 1.5–2\(\times\) speedup for transformer blocks is modest because their large GEMM operations are already compute bound, leaving little overhead for fusion to remove. Memory-bound models see 3–10\(\times\) gains, where fusion hides the relative cost of many small, sequential operations.
Systems implications
Static graphs achieve high performance through ahead-of-time optimization. Kernel fusion reduces memory bandwidth requirements (often the bottleneck for ML workloads), and hardware-specific compilation enables near-peak utilization.
The cost of this performance is reduced flexibility. Standard Python control flow (if, for) cannot depend on computed tensor values in static graphs. TensorFlow provides graph-level control flow primitives (tf.cond and tf.while_loop) that support data-dependent conditions, but these require special syntax that diverges from standard Python, making code harder to write and reason about. Debugging is difficult because stack traces point to graph construction code, not execution code. Error messages often reference symbolic node names rather than the actual operations that failed.
Hybrid approaches: JIT compilation
Can we have both eager debugging and graph optimization? JIT compilation attempts this by capturing computation at runtime. The core trade-off is fidelity vs. generality. Tracing captures the exact execution path taken during a sample run, producing high fidelity to that specific input but missing branches not taken. Source-level compilation (scripting) analyzes the full program structure, preserving all control flow branches but requiring a restricted language subset. Both approaches produce an intermediate representation (IR)13 that enables the same ahead-of-time optimizations available to static graphs: operator fusion, constant folding, dead code elimination, and buffer reuse.
13 Intermediate Representation (IR): The “intermediate” captures this format’s architectural role: a language-independent layer that decouples the frontend (Python capture) from the backend (hardware code generation), exactly as LLVM IR decouples C/Rust/Swift frontends from x86/ARM backends. ML frameworks adopted this compiler pattern because it reduces the \(O(M \times N)\) cost of supporting \(M\) frontends and \(N\) backends to \(O(M + N)\): a single graph capture mechanism (TorchDynamo, tf2xla) can target multiple hardware backends without rewriting the capture logic.
The eager-vs.-compiled trade-off has a direct iron law consequence. JIT compilation amortizes the \(L_{\text{lat}}\) (dispatch overhead) across the compiled region. Longer compiled regions mean more overhead amortized per operation, which explains why graph breaks are performance-critical: each break forces a return to eager dispatch, resetting the amortization.
PyTorch’s TorchScript exemplifies both strategies. Tracing executes a function once with example inputs and records every tensor operation into a static computation graph. Listing 3 demonstrates the approach: the traced module becomes a compiled artifact that can be serialized, optimized, and executed independently of the Python interpreter:
import torch
def forward(x):
y = x * 2
z = y + 1
return z
# Trace the function by running it once
x_example = torch.tensor([1.0])
traced = torch.jit.trace(forward, x_example)
# traced is now a compiled TorchScript module
# Can serialize: torch.jit.save(traced, "model.pt")
# Can optimize: fusion, constant folding
# Can run without Python interpreterThe critical limitation of tracing reveals the fidelity-generality trade-off concretely. Because tracing records a single execution path, it cannot handle data-dependent control flow. Listing 4 illustrates a silent correctness failure.
def conditional_forward(x):
if x.sum() > 0: # Data-dependent condition
return x * 2
else:
return x * 3
traced = torch.jit.trace(conditional_forward, torch.tensor([1.0]))
# Tracing captures ONLY the x.sum() > 0 branch
# If input later has sum <= 0, traced version
# still executes x * 2 branchTracing records whichever branch executed during the example input. Subsequent executions always follow the traced path regardless of input values, silently producing incorrect results for inputs that would have taken the other branch. This failure mode is particularly dangerous because it produces no error, only wrong outputs. In production, such bugs can persist for months before anyone notices that a small fraction of inputs are being misclassified—and by then, debugging is a forensic exercise.
The alternative, scripting, achieves generality by analyzing Python source code directly and compiling it to TorchScript IR without executing. The scripting compiler parses the abstract syntax tree (AST), converts supported operations to IR operations, and preserves the branching structure so that both branches of a conditional exist in the compiled representation. The cost of this generality is a restricted Python subset: type annotations are required where inference fails, arbitrary Python objects and standard library modules are excluded, and dynamic metaprogramming is forbidden.
Tracing suits feed-forward models without conditionals (ResNet, VGG, Vision Transformer) and models where control flow depends only on hyperparameters fixed at trace time. Scripting suits models with data-dependent control flow (RNN variants, recursive networks, adaptive computation) and deployment to environments without a Python interpreter. The following examples demonstrate scripting syntax (Listing 5), control flow preservation (Listing 6), language restrictions (Listing 8), and IR inspection (Listing 7).
@torch.jit.script
def forward(x):
y = x * 2
z = y + 1
return z
# Compiles Python source code to TorchScript IR
# No example inputs needed
# Preserves control flow structureThe key advantage of scripting appears when handling conditionals. Unlike tracing, which captures only one branch, scripting preserves both paths in the IR.
@torch.jit.script
def conditional_forward(x: torch.Tensor) -> torch.Tensor:
if x.sum() > 0:
return x * 2
else:
return x * 3
# Both branches preserved in IR
# Correct branch executes based on runtime input valuesTo understand what the compiler produces, we can inspect the generated intermediate representation directly.
@torch.jit.script
def example(x: torch.Tensor) -> torch.Tensor:
return x * 2 + 1
# Inspect generated IR:
print(example.graph)
# graph(%x : Tensor):
# %1 : int = prim::Constant[value=2]()
# %2 : Tensor = aten::mul(%x, %1)
# %3 : int = prim::Constant[value=1]()
# %4 : Tensor = aten::add(%2, %3, %3)
# return (%4)However, scripting imposes constraints on what Python constructs are supported.
@torch.jit.script
def invalid_script(x):
import numpy as np # ERROR: Cannot import arbitrary modules
result = np.array([1, 2, 3]) # ERROR: NumPy not supported
print(f"Debug: {x}") # ERROR: f-strings not supported
return result
# Valid alternative:
@torch.jit.script
def valid_script(x: torch.Tensor) -> torch.Tensor:
# Use TorchScript-compatible operations
result = torch.tensor([1, 2, 3], dtype=x.dtype, device=x.device)
return resultScripting requires a restricted Python subset because TorchScript must statically analyze code that Python normally interprets dynamically. Function signatures and variables need explicit type annotations when type inference fails, and only tensor operations, numeric types, and standard containers (lists, dicts, tuples) are permitted—no arbitrary Python objects, no standard library modules like os or sys, and no dynamic class modification or metaprogramming. These constraints are the price of compilation: every feature that makes Python flexible also makes it unpredictable for a compiler.
The TorchScript IR represents operations using the aten namespace for core tensor operations, the prim namespace for primitives and control flow, static types for every value, and static single-assignment (SSA) form, where each variable is assigned exactly once to simplify compiler analysis. This IR enables optimizations independent of Python: operator fusion combines adjacent operations into single kernels, constant folding evaluates constant expressions at compile time, dead code elimination removes unused operations, and memory optimization reuses buffers when possible. Table 1 summarizes the key trade-offs between these two approaches.
| Aspect | Tracing | Scripting |
|---|---|---|
| Input requirement | Example inputs needed | No inputs needed |
| Control flow | Cannot handle data-dependent | Supports data-dependent |
| Conversion ease | Simpler (just run function) | Harder (restricted Python) |
| Type annotations | Not required | Required when inference fails |
| Error detection | Runtime (wrong results) | Compile time (syntax errors) |
| Best for | Feed-forward models | Models with conditionals |
Modern compilation: torch.compile
The previous approaches force a choice: write flexible code (eager execution) or fast code (static graphs). Modern JIT compilation attempts to eliminate this trade-off by automatically compiling eager code into optimized graphs with minimal developer intervention.
PyTorch 2.0’s torch.compile (Ansel et al. 2024) represents this approach: developers write natural Python code that executes eagerly during development, but the framework automatically captures and compiles hot paths into optimized kernels for production. Listing 9 shows the basic usage pattern:
@torch.compile
def forward(x):
return x * 2 + 1
# First call: captures execution, compiles optimized kernel (~100ms)
result1 = forward(torch.tensor([1.0]))
# Reuse compiled code
forward(torch.randn(10, 10))The compilation overhead in these examples (approximately 100 ms to compile the first time, microseconds to reuse) illustrates why torch.compile is so effective. The deeper question is why compilation helps so much. The answer lies in understanding the physics of software overhead. Dispatch costs that seem negligible for a single operation—a few microseconds here and there—compound dramatically across the thousands of operations in a forward pass. The following analysis quantifies this phenomenon.
Napkin Math 1.1: The Physics of Software Overhead
The Constants of Latency:
- Python Dispatch: ~10 μs per operation.
- Kernel Launch: ~5 μs per operation.
- Memory Access (VRAM): ~1 μs.
Scenario one: Eager Mode (The “Tiny Op” Trap) Consider a simple activation block: y = relu(x + bias).
Operations: two (Add, ReLU).
Execution:
- Launch
AddKernel: 15 µs overhead. - Read/Write Memory: \(2N\) bytes.
- Launch
ReLUKernel: 15 µs overhead. - Read/Write Memory: \(2N\) bytes.
- Launch
Total Overhead: 30 µs.
Total Memory Traffic: \(4N\) bytes.
Scenario two: Compiled Mode (Fusion) The compiler fuses this into one kernel: FusedAddRelu.
Execution:
- Launch
FusedKernel: 15 µs overhead. - Read/Write Memory: \(2N\) bytes (intermediate result stays in registers).
- Launch
Total Overhead: 15 µs (2\(\times\) speedup).
Total Memory Traffic: 2N bytes (2\(\times\) bandwidth efficiency).
The Conclusion: Compilation is not magic; it is overhead amortization. For small, element-wise operations such as LayerNorm, Gaussian Error Linear Unit (GELU), and Add, overhead often exceeds compute time by 10–100\(\times\). Fusing them is the only way to use the hardware effectively.
See this tax play out concretely in Figure 6. Notice how eager execution (top) creates “gaps” where the GPU sits idle while Python dispatches the next kernel. The blue compute regions are short; the red dispatch regions are comparatively long. Compilation (bottom) fuses these operations into a single kernel launch, eliminating the gaps entirely so the GPU spends nearly all its time computing rather than waiting.
The natural question is: can this fusion happen automatically? PyTorch 2.0’s torch.compile14 attempts exactly this by capturing eager code and compiling it into fused kernels without requiring users to write custom CUDA.15
14 torch.compile: It enables this automatic fusion by intercepting Python bytecode (via TorchDynamo) to extract a computational graph from unmodified eager code. This graph is then compiled into optimized kernels, trading a one-time compilation delay for a permanent 1.3–\(2\times\) throughput gain on transformer models by reducing kernel launch overhead.
15 CUDA (Compute Unified Device Architecture): NVIDIA’s parallel computing platform (2007) serving as the foundational layer between high-level Python operations and GPU silicon. When PyTorch executes torch.matmul(A, B), the call traverses the framework’s dispatcher, selects a cuBLAS kernel, and launches it on the GPU. Each launch incurs 5–20 \(\mu\)s of CPU-side overhead. For small operations, this dispatch overhead (\(L_{\text{lat}}\)) exceeds the useful compute time, which is why compilation (fusing \(N\) operations into one kernel launch) yields speedups proportional to the reduction in launch count rather than the reduction in arithmetic.
Architecture: Three-stage compilation pipeline
torch.compile consists of three coordinated components, each handling a distinct phase of the compilation process:
TorchDynamo (graph capture): Intercepts Python bytecode execution using CPython’s PEP 523 frame evaluation API. Unlike
torch.jit.trace, which records a single execution path and silently ignores alternative branches, TorchDynamo also captures operations during execution but inserts graph breaks when it encounters unsupported code (print statements, arbitrary Python), ensuring correctness rather than silent failure. The current graph is finalized for compilation, unsupported code executes eagerly, and a new graph begins after.FX Graph (intermediate representation): Operations captured by TorchDynamo are converted to FX graph format, PyTorch’s node-based directed acyclic graph where each node represents an operation with explicit inputs and outputs. The FX graph serves as PyTorch’s analog to LLVM IR: a standardized representation that separates frontend (Python code capture) from backend (hardware-specific code generation). This design allows different backends such as TorchInductor, Open Neural Network Exchange (ONNX) Runtime, and TensorRT to consume FX graphs and enables optimization passes such as dead code elimination, constant folding, and pattern matching for fusion opportunities.
TorchInductor16 (code generation): The default backend that compiles FX graphs to optimized machine code. For CUDA GPUs, TorchInductor generates Triton17 kernels, a Python-based GPU kernel language that compiles to Parallel Thread Execution (PTX)18. For CPUs, it generates C++ code with vectorization instructions (AVX2, AVX-512). TorchInductor applies three key optimizations: kernel fusion (combining operations to reduce memory traffic), memory layout optimization (choosing tensor layouts that minimize access overhead), and autotuning (measuring performance across implementation variants to select the fastest).
16 TorchInductor: The use of Triton to generate GPU code is a deliberate trade-off that prioritizes fast JIT compilation speed over achieving maximum hardware performance. This makes on-the-fly optimization practical for an eager-execution framework, even if the resulting kernels are 5–20 percent slower than highly optimized, hand-written CUDA.
17 Triton: TorchInductor generates Triton because its Python-like syntax provides a simpler, more stable compilation target than raw CUDA, making automated code generation tractable. This abstraction allows the compiler to handle complex GPU details like memory coalescing automatically, a requirement for performing kernel fusion. The accepted trade-off is achieving 80–95 percent of hand-tuned CUDA performance in exchange for enabling the compiler to effectively autotune kernels and reduce development time from weeks to hours.
18 PTX: An intermediate representation (IR) from NVIDIA that serves as a stable compilation target for high-level GPU languages like Triton. This allows TorchInductor to generate portable code, as the NVIDIA driver—not the framework—is responsible for the final translation to hardware-specific machine code (SASS). This forward compatibility, however, can result in performance that is 10–15 percent slower than kernels hand-tuned for a specific GPU architecture.
The generated code is cached on disk: TorchInductor maintains its own compilation cache, and Triton kernels are additionally cached in ~/.triton/cache/. Subsequent runs with the same input shapes can skip compilation and directly execute cached code.
Execution flow
The first execution follows a multi-step process: TorchDynamo intercepts bytecode and records operations into FX graph, FX graph is passed to TorchInductor for compilation (5–30 seconds for transformer models), and compiled code is cached and executed. Subsequent executions with the same input shapes dispatch directly to compiled code with microseconds overhead. If input shapes change, TorchInductor must recompile for the new shapes (shape specialization). PyTorch maintains separate compiled versions for each unique shape configuration.
Graph breaks: Causes and detection
Graph breaks occur when torch.compile encounters code it cannot compile, forcing execution to fall back to eager mode. Understanding graph break causes provides the foundation for achieving good performance.
Data-dependent control flow requires tensor values unavailable at compile time, as shown in Listing 10.
@torch.compile
def conditional_compute(x):
if x.sum() > 0: # Graph break: tensor value needed
return x * 2
else:
return x * 3
# Creates two compiled regions: operations before
# and after the if statement
# The if statement itself executes eagerlyTorchDynamo creates a graph break: operations before the if statement are compiled, the if statement executes eagerly (evaluating which branch to take), and the chosen branch is compiled as a separate region.
Unsupported operations also cause graph breaks, as Listing 11 demonstrates.
print force a graph break, splitting compiled code into two regions with eager execution in between.
@torch.compile
def debug_compute(x):
y = x * 2
print(f"y = {y}") # Graph break: I/O operation
z = y + 1
return z
# Creates two compiled regions: before and after printCommon unsupported operations include I/O (print, file operations), custom Python objects, and calls to non-PyTorch libraries. Each graph break incurs overhead: tensors must be marshalled from compiled code back to Python (possibly copying from GPU to CPU), the eager operation executes, and results are marshalled into the next compiled region.
Shape changes prevent compiled code reuse, as Listing 12 illustrates.
@torch.compile
def variable_length(x, length):
return x[:, :length] # Shape changes each call
# Each unique length triggers recompilation
for i in range(10):
result = variable_length(x, i) # 10 recompilationsDetect graph breaks using Listing 13.
TORCH_LOGS to graph_breaks prints each break location and reason during execution.
TORCH_LOGS="graph_breaks" python train.pyThis prints each break location and reason: Graph break in user code at file.py:15/Reason: call to unsupported function print. Minimizing graph breaks is key to performance: move unsupported operations outside compiled regions, replace data-dependent control flow with conditional execution (torch.where), or accept eager execution for inherently dynamic sections.
Compilation modes and backends
As a project matures from prototyping to production, engineers progressively increase compilation aggressiveness. The default mode (mode='default') applies moderate optimization with fast compilation (5–30 seconds for transformer models), making it suitable for development and training where compilation overhead is amortized over many iterations. When deploying an inference server with fixed input shapes, mode='reduce-overhead' minimizes Python interpreter overhead by aggressively capturing operations and enabling CUDA graphs that batch kernel launches, improving throughput by 20–40 percent over the default. For production training that will run for days, mode='max-autotune' generates and benchmarks multiple implementation variants for each operation, increasing compilation time (minutes to hours for large models) but improving runtime performance by 10–30 percent. This progression—default for development, reduce-overhead for inference, max-autotune for long training runs—mirrors the Compilation Continuum principle we formalize later.
The compilation mode controls how aggressively to optimize; the backend controls what target to optimize for. TorchInductor (the default) generates Triton kernels for CUDA and C++ for CPU, providing the best general-purpose performance for both training and inference. When cross-platform deployment is required, the ONNX Runtime backend exports the FX graph to ONNX format, enabling execution on CPUs, GPUs, mobile, and edge devices—though limited ONNX operation coverage may cause more graph breaks. For maximum inference throughput on NVIDIA GPUs, the TensorRT backend compiles to NVIDIA’s inference engine with aggressive int8 quantization, layer fusion, and kernel autotuning, often achieving 1.5–2\(\times\) speedup over TorchInductor. The trade-off is clear: each backend narrows the target to unlock deeper optimization, echoing the flexibility-vs.-performance axis that distinguishes eager from graph execution.
Practical example: Measuring speedup
Listing 14 implements correct GPU benchmarking methodology, incorporating CUDA synchronization, warmup iterations to exclude compilation time, and sufficient iterations to amortize measurement overhead:
import torch
import time
def forward(x, w):
return torch.matmul(x, w).relu()
x = torch.randn(1024, 1024, device="cuda")
w = torch.randn(1024, 512, device="cuda")
# Eager mode benchmark
torch.cuda.synchronize() # Ensure GPU operations complete
start = time.time()
for _ in range(100):
y = forward(x, w)
torch.cuda.synchronize() # Wait for GPU kernel completion
eager_time = time.time() - start
# Compiled mode benchmark
forward_compiled = torch.compile(forward)
forward_compiled(x, w) # Warmup: trigger compilation
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
y = forward_compiled(x, w)
torch.cuda.synchronize()
compiled_time = time.time() - start
print(f"Speedup: {eager_time/compiled_time:.2f}$\times$ ")
# Typical: 2-5x speedup for matrix operationsCritical benchmarking details: (1) Use torch.cuda.synchronize() because CUDA operations are asynchronous; without synchronization, timing measures only kernel launch time, not execution time. (2) Warmup compilation by calling once before timing to exclude compilation from measurements. (3) Run 100+ iterations to amortize measurement overhead.
Systems implications
First execution includes compilation time: 5–10 s for small models, 30–60 s for BERT-base transformers, 5–10 min for GPT-3 scale models. This overhead is amortized across training (compile once, train for thousands of iterations) but impacts development iteration time. Compiled kernels are cached on disk; subsequent runs skip compilation.
Compilation adds overhead: 100–500 MB for FX graph construction, 500 MB–2 GB peak during Triton compilation, 10–100 MB per compiled graph for storage. Runtime memory usage is similar to eager mode (kernel fusion can reduce intermediate tensors but compiled code may allocate temporary buffers). Compiled models typically use 90–110 percent of eager mode memory.
Errors in compiled code produce stack traces pointing to generated code, not source Python code. Print statements inside compiled regions cause graph breaks (executed eagerly, not compiled). For debugging, remove @torch.compile to revert to eager execution, fix bugs, then re-enable compilation. Use TORCH_COMPILE_DEBUG=1 for verbose compilation logs.
When to use torch.compile
The decision follows directly from the compilation cost model. Long training runs amortize compilation overhead across hundreds of iterations, and stable architectures with fixed control flow minimize graph breaks—making training the strongest use case. Inference is equally compelling: a deployed model compiles once at startup and serves thousands of requests, where mode='reduce-overhead' minimizes per-request overhead. Compilation should be deferred, however, during rapid prototyping, where the overhead slows iteration time and the architecture has not yet stabilized. Models with frequent graph breaks or dynamic shape changes prevent effective compilation, and debugging is harder in compiled mode because error locations point to generated code rather than source Python. The practical strategy is to develop in eager mode, stabilize the architecture, then enable compilation for training and deployment.
Comparison of execution models
Table 2 contrasts the three execution models across six dimensions, revealing that hybrid JIT compilation achieves most of static graph performance while preserving much of eager execution’s flexibility:
| Aspect | Eager + Autograd Tape (PyTorch default) | Static Graph (TensorFlow 1.x) | JIT Compilation (torch.compile) |
|---|---|---|---|
| Execution Model | Immediate | Deferred | Hybrid |
| Graph Construction | During forward pass | Before execution | First execution (cached) |
| Optimization | None (per-operation) | Ahead-of-time | JIT compilation |
| Dynamic Control Flow | Full support | Limited (static unroll) | Partial (graph breaks) |
| Debugging | Easy (standard Python) | Difficult (symbolic) | Moderate (mixed) |
| Performance | Baseline | High (optimized) | High (compiled regions) |
Eager mode’s primary value is in the “Workflow Iteration” loop (ML Workflow): it allows using standard Python debuggers (like PDB) to inspect variables mid-execution, whereas graph-mode debugging often requires specialized framework tools. This immediate feedback accelerates the prototyping phase of the ML lifecycle.
Beyond these core execution trade-offs, Table 3 highlights additional systems-level distinctions between static and dynamic approaches:
| Aspect | Static Graphs | Dynamic Graphs |
|---|---|---|
| Memory Management | Precise allocation planning, optimized memory usage | Flexible but potentially less efficient |
| Hardware Utilization | Can generate highly optimized hardware-specific code | May sacrifice hardware-specific optimizations |
| Research Velocity | Slower iteration due to define-then-run requirement | Faster prototyping and model experimentation |
| Integration with Legacy Code | More separation between definition and execution | Natural integration with imperative code |
These trade-offs are not binary choices. Modern frameworks offer a spectrum of options, which raises the quantitative question of where on this spectrum a given project should operate.
Quantitative principles of execution
These execution models present a spectrum of trade-offs, but engineers need more than intuition to navigate them. Two quantitative principles formalize the decision. The Compilation Continuum Principle establishes when the performance gains from compilation justify its development cost, expressed as a ratio of production executions to development iterations. The Dispatch Overhead Law quantifies the per-operation cost of framework flexibility, revealing why small operations in eager mode can spend more time in Python overhead than in actual computation. Together, these principles transform framework selection from subjective preference into measurable engineering analysis.
The compilation continuum principle
The Execution Problem demands a quantitative principle: when should a project compile?
The execution models form a continuum from maximum flexibility to maximum optimization, visualized in Equation 1:
\[ \text{Eager} \xrightarrow{\text{tracing}} \text{JIT} \xrightarrow{\text{AOT}} \text{Static Graph} \xrightarrow{\text{synthesis}} \text{Custom Hardware} \tag{1}\]
Each step rightward sacrifices flexibility for performance. The practical question is where on this continuum a given project should operate. The optimal compilation strategy depends on the ratio of development iterations to production executions (Equation 2):
\[ \text{Compilation Benefit} = \frac{N_{\text{prod}} \cdot (T_{\text{eager}} - T_{\text{compiled}})}{T_{\text{compile}} + N_{\text{dev}} \cdot T_{\text{compile}}} \tag{2}\]
Where:
- \(N_{\text{prod}}\) = number of production executions (dimensionless count: inference requests, training steps)
- \(N_{\text{dev}}\) = number of development iterations requiring recompilation (dimensionless count)
- \(T_{\text{eager}}\) = time per execution in eager mode (seconds)
- \(T_{\text{compiled}}\) = time per execution in compiled mode (seconds)
- \(T_{\text{compile}}\) = one-time compilation cost (seconds)
Decision Rule: Compile when \(\text{Compilation Benefit} > 1\). The ratio is dimensionless.
Table 4 provides representative throughput data across execution modes and model architectures:
| Model | Eager (img/sec) | torch.compile (img/sec) | TensorRT (img/sec) | Compile Time (seconds) |
|---|---|---|---|---|
| ResNet-50 | 1,450 | 2,150 | 3,800 | 15–30 |
| BERT-Base | 380 | 520 | 890 | 30–60 |
| ViT-B/16 | 620 | 950 | 1,650 | 25–45 |
| GPT-2 (124M) | 180 | 260 | 420 | 45–90 |
These throughput differences across execution modes raise a practical question—which framework execution strategy best serves each workload archetype.
Lighthouse 1.1: Framework Strategy by Archetype
| Archetype | Dominant iron law Term | Optimal Framework Strategy | Rationale |
|---|---|---|---|
| ResNet-50 | \(\frac{O}{R_{\text{peak}} \cdot \eta}\) (Compute) | TensorRT (inference) | Kernel fusion maximizes MFU; compute-bound |
| (Compute Beast) | torch.compile (training) | workloads benefit most from optimization | |
| GPT-2 | \(\frac{D_{\text{vol}}}{BW}\) (Memory Bandwidth) | torch.compile | Kernel fusion reduces HBM round-trips; |
| (Bandwidth Hog) | keeps data in cache to mitigate bandwidth | ||
| DLRM | \(\frac{D_{\text{vol}}}{BW}\) (Random Access) + | Eager with specialized kernels | Embedding lookups are inherently irregular |
| (Sparse Scatter) | \(T_{network}\) | (FBGEMM) | and dynamic; compilation gains are small |
| DS-CNN | \(L_{\text{lat}}\) (Overhead) | AOT compilation (TFLite, ONNX) | Sub-ms inference; every microsecond of |
| (Tiny Constraint) | Python overhead is unacceptable |
Key insight: Compilation benefits scale with how much of the workload is optimizable. Compute Beasts (Table 4: ResNet-50 sees 2.6\(\times\) speedup from TensorRT) benefit most. Sparse Scatter workloads gain little because their bottleneck (embedding lookups) is inherently irregular.
This principle has concrete implications across three regimes. In research prototyping (\(N_{\text{dev}} \gg N_{\text{prod}}\)), teams should stay eager. If the architecture changes every few minutes, compilation overhead dominates. A 30-second compile time with ten iterations/hour means five minutes lost to compilation per hour, often more than the runtime savings.
For training runs (\(N_{\text{prod}} \gg N_{\text{dev}}\)), compilation pays off. A typical training run executes millions of forward/backward passes, so even 60 seconds of compilation amortizes to microseconds per step. From Table 4, torch.compile provides ~48 percent speedup on ResNet-50 (2,150 vs. 1,450 img/sec); this pays off after the breakeven point in Equation 3:
\[ N_{\text{breakeven}} = \frac{T_{\text{compile}}}{T_{\text{eager}} - T_{\text{compiled}}} = \frac{30\text{s}}{(1/1450 - 1/2150)\text{s/img}} \approx 134{,}000 \text{ images} \tag{3}\]
For ImageNet (1.28M training images), compilation pays off within the first epoch.
For production inference (\(N_{\text{dev}} \approx 0\), \(N_{\text{prod}} \rightarrow \infty\)), teams should maximize compilation. With no development iterations and potentially millions of requests, every optimization matters. Using mode='max-autotune' despite hour-long compilation is worthwhile because the cost is amortized over the deployment lifetime.
These three regimes create distinct regions in the compilation decision space. Figure 7 maps out these regions so engineers can identify where each strategy wins. Watch for the crossover points: the steep eager line (highest per-execution cost) eventually overtakes JIT’s moderate slope, while the gentlest compiled line (lowest per-execution cost but largest upfront investment) wins only after millions of executions. The slopes reveal per-execution cost; the vertical offsets reveal compilation overhead. A project’s position on the x-axis determines which line it should be on.
The dispatch overhead law
A second principle emerges from the Dispatch Overhead Equation (Equation 4): when does framework overhead, rather than compute or memory, dominate execution time? Let \(N_{\text{ops}}\) be the number of operations (count), \(t_{\text{dispatch}}\) the per-operation dispatch overhead (seconds), and \(T_{\text{compute}}\) and \(T_{\text{memory}}\) the total compute and memory times (seconds). Framework overhead dominates when operations are small relative to dispatch cost:
\[ \text{Overhead Ratio} = \frac{N_{\text{ops}} \cdot t_{\text{dispatch}}}{T_{\text{compute}} + T_{\text{memory}}} \tag{4}\]
When Overhead Ratio \(> 1\), the model is overhead-bound. Compilation provides maximum benefit for overhead-bound workloads because it eliminates per-operation dispatch.
From the case study in Section 1.9, we can quantify this effect.
This cumulative latency creates what is effectively a dispatch tax on execution. We define \(T_{\text{hw}}\) as hardware execution time and \(T_{\text{sw}}\) as software overhead time; both are measured in seconds.
Napkin Math 1.2: The Dispatch Tax
Problem: When does Python overhead kill performance?
Scenario one: Small multilayer perceptron (MLP) (Overhead Bound)
- Compute: 6 small matrix/element-wise operations.
- Hardware Time: T_hw ≈ 2.6 μs (mostly memory latency).
- Software Overhead: T_sw ≈ 6 ops\(\times\) 5.0 μs/op = 30 μs.
- Ratio: 30 / 2.6 ≈ 11.5.
- Conclusion: The system spends 92 percent of time waiting for Python. Compilation yields 13\(\times\) speedup.
Scenario two: GPT-3 Layer (Compute Bound)
- Compute: Huge matrix multiplications.
- Hardware Time: T_hw ≈ 100 ms = 100000.0 μs.
- Software Overhead: \(T_{sw} \approx 50.0 \, \mu s\).
- Ratio: 50.0 / 100000.0 ≈ 0.0005.
- Conclusion: Python overhead is negligible. Compilation helps only via kernel fusion (memory bandwidth), not dispatch elimination.
The principle’s implication is that small models benefit disproportionately from compilation. A 100-parameter toy model might see 10\(\times\) speedup from torch.compile, while a 175 B-parameter model sees only 1.3\(\times\). This explains why compilation matters most for efficient inference on smaller, deployed models.
The dispatch tax analysis reveals that small operations are overhead-bound regardless of hardware capability. This observation matters most at the extreme edge of the deployment spectrum, where the entire Python runtime is itself an unacceptable overhead.
Frameworks for the edge: TinyML and micro-runtimes
The compilation continuum reaches its extreme at the far edge. While cloud frameworks like PyTorch and TensorFlow two.x prioritize flexibility through eager execution, TinyML19 systems operating on microcontrollers (MCUs) with kilobytes of memory cannot afford the overhead of a Python interpreter or a dynamic runtime.
19 TinyML: Systems designed for microcontrollers (MCUs) that cannot afford the memory or processing overhead of a Python interpreter. Instead of flexible eager execution, frameworks compile models ahead-of-time (AOT) into self-contained C/C++ executables with no dynamic memory allocation. This is a hard requirement, as a single malloc() failure on a device with just 256 KB of RAM is unrecoverable.
Lighthouse 1.2: Lighthouse Example: Smart Doorbell (TinyML)
The Constraint: A standard PyTorch runtime occupies ~500 MB. The Python interpreter itself occupies ~20 MB. Both are orders of magnitude larger than the entire device.
The Framework Solution: Micro-frameworks like TensorFlow Lite Micro (TFLM) and PyTorch ExecuTorch solve this through Extreme AOT Compilation:
- Static memory planning: The framework calculates the exact memory address for every tensor at compile time. There is no dynamic
malloc()or garbage collection. - Kernel specialization: Only the specific kernels used by the model (for example, Conv2D, DepthwiseConv) are compiled into the binary. Unused code is stripped away.
- No-interpreter execution: The model is converted into a flat sequence of function calls or a simple “Command Buffer” that the MCU executes directly in C/C++.
The Silicon Contract: On TinyML devices, the contract is strictly Memory-Bound. The framework’s primary job is to ensure the model’s intermediate activations (the “working set”) fit within the MCU’s tiny SRAM.
These micro-runtimes represent the “Pure AOT” endpoint of the continuum. By sacrificing all dynamic flexibility, they enable machine learning to run on devices consuming milliwatts of power, fulfilling the Energy-Movement Invariant (formalized in Data Engineering) by keeping all data movement local to the chip.
The spectrum of execution strategies, from dynamic eager execution to static graph compilation and specialized micro-runtimes, requires developers to make deliberate trade-offs. The following checkpoint summarizes the key decision points before we address the second core problem.
Checkpoint 1.1: Execution Models
The choice of execution mode determines both developer velocity and model performance.
Debuggability vs. Speed
The Modern Compromise
The execution problem determines when computation happens and what optimizations are possible. Neural network training, however, requires a capability that no amount of clever scheduling can provide: the ability to compute gradients automatically.
Consider what training actually requires: for each of millions of parameters, compute how a tiny change would affect the loss. Doing this manually for even a simple three-layer network requires deriving and implementing dozens of partial derivatives. For a modern transformer with billions of parameters, manual differentiation is economically impossible. A framework that executes efficiently but cannot differentiate can run inference but cannot learn.
Differentiation Problem
The differentiation problem asks: how should frameworks compute gradients20 automatically? Neural network training requires derivatives of a scalar loss \(\mathcal{L}\) with respect to millions or billions of parameters, making manual differentiation impractical. Because a single scalar loss depends on all parameters, reverse-mode automatic differentiation (AD)21 is the optimal strategy: one backward pass computes all parameter gradients simultaneously, while forward mode would require a separate pass for each parameter. All major ML frameworks therefore implement reverse-mode AD by default (Baydin et al. 2018).
20 Automatic Differentiation (AD): The “automatically” in the triggering sentence is the key word: AD mechanizes the chain rule as a graph traversal, eliminating the manual derivative computation that made scaling beyond toy networks impractical. The systems trade-off that makes this feasible is the choice of reverse mode, which exploits the many-to-one topology of training (many parameters, one scalar loss) to compute all gradients in a single backward pass. Forward mode would require one pass per parameter, making billion-parameter training computationally impossible.
21 Reverse-Mode AD: The \(O(1)\)-vs.-\(O(N)\) asymmetry mentioned earlier has a concrete price: reverse mode must store every intermediate value from the forward pass for use during the backward traversal. For a billion-parameter transformer, these stored activations can consume 3–4\(\times\) more memory than the weights themselves. This memory cost is the reason frameworks provide activation checkpointing and gradient accumulation, techniques that trade recomputation time for the memory that reverse-mode AD demands.
Building on the backpropagation algorithm introduced in Neural Computation (where we established that gradients flow backward through the computation graph via the chain rule), this section shifts focus from the mathematics to the systems engineering of differentiation: how frameworks represent computation graphs, manage memory for intermediate values, and orchestrate the backward pass efficiently across accelerators. The framework’s role is not to perform calculus but to manage the bookkeeping at scale, which is required for the training algorithms detailed in Model Training. Listing 15 illustrates the core idea with a simple three-operation function:
def f(x):
a = x * x # Square
b = sin(x) # Sine
return a * b # ProductFrameworks decompose this function into elementary operations, each with a known local derivative, and then combine these local derivatives via the chain rule to compute gradients through arbitrary compositions. The systems challenge is implementing this efficiently: the framework must record the computation graph during the forward pass, store intermediate values, and execute the backward pass with minimal memory overhead. The following subsections trace how production frameworks solve each of these problems.
Forward and reverse mode differentiation
Two primary approaches to automatic differentiation exist, and the choice between them (forward mode vs. reverse mode) determines whether gradient computation scales with the number of inputs or the number of outputs, a distinction that explains why neural network training universally uses one mode over the other.
Forward mode
Neural network training universally uses reverse mode (covered next), but forward mode illuminates why reverse mode is necessary. Forward mode automatic differentiation computes derivatives alongside the original computation, tracking how changes propagate from input to output. This approach mirrors manual derivative computation, making it intuitive to understand and implement.
Forward mode’s memory requirements are its strength: the method stores only the original value, a single derivative value, and temporary results. Memory usage stays constant regardless of computation depth, making forward mode particularly suitable for embedded systems, real-time applications, and memory-bandwidth-limited systems. However, this comes with a computational cost. Forward mode doubles the Ops term (in iron law terms) for each input parameter whose derivative is requested. For a model with \(N\) parameters, forward mode multiplies total computation by \(N\), because each parameter requires a separate forward pass. Reverse mode, by contrast, adds a constant factor of approximately 2–3\(\times\) regardless of \(N\). This asymmetry explains why forward mode is never used for training neural networks, where \(N\) ranges from millions to hundreds of billions. This combination of computational scaling with input count but constant memory creates a specific niche: forward mode excels in scenarios with few inputs but many outputs, such as sensitivity analysis, feature importance computation, and online learning with single-example updates.
To see the mechanism concretely, consider computing both the value and derivative of \(f(x) = x^2 \sin(x)\). Listing 16 shows how forward mode propagates derivative computations alongside every operation, applying the chain rule and product rule at each step:
def f(x): # Computing both value and derivative
# Step 1: x -> x²
a = x * x # Value: x²
da = 2 * x # Derivative: 2x
# Step 2: x -> sin(x)
b = sin(x) # Value: sin(x)
db = cos(x) # Derivative: cos(x)
# Step 3: Combine using product rule
result = a * b # Value: x² * sin(x)
dresult = a * db + b * da # Derivative: x²*cos(x) + sin(x)*2x
return result, dresultForward mode achieves this systematic derivative computation by augmenting each number with its derivative value, creating what mathematicians call a “dual number.” Listing 17 traces a concrete execution with x = 2.0, revealing how each intermediate result carries both its value and derivative through the computation:
x = 2.0 # Initial value
dx = 1.0 # We're tracking derivative with respect to x
# Step 1: x²
a = 4.0 # (2.0)²
da = 4.0 # 2 * 2.0
# Step 2: sin(x)
b = 0.909 # sin(2.0)
db = -0.416 # cos(2.0)
# Final result
result = 3.636 # 4.0 * 0.909 = 3.636
dresult = (
1.972 # 4.0 * (-0.416) + 0.909 * 4.0 = -1.664 + 3.636 = 1.972
)The dual number trace demonstrates the 2\(\times\) computational overhead per input: every arithmetic operation (multiply, sine, product rule combination) is performed twice, once for the value and once for the derivative. For this single-input function, the overhead is acceptable. For a neural network with \(N = 100{,}000{,}000\) parameters, computing all gradients would require 100 million such passes, which is why forward mode is restricted to the few-input applications described earlier.
Forward mode’s strength in single-input analysis becomes its fatal weakness for training. A neural network has one scalar loss but millions of parameters, and forward mode would require a separate pass for each one—an intractable \(O(N)\) cost that explains why no production framework uses forward mode for training. Forward mode remains useful for targeted analyses such as sensitivity analysis (how does changing one pixel affect the prediction?) and feature importance (which input dimensions most influence the output?), where the number of inputs of interest is small.
Given forward mode’s \(O(N)\) scaling with parameter count, we need an entirely different approach for training. Reverse mode provides exactly this: by propagating gradients backward from output to input, it computes all \(N\) parameter gradients in a single pass.
Reverse mode
Why does every modern ML framework default to reverse mode for training? The answer is computational asymmetry, one of the most consequential design decisions in machine learning software.
A neural network has one scalar loss but millions of parameters. Forward mode computes one parameter’s gradient per pass, requiring \(n\) passes for \(n\) parameters. Reverse mode computes all \(n\) gradients in a single backward pass. For a model with 100 million parameters, that is the difference between 100 million forward passes and exactly one backward pass, a speedup proportional to the parameter count.
This asymmetry makes reverse mode the only viable option for training. Consider a function where \(x\) influences the output through two distinct paths. Listing 18 defines such a function, and Listing 19 traces its forward and backward computation for a concrete input.
def f(x):
a = x * x # First operation: square x
b = sin(x) # Second operation: sine of x
c = a * b # Third operation: multiply results
return c# --- Forward pass: compute and store values ---
x = 2.0 # Input value
a = 4.0 # x * x = 2.0 * 2.0 = 4.0
b = 0.909 # sin(2.0) ≈ 0.909
c = 3.637 # a * b = 4.0 * 0.909 ≈ 3.637
# --- Backward pass: propagate gradients from output ---
dc/dc = 1.0 # Seed gradient
# Through multiplication c = a * b
dc/da = b # ∂(a*b)/∂a = b = 0.909
dc/db = a # ∂(a*b)/∂b = a = 4.0
# Combine contributions from both paths through x
# Path 1: x -> x² -> c contribution: 2x * dc/da
# Path 2: x -> sin(x) -> c contribution: cos(x) * dc/db
dc/dx = (2 * x * dc/da) + (cos(x) * dc/db)
= (2 * 2.0 * 0.909) + (cos(2.0) * 4.0)
= 3.636 + (-0.416 * 4.0)
= 1.972 # 3.636 - 1.664 = 1.972The critical observation is that this single backward pass computed dc/dx regardless of how many paths connect x to c. In a neural network, each weight can affect the loss through thousands of paths across layers, and reverse mode handles them all in one traversal. This is why training a 175 B parameter model like GPT-3 is feasible at all: reverse mode’s O(1) backward passes (relative to parameter count) keeps gradient computation tractable.
Translating this mathematical elegance into a working system requires solving a concrete engineering problem: the backward pass needs values computed during the forward pass, so the framework must decide what to store, when to store it, and when to free it. Modern frameworks accomplish this through computational graphs and automatic gradient accumulation22.
22 Gradient Accumulation: A direct answer to the “when to free it” question: the framework breaks a large logical batch into smaller mini-batches processed sequentially, freeing activation memory after each mini-batch’s backward pass and accumulating only the small gradient tensors. This lets a system simulate a batch size of 4,096 using the memory footprint of a 64-sample batch, trading sequential compute time for a 60\(\times\) reduction in peak activation memory. Without this technique, many production training configurations would exceed accelerator memory on the first batch.
Listing 20 illustrates this with a two-layer network, showing both the forward computation that stores intermediate values and the backward pass that consumes them to produce gradients for every parameter simultaneously.
def simple_network(x, w1, w2):
hidden = x * w1 # First layer
activated = max(0, hidden) # ReLU activation
output = activated * w2 # Second layer
return output
# --- Forward pass stores intermediates ---
# x=1.0, w1=2.0, w2=3.0
# hidden=2.0, activated=2.0, output=6.0
# --- Backward pass consumes them ---
d_output = 1.0 # Seed gradient
d_w2 = activated # = 2.0
d_activated = w2 # = 3.0
d_hidden = d_activated * (1 if hidden > 0 else 0) # ReLU gate: 3.0
d_w1 = x * d_hidden # = 3.0
d_x = w1 * d_hidden # = 6.0Three implementation requirements emerge from this example. First, the framework must track dependencies between operations to determine the correct reverse traversal order. Second, intermediate values (hidden, activated) must persist in memory until the backward pass consumes them. Third, every operation needs both a forward implementation and a corresponding backward rule. These requirements define the engineering surface of any AD system, and the second requirement, memory persistence, turns out to be the dominant cost.
Memory management strategies
A 175 B parameter model in FP16 requires 350 GB just for weights, far exceeding any single GPU’s memory. Weights, however, are only the beginning: reverse mode AD also stores every intermediate activation from the forward pass for use during the backward pass. For a 100-layer network processing a batch of 64 images, these stored activations can consume 8–12 GB on top of the model weights, gradients, and optimizer state. Memory, not compute, is the binding constraint on what models a framework can train.
The problem scales linearly with depth. Listing 21 shows how each layer in a deeper network adds another activation tensor that must persist until the backward pass reaches that layer.
def deep_network(x, w1, w2, w3):
# Forward pass - must store intermediates
hidden1 = x * w1
activated1 = max(0, hidden1) # Store for backward
hidden2 = activated1 * w2
activated2 = max(0, hidden2) # Store for backward
output = activated2 * w3
return outputFrameworks attack this memory wall with two primary strategies. The first is activation checkpointing (also called gradient checkpointing): rather than storing every activation, the framework stores only selected checkpoints and recomputes the intermediate activations during the backward pass. Model Training examines checkpointing strategies in detail, including optimal checkpoint placement algorithms. Listing 22 shows the pattern: save activations at checkpoint boundaries, recompute everything between them.
# Conceptual representation of checkpointing
checkpoint1 = save_for_backward(activation1)
# Intermediate activations can be recomputed
checkpoint2 = save_for_backward(activation4)
# Framework balances storage vs recomputationThe second strategy is operation fusion23. Rather than executing matrix multiplication, bias addition, and ReLU as three separate operations, each writing intermediate results to memory, frameworks fuse them into a single kernel. This eliminates intermediate memory allocations entirely and achieves 2–3\(\times\) speedup on modern GPUs by keeping data in registers and caches.
23 Operation Fusion: The 2–3\(\times\) speedup cited in the triggering sentence arises from a specific hardware fact: GPU registers and L1 cache deliver 10–100\(\times\) higher bandwidth than HBM. When matmul, bias, and ReLU execute as separate kernels, each writes its output to HBM and the next reads it back, a round-trip that dominates execution time for memory-bound operations. Fusing them into one kernel keeps intermediates in registers, converting three HBM round-trips into zero.
The backward pass itself benefits from hardware-specific optimization. Rather than directly translating the mathematical definition of a convolution gradient into code, frameworks implement specialized backward kernels that exploit memory access patterns and hardware capabilities of modern accelerators (Chetlur et al. 2014). These optimizations, checkpointing, fusion, and specialized kernels, work together to make training practical for architectures that would otherwise exhaust GPU memory in a single forward pass.
Framework implementation of automatic differentiation
Checkpointing, fusion, and specialized kernels solve the systems problems of AD. Practitioners, however, never interact with these mechanisms directly. Instead, frameworks expose AD through high-level APIs that hide the underlying machinery behind simple method calls. A PyTorch training loop—optimizer.zero_grad(), forward pass, loss.backward(), optimizer.step()—appears to be four function calls. Behind each call, however, the framework tracks all operations during the forward pass, builds and maintains the computation graph, manages memory for intermediate values, schedules gradient computations efficiently, and interfaces with hardware accelerators. The same graph machinery extends to advanced scenarios: nested torch.autograd.grad calls compute second-order derivatives for techniques like natural gradient descent, and mixed-precision contexts (autocast) select reduced-precision kernels for compute-intensive operations while maintaining FP32 for numerical stability.
PyTorch autograd internals
The autograd system is the framework component that solves the differentiation problem described in Three Framework Problems. Three systems principles govern its design: the data structure that enables efficient gradient computation, the memory cost of maintaining that data structure, and the control mechanisms that production systems require. Understanding these principles explains why training consumes 100\(\times\) more memory than inference for the same model, and why frameworks provide specific mechanisms to manage that cost.
Principle 1: The reverse-linked graph structure
During the forward pass, the autograd system constructs a reverse-linked graph of Function nodes. Each node records the operation performed and stores references to the tensors it needs for gradient computation. This graph is the data structure that makes reverse-mode automatic differentiation possible: regardless of how many parameters a model has, a single backward pass through this graph computes all gradients. For a model with \(N\) parameters, reverse-mode AD requires \(O(1)\) backward passes (compared to \(O(N)\) for forward-mode), which is why every major framework implements this approach.
Concretely, every tensor produced by a differentiable operation stores a grad_fn attribute pointing to the Function that created it. Each Function links to its inputs through next_functions, forming a chain from the loss back to the leaf parameters. Listing 23 illustrates this structure for a simple computation:
grad_fn links to the Function that created it, forming a reverse chain from output to leaf parameters that enables O(1) backward passes.
import torch
x = torch.tensor([2.0], requires_grad=True)
y = x * 3
z = y.pow(2)
# Traverse the reverse-linked graph
print(z.grad_fn) # PowBackward0
print(z.grad_fn.next_functions) # -> MulBackward0
print(
z.grad_fn.next_functions[0][0].next_functions
) # -> AccumulateGrad (leaf)The traversal reveals the chain: PowBackward0 (for z = y**2) links to MulBackward0 (for y = x * 3), which terminates at AccumulateGrad for the leaf tensor x. Leaf tensors are the endpoints of the graph where gradients accumulate into the .grad attribute rather than propagating further. The tuple format (Function, index) tracks which output of a multi-output operation each connection corresponds to.
This reverse-linked structure has a critical systems implication: the entire graph must remain in memory from the time a tensor is created until the backward pass consumes it. The graph itself is lightweight (pointers and metadata), but the tensors it references are not.
The graph structure thus introduces a second implication: memory consumption scales with model depth.
Principle 2: The memory-compute trade-off
Every activation saved for the backward pass persists in memory until consumed by gradient computation. This is the primary reason training memory dwarfs inference memory. Computing the gradient of most operations requires values from the forward pass: multiplication needs both inputs (\(\frac{\partial}{\partial x}(x \cdot y) = y\)), exponentiation needs the base (\(\frac{\partial}{\partial x}(x^2) = 2x\)), and softmax needs its output values. The autograd system stores these tensors in each Function node’s saved_tensors attribute.
For a network with \(N_L\) layers, the system must save approximately \(N_L\) activation tensors, one per layer, for the entire batch. Consider a concrete example: ResNet-50 has 25.6 M parameters (~102 MB in FP32) and processes batch size 64 with \(224\times224\) images. The memory breakdown reveals the scale of this trade-off. Forward activations alone consume approximately 8–12 GB (varying by implementation and checkpointing strategy). Parameter gradients add another ~102 MB (the same size as the parameters themselves), and Adaptive Moment Estimation (Adam) optimizer state contributes ~205 MB for its two momentum buffers per parameter. The total training footprint reaches 10–15 GB, compared to just ~102 MB for inference alone.
This 100\(\times\) ratio between training and inference memory quantifies why the Data Movement (\(D_{\text{vol}}\)) term dominates training latency in the iron law. During training, the framework must write all activations to memory during the forward pass and read them back during the backward pass, doubling the memory traffic compared to inference alone. For a complete derivation of the four-component training memory equation (\(M_{total} = M_{weights} + M_{gradients} + M_{optimizer} + M_{activations}\)) and worked examples at larger model scales, see The true cost of training memory.
Frameworks provide two primary mechanisms to manage this trade-off. Gradient checkpointing (Chen et al. 2016) trades recomputation for memory: instead of saving all activations, the framework saves only a subset and recomputes the rest during the backward pass. This typically reduces activation memory by 50–90 percent at the cost of 20–33 percent additional compute (with optimal \(\sqrt{n}\) checkpoint placement). In iron law terms, checkpointing increases the \(O\) term (recomputation) to reduce the \(D_{\text{vol}}\) term (memory traffic). Tensor detachment provides a complementary mechanism: calling .detach() on a tensor removes it from the computation graph entirely, preventing the framework from saving activations through that path. This is essential for transfer learning, where pretrained layers should not accumulate gradients, and reduces the \(D_{\text{vol}}\) term by eliminating unnecessary activation storage.
Mixed-precision training offers a third approach, reducing activation memory by storing values in lower precision formats. The detailed trade-offs of mixed precision are examined later in this chapter.
Principle 3: Extensibility and control
Production training systems require fine-grained control over gradient flow that goes beyond the default backward pass. Three categories of control arise in practice. First, selective gradient computation: transfer learning and fine-tuning require freezing subsets of parameters, which the framework supports through requires_grad=False flags and the .detach() mechanism described earlier. Second, gradient inspection and modification: debugging vanishing or exploding gradients, implementing per-tensor gradient clipping, and logging gradient statistics all require intercepting gradients mid-computation, which frameworks expose through hook APIs. Third, custom differentiation rules: operations not in the framework’s built-in library (custom CUDA kernels, novel activation functions, domain-specific operations) require user-defined forward and backward implementations.
These control mechanisms share a common systems design: they are callback-based extensions that the autograd engine invokes at specific points during graph traversal, without modifying the core differentiation algorithm. This extensibility pattern allows the framework to maintain a single optimized backward pass while supporting arbitrarily complex gradient manipulation. The following examples demonstrate these mechanisms in practice, showing how to inspect and control PyTorch’s autograd system.
Retaining the computation graph
By default, backward() frees the graph after use. To run multiple backward passes (for multi-loss optimization or higher-order derivatives), use retain_graph=True at the cost of doubled memory, as shown in Listing 24.
x = torch.tensor([2.0], requires_grad=True)
y = x**2
# First backward pass - graph is freed by default
y.backward()
print(x.grad) # tensor([4.])
# Second backward on SAME y fails - graph was freed
# y.backward() # RuntimeError: graph already freed!
# Solution: retain_graph=True keeps graph for multiple passes
x.grad.zero_()
y = x**2
y.backward(retain_graph=True) # First pass, keep graph
y.backward() # Second pass works, graph freed after thisGradient accumulation behavior
Gradients accumulate across backward passes by default. As Listing 25 demonstrates, without calling zero_grad(), successive backward passes sum their gradients:
x = torch.tensor([1.0], requires_grad=True)
# First backward pass
y = x * 2
y.backward()
print(x.grad) # tensor([2.])
# Second backward pass (without zero_grad)
y = x * 3
y.backward()
print(x.grad) # tensor([5.]) = 2 + 3 (accumulated!)
# Reset gradients
x.grad.zero_()
y = x * 3
y.backward()
print(x.grad) # tensor([3.])Custom autograd functions
When implementing custom operations, the developer explicitly specifies what to save for the backward pass and how to compute gradients. Listing 26 shows the pattern:
class MultiplyAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, z):
# Save tensors needed for backward
ctx.save_for_backward(x, y)
return x * y + z
@staticmethod
def backward(ctx, grad_output):
# Retrieve saved tensors
x, y = ctx.saved_tensors
# Compute gradients using chain rule
grad_x = grad_output * y # ∂L/∂x = ∂L/∂out * ∂out/∂x
grad_y = grad_output * x # ∂L/∂y = ∂L/∂out * ∂out/∂y
grad_z = grad_output # ∂L/∂z = ∂L/∂out * 1
return grad_x, grad_y, grad_z
# Usage
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)
z = torch.tensor([1.0], requires_grad=True)
output = MultiplyAdd.apply(x, y, z)
output.backward()
print(
x.grad, y.grad, z.grad
) # tensor([3.]), tensor([2.]), tensor([1.])Gradient hooks
Register hooks on tensors to inspect or modify gradients during backpropagation, as shown in Listing 27:
def gradient_hook(grad):
print(f"Gradient: {grad}")
# Modify gradient (e.g., gradient clipping)
return grad.clamp(-1.0, 1.0)
x = torch.tensor([2.0], requires_grad=True)
x.register_hook(gradient_hook)
y = x * 10
y.backward()
# Prints: Gradient: tensor([10.])
# x.grad contains clamped value: tensor([1.])Detach vs. data
Use .detach() to safely break gradient flow. Listing 28 illustrates how the legacy .data attribute can silently corrupt gradient computation through in-place operations:
.detach() to safely break gradient flow. The legacy .data attribute can silently corrupt gradients through in-place operations.
x = torch.tensor([1.0], requires_grad=True)
y = x * 2
# SAFE: .detach() creates a new tensor that shares storage
# but is not part of the computation graph
z_safe = y.detach()
z_safe.mul_(100) # In-place op on detached tensor
# y's data IS modified (shared storage), but autograd graph is intact
# DANGEROUS: .data bypasses autograd entirely
# In-place modifications corrupt the computation graph
z_unsafe = y.data
z_unsafe.mul_(100) # This modifies y's underlying storage!
# y.backward() now computes wrong gradients
# Best practice: always use .detach() for inference
with torch.no_grad():
inference_output = model(x).detach()These three principles connect directly to the framework’s role as a compiler for the Silicon Contract. The reverse-linked graph determines which operations the backward pass must execute (the \(O\) term). The memory-compute trade-off governs how much data the framework must move through the memory hierarchy (the \(D_{\text{vol}}\) term). And the extensibility mechanisms allow engineers to tune both terms for their specific workload. The interaction between autograd memory management and numerical precision leads naturally to mixed-precision training, which further reduces the \(D_{\text{vol}}\) term.
Mixed-precision training support
Mixed precision exploits a hardware asymmetry to improve two iron law terms simultaneously: Tensor Cores execute FP16 matrix multiplications at 2\(\times\) the throughput of FP32 (increasing effective \(O/R_{\text{peak}}\)), while FP16 activations halve the memory footprint (reducing \(D_{\text{vol}}\)). Improving both terms simultaneously is rare; most optimizations improve one at the expense of the other.
Frameworks exploit this through automatic mixed-precision APIs that select reduced precision for compute-intensive operations while maintaining FP32 where numerical stability demands it. Inside these APIs, frameworks automatically apply precision rules: matrix multiplications and convolutions use FP16 for bandwidth efficiency, while numerically sensitive operations like softmax and layer normalization remain in FP32. This selective precision maintains accuracy while achieving speedups on modern GPUs with specialized hardware units. Because FP16 has a narrower dynamic range than FP32, gradients can underflow to zero during backpropagation. Loss scaling addresses this by multiplying the loss by a large factor before the backward pass, then dividing gradients by the same factor afterward.
Frameworks also support multiple precision formats including FP16, BF1624, and TF32, each with different trade-offs between range and precision. BF16 maintains FP32’s dynamic range, simplifying training by eliminating most gradient underflow issues and removing the need for loss scaling entirely. Model Training examines the mechanics of mixed-precision training in detail, including loss scaling algorithms, memory savings analysis, and numerical stability considerations. Listing 29 demonstrates PyTorch’s mixed precision API: the autocast context manager automatically selects FP16 for compute-intensive operations while GradScaler prevents gradient underflow by dynamically scaling loss values.
24 BFloat16 Design Rationale: Developed by Google Brain circa 2018 specifically for TPU training stability, BF16 preserves FP32’s eight-bit exponent range while halving memory footprint—an explicit trade-off of mantissa precision (7 bits vs. FP16’s 10) for dynamic range. The critical consequence is loss scaling elimination: FP16’s five-bit exponent causes gradient underflow for values below \(6 \times 10^{-5}\), requiring manual loss scaling to keep gradients in range. BF16’s FP32-matched exponent makes this entire class of training instability impossible, which is why BF16 and FP16 are not interchangeable: BF16 is preferred when training stability matters; FP16 is preferred when numerical precision matters more than gradient stability.
import torch
from torch.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler("cuda")
for inputs, targets in dataloader:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
# Framework automatically selects precision per operation
with autocast(device_type="cuda", dtype=torch.float16):
outputs = model(inputs)
loss = criterion(outputs, targets)
# GradScaler handles gradient scaling for numerical stability
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()BF16 training typically does not require loss scaling, as Listing 30 demonstrates.
# BF16 training typically does not require loss scaling
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward() # No GradScaler needed
optimizer.step()Resuming training after interruption requires restoring model weights and optimizer state together: momentum buffers, adaptive learning rates, and gradient statistics. For Adam, optimizer state typically quintuples the memory footprint beyond weights alone (since two FP32 states are stored for each FP16 parameter), meaning a 7B-parameter model requires approximately 70 GB total (14 GB weights + 56 GB optimizer state). Checkpoint size therefore bounds recovery speed after failure, connecting fault tolerance directly to the iron law’s \(D_{\text{vol}}\) term.
Model Training covers optimizer memory requirements and optimization strategies for large-scale training, where checkpoint size becomes a binding constraint. Frameworks provide the state_dict() interface to access optimizer state for serialization (Listing 31), and resuming training requires loading both model parameters and optimizer state (Listing 32).
import torch
import torch.nn as nn
import torch.optim as optim
model = nn.Linear(10, 5)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# After training steps, optimizer accumulates state
loss = model(torch.randn(3, 10)).sum()
loss.backward()
optimizer.step()
# Access state for checkpointing
state = optimizer.state_dict()
# Contains: {'state': {...}, 'param_groups': [{'lr': 0.001, ...}]}# Saving checkpoint
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
torch.save(checkpoint, "checkpoint.pt")
# Resuming training
checkpoint = torch.load("checkpoint.pt")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])The mathematics of automatic differentiation were established decades before deep learning’s resurgence. What changed was the systems engineering. Before framework automation, implementing gradient computation for a single fully connected layer meant writing separate forward and backward functions, manually tracking intermediate values, and verifying mathematical correctness across dozens of operations. A modern transformer involves hundreds of operations with complex dependencies; manual gradient derivation for attention, layer normalization, and residual connections would require months of careful work per architecture variant.
The breakthrough was turning this manual process into software infrastructure. A single matrix multiplication requires different gradient computations depending on which inputs require gradients, tensor shapes, hardware capabilities, and memory constraints. Autograd systems handle these variations transparently, which is why the rate of architectural innovation accelerated after frameworks matured. The mathematics did not change; software engineering made the mathematics practical to apply at scale.
Memory management in gradient computation
The memory strategies from Section 1.3.1.2 (checkpointing, gradient accumulation) exist because reverse-mode differentiation requires preserving computational history. As Listing 21 demonstrated, each layer adds an activation tensor that persists until the backward pass consumes it, creating a memory wave that peaks at the start of backpropagation and recedes as gradients are computed. Modern frameworks track the lifetime of each intermediate value automatically, freeing memory as soon as it is no longer needed. Even with precise lifetime tracking, however, a deeper problem remains: the cost of acquiring memory from the GPU in the first place.
The cost of raw GPU memory allocation provides a critical engineering lesson: production systems require Memory Abstraction. Requesting memory directly from a GPU is a high-latency operation that can synchronize the entire device, creating an allocation bottleneck that stalls computation. To solve this, modern frameworks implement Caching Allocators. Instead of communicating with the hardware for every new tensor, the framework requests large blocks of memory upfront and manages its own internal pool. This abstraction is critical because it prevents memory fragmentation, the scenario where free memory is available but scattered in pieces too small to hold a large tensor, allowing models to push the physical limits of the hardware without constant system-level overhead.
Systems Perspective 1.2: Caching Allocator and Utilization
- Allocation Latency:
cudaMallocis a synchronous operation that costs 10–100 microseconds. In a training loop with thousands of operations per second, this latency would dominate execution time. The caching allocator pays this cost once, then serves subsequent requests in nanoseconds from its pool. - Fragmentation: A “Swiss cheese” memory pattern reduces Effective Capacity. If 10 GB is free but the largest contiguous block is 1 GB, a 2 GB tensor cannot be allocated. By binning allocations into standard sizes (powers of 2), the allocator ensures that freed memory can be reused for future requests, keeping Utilization high.
When “OOM” (Out of Memory) errors appear despite nvidia-smi showing free memory, fragmentation is often the culprit. The allocator cannot find a contiguous block large enough for the requested tensor.
Production system integration challenges
A training iteration that takes 300 ms in profiling may take 500 ms in production because the AD system must coordinate with the memory allocator, the device manager, the operation scheduler, and the optimizer on every single step. Each gradient computation can trigger data movement between CPU and GPU, memory allocation for intermediate tensors, and kernel launches on accelerators. These system interactions dominate wall-clock time for small models and remain significant even at scale. The gap between what the programmer writes (a five-line training loop) and what the system executes (dozens of memory allocations, kernel launches, and synchronization points) is the central tension of AD system design.
Beyond sequential overhead, the AD system must also exploit concurrency. Modern networks frequently contain independent branches—two convolutional paths processing the same input before merging, as in Inception-style architectures. On a GPU with sufficient resources, the framework’s scheduler can execute both branch backward passes on separate CUDA streams, reducing backward pass time by up to 30–40 percent. The AD system therefore tracks dependencies for two purposes: correctness (computing the right gradients) and performance (scheduling independent computations concurrently). Frameworks hide this complexity behind loss.backward(), but the scheduling, memory allocation, and data movement decisions behind that call determine whether training runs at 40 percent or 80 percent of peak hardware utilization.
The memory and system integration challenges examined earlier (caching allocators, activation storage, and checkpoint overhead) affect all frameworks. Yet how frameworks implement automatic differentiation in the first place varies significantly, with consequences for both optimization potential and developer experience. The distinction between tape-based and transform-based autodiff captures this architectural divergence.
Systems Perspective 1.3: Tape-based vs. Transform-based Autodiff
JAX (Transform-based): Treats automatic differentiation as a high-level function transformation (grad(f)). Because JAX sees the mathematical function before execution, it can easily chain other transformations like jit(grad(f)) or vmap(grad(f)), producing highly optimized, compiled kernels that often outperform dynamic frameworks on specialized hardware like TPUs.
JAX25 exemplifies the transform-based approach, where composable function transformations replace imperative tape recording.
25 JAX: The “transform-based” distinction matters because JAX’s grad, jit, and vmap are not library calls but algebraic transformations on pure functions, composable in any order. A chain like jit(grad(vmap(f))) compiles into a single XLA kernel because functional purity (no side effects, no mutation) lets the compiler reason about the entire program mathematically. The payoff is over 90 percent hardware utilization on TPUs; the cost is that any impurity (printing, mutation, unkeyed randomness) silently vanishes after the first trace.
How different frameworks implement AD
The execution models covered in Section 1.2, namely eager, static graph, and hybrid, directly shape how each framework implements automatic differentiation:
- PyTorch (Paszke et al. 2019) builds its autograd tape dynamically during forward execution, providing immediate debugging at the cost of graph-level optimization. The
grad_fnchain mechanism detailed in Section 1.3.2.1 enables flexible control flow but requires storing the complete graph until backward pass completion. - TensorFlow (in its one.x incarnation) performed symbolic differentiation during graph construction, enabling ahead-of-time optimization. Modern TensorFlow two.x uses eager execution by default but provides
tf.functionfor graph compilation when performance matters. - JAX (Frostig et al. 2018) transforms functions rather than tracking operations. The
jax.grad()transformation returns a new function that computes gradients, enabling composition withjax.vmap()for vectorization andjax.jit()for compilation. This approach requires pure functions but enables composable program transformations that chain differentiation, vectorization, and compilation in a single expression.
These implementation differences have direct practical consequences for framework selection, which Section 1.6 examines in detail.
A recurring tension runs through every AD design decision: mathematical correctness demands storing computational history, but hardware imposes strict memory limits. Every framework resolves this tension differently, choosing which activations to checkpoint, which operations to fuse, and how aggressively to trade recomputation for memory. These choices determine which models can train on which hardware, making AD system design one of the most consequential engineering decisions in any framework.
Checkpoint 1.2: The Systems Cost of Gradients
Training is inherently more expensive than inference because of Automatic Differentiation.
Computational Reality
Optimization Mechanics
The execution and differentiation problems together enable the training loop: the execution model determines when computation happens, while automatic differentiation computes the gradients that drive learning. Both problems, however, quietly assume something that cannot be taken for granted: that the same code can run across diverse hardware. A model trained on an NVIDIA A100 must serve inference on a mobile phone’s ARM CPU, a Google TPU, or a microcontroller with kilobytes of memory. The same torch.matmul call must dispatch to cuBLAS on one device and a hand-tuned ARM NEON kernel on another. This hardware diversity creates the third problem.
Abstraction Problem
The hardware diversity described earlier is not merely inconvenient; it is architecturally fundamental. A GPU offers 1,000\(\times\) the parallelism of a CPU but has different memory semantics. A TPU provides higher throughput but requires static shapes. A microcontroller has kilobytes where a server has gigabytes. The abstraction problem is precisely this: frameworks must hide this complexity behind a single programming interface while still enabling efficient utilization of each target’s unique capabilities.
The problem decomposes into two interacting dimensions. The first is data representation: how should frameworks represent tensors, parameters, and computational state in ways that work across hardware? The second is execution mapping: how should high-level operations translate to hardware-specific implementations? These dimensions are not independent concerns. The way data is represented (memory layout, precision, device placement) directly affects what execution strategies are possible. A tensor stored in row-major format on a GPU requires different kernels than one in column-major format on a CPU. A model quantized to INT8 enables entirely different execution paths than FP32.
Solving the abstraction problem requires sophisticated software infrastructure: tensor representations that encode both mathematical semantics and hardware constraints, intermediate representations that enable hardware-specific compilation, and runtime systems that manage data movement across the memory hierarchy.
To make this concrete, trace what must happen when a programmer writes model(input). The framework must answer five questions in rapid succession: What is the data? (tensor shape, memory layout, numeric precision), Where does it live? (device placement and the bandwidth hierarchy connecting CPU, GPU, and accelerator memory), How does it arrive fast enough? (data pipelines that sustain hundreds of MB/s to keep the accelerator fed), How does it scale beyond one device? (parameter synchronization and distributed execution contexts), and What actually runs on the hardware? (kernel dispatch, scheduling, and resource optimization). The following sub-sections answer these questions in order, building from the data container up to the hardware execution layer.
Data structures and tensor abstractions
A ResNet-50 forward pass touches 25.6 million parameters, produces intermediate activations at every layer, and must coordinate memory across CPU and GPU address spaces. How do frameworks organize all of this data so that a single Python call like model(input) executes millions of operations without the programmer managing a single pointer? Answering this question requires solving four problems in sequence: defining a universal data container (tensors), placing it on the right device (memory management), feeding data fast enough (data pipelines), and dispatching the right hardware kernel (core operations). We trace this path from data representation to hardware execution.
Computational graphs specify the logical flow of operations, but data structures determine how those operations access and manipulate data in physical memory. This distinction matters because the same mathematical operation can differ by an order of magnitude in throughput depending on whether data is contiguous in cache, pinned for DMA transfer, or scattered across pages.
The first step is the data container itself. Framework data structures must sustain memory bandwidth (hundreds of GB/s on modern GPUs), accommodate architectures from 1D sequences to 5D video tensors, and hide device management behind clean APIs. Tensors are the universal answer.
Tensors
At the foundation of every framework’s data representation lies a single abstraction: the tensor.
Definition 1.2: Tensor
Tensors are \(n\)-dimensional arrays with explicit shape, data type, and memory layout metadata that allow ML frameworks to map mathematical operations directly onto hardware vector units without intermediate data transformation.
- Significance (Quantitative): Tensor memory footprint is fully deterministic from its metadata: a contiguous FP32 tensor of shape \([1024, 1024]\) occupies exactly \(1024 \times 1024 \times 4 = 4\) MB. Non-contiguous layouts (for example, from a transpose operation) require explicit
.contiguous()calls before certain CUDA kernels can execute, adding a memory-copy overhead that can dominate the \(L_{\text{lat}}\) term for tensors under 1 MB. - Distinction (Durable): Unlike a Python list or generic NumPy array, a framework tensor carries device placement metadata (CPU vs. GPU), dtype (FP32, BF16, INT8), and stride information that enables zero-copy view operations and CUDA kernel dispatch without any runtime type checking or data movement.
- Common Pitfall: A frequent misconception is that tensor operations are always in-place. Framework tensor operations return new tensors by default, allocating fresh GPU memory for each intermediate result. In a long computation graph, these intermediate allocations accumulate and can exhaust GPU memory before any weights are updated.
Every computation in a neural network operates on tensors.26 Training batches, activation maps, parameter gradients, and optimizer states are all tensors. This unified representation lets frameworks optimize a single data structure for hardware rather than managing separate containers for each role.
26 Tensor: From Latin tendere (“to stretch”), coined in its mathematical sense by physicist Woldemar Voigt in 1898 for objects defined by how they transform under coordinate changes. ML inherited the term because framework tensors are similarly defined by transformation behavior: transposing changes strides but not data, reshaping changes metadata without moving bytes. This transformation-centric design is also the source of layout sensitivity: choosing NCHW when the target accelerator prefers NHWC (or vice versa) can halve computational throughput, because misaligned memory access patterns break hardware coalescing.
The tensor abstraction consumes far more memory than model weights alone suggest. Engineers who estimate memory from parameter count alone allocate accordingly and encounter out-of-memory errors that seem inexplicable. The following notebook quantifies what we call the administrative tax: the shadow tensors for gradients, optimizer momentum, and stored activations that accompany every weight tensor.
Napkin Math 1.3: The Administrative Tax
Problem: Why does GPU utilization drop when training small models?
The Math (The Hidden Tax):
Model Weights: 2 GB.
Gradients: 2 GB (same size as weights).
Optimizer States (Adam): 8 GB (\(2 \times \text{weights}\) for momentum and velocity in FP32).
Activations: For a batch size of 32 and a 100-layer network, the framework must store every intermediate layer output for the backward pass.
\[ \text{Activations} \approx \text{Batch} \times \text{Layers} \times \text{Width}^{2} \times 2 \text{ bytes} \] For a 1024-width model: \(32 \times 100 \times 1024^{2} \times 2 \approx \mathbf{6.7 \text{ GB}}\). (Each layer’s activation is a
Width$\times$ Widthmatrix per sample—appropriate for transformer-style models where intermediate projections scale with hidden dimension squared.)
The Systems Conclusion: A 2 GB model has an “Administrative Tax” of ~17 GB (2 GB gradients + 8 GB optimizer + 6.7 GB activations) before the first batch is even processed. During training, Data Movement includes saving and retrieving these activations, which is why training is often 3–4\(\times\) slower than pure inference.
Tensor structure and dimensions
A tensor generalizes scalars, vectors, and matrices to arbitrary dimensions. The hierarchy is straightforward: a scalar is a rank-0 tensor (single value), a vector is rank-1 (sequence of values), and a matrix is rank-2 (rows and columns). Higher ranks extend this pattern through nesting, so a rank-3 tensor is a stack of matrices—compare all four ranks side by side in Figure 8 to see how each level adds a new axis of organization.
This rank hierarchy maps directly onto ML data. A color image is a rank-3 tensor: height x width x three channels (red, green, blue). Figure 9 breaks this apart, stacking the three color channels illustrating how a single photograph becomes a three-layer numerical grid. Stacking a batch of \(N\) images adds a fourth dimension, producing a rank-4 tensor of shape \([N, 3, H, W]\). Every convolutional layer in a vision model consumes and produces tensors of exactly this shape, which is why the tensor abstraction is so central to framework design.
Framework tensors carry more than raw numbers. Each tensor stores metadata that the runtime uses to validate operations and select fast execution paths: a shape tuple (for example, [64, 3, 224, 224] for a batch of images), a dtype (float32, float16, int8), and a device tag (CPU, cuda:0). A matrix multiplication, for instance, checks shape compatibility at dispatch time and uses the dtype to route to the correct hardware kernel, whether a standard FP32 GEMM or a Tensor Core FP16 path.
Memory layout implementation introduces distinct challenges in tensor design. While tensors provide an abstraction of multi-dimensional data, physical computer memory remains linear. Stride patterns address this disparity by creating mappings between multi-dimensional tensor indices and linear memory addresses. These patterns significantly impact computational performance by determining memory access patterns during tensor operations. Figure 10 makes this concrete with a \(2\times3\) tensor: follow the same six values as they map into two different linear orderings—row-major and column-major—and note how the stride values change to compensate.
These memory layout patterns are crucial for framework performance optimization. Row-major layout (used by NumPy, PyTorch) stores elements row by row, making row-wise operations more cache-friendly. Column-major layout (used by some BLAS libraries) stores elements column by column, optimizing column-wise access patterns. The stride values encode this layout information: in row-major layout for a \(2\times3\) tensor, moving to the next row requires skipping three elements (stride[0]=3), while moving to the next column requires skipping one element (stride[1]=1).
These memory layout details have direct performance implications. When a convolution kernel accesses weight values, row-major layout means consecutive weights along the output channel dimension are contiguous in memory—enabling efficient vectorized loads. Column-major layout would scatter those same weights across memory, forcing slower gather operations. Careful alignment of stride patterns with hardware memory hierarchies maximizes cache efficiency and memory throughput, with optimal layouts achieving 80–90 percent of theoretical memory bandwidth (1.5–3.0 TB/s on modern data-center GPUs like the A100 and H100) compared to suboptimal patterns that may achieve only 20–30 percent utilization.
Tensor implementations use type systems to control numerical precision and memory consumption. The standard choice in machine learning has been 32-bit floating-point numbers (float32), offering a balance of precision and efficiency. Modern frameworks extend this with multiple numeric types for different needs. Integer types support indexing and embedding operations. Reduced-precision types like 16-bit floating-point numbers enable efficient mobile deployment. Eight-bit integers allow fast inference on specialized hardware.
The choice of numeric type affects both model behavior and computational efficiency. Neural network training typically requires float32 precision to maintain stable gradient computations. Inference tasks can often use lower precision (int8 or even int4), reducing memory usage and increasing processing speed. Mixed-precision training approaches combine these benefits by using float32 for critical accumulations while performing most computations at lower precision.
Type conversions between different numeric representations require careful management. Operating on tensors with different types demands explicit conversion rules to preserve numerical correctness. These conversions introduce computational costs and risk precision loss. Frameworks provide type casting capabilities but rely on developers to maintain numerical precision across operations.
Tensors answer the first question—what is the data?—by encoding shape, layout, and precision into a single abstraction. A perfectly shaped tensor on the wrong device, however, or one that must cross a 60\(\times\) bandwidth gap to reach the GPU, can erase every layout optimization. The next question is where data lives and how it moves.
Device and memory management
Tensors and their memory layouts establish what the framework computes with. Where that data physically resides, and how it moves between locations, determines whether computation happens at full speed or crawls.
Frameworks as the operating system interface
While the high-level API focuses on math, the framework’s backend functions as the Operating System of the Single-Machine Stack. It manages the two critical resources of a single node: compute scheduling and data movement.
The CUDA Runtime serves as this OS layer, providing the low-level primitives for launching kernels and managing device memory. The framework coordinates with this runtime to implement Direct Memory Access (DMA) over the PCIe bus. As established in Hardware Acceleration, the bandwidth gap between the host (CPU) and device (GPU) is the primary “Data Loading Bottleneck.” Frameworks mitigate this through pinned memory (page-locked memory) that allows the GPU to read directly from CPU RAM via DMA without interrupting the processor. This “HW/OS” interface is what makes high-throughput training loops possible on a single machine.
Every tensor resides on a specific device , and cross-device operations incur transfer costs that can dominate execution time. PCIe 4.0 delivers 32 GB/s between CPU and GPU, while HBM2e provides 2.0 TB/s within the GPU. This bandwidth gap, exceeding 60\(\times\), means a single misplaced tensor transfer can erase the entire speedup from GPU acceleration.
Why does this matter for framework design? Because the framework must track where every tensor lives and enforce that operations only combine tensors on the same device. When data must move, the framework must decide whether to block execution or overlap the transfer with other work. These decisions, invisible to most users, determine whether a training loop achieves 30 percent or 80 percent of theoretical hardware throughput.
Three systems principles govern effective device and memory management: understanding the bandwidth hierarchy that constrains data movement, overlapping computation with communication to hide transfer latency, and using fine-grained synchronization to maintain correctness without sacrificing concurrency. The remainder of this section develops each principle, with quantitative analysis grounded in the iron law’s data movement term.
Principle 1: The device bandwidth hierarchy
The cost of moving data between devices varies by orders of magnitude depending on the interconnect.27 Before examining optimization strategies, we need to understand these costs quantitatively. Table 6 shows transfer times for a \(1000\times1000\) float32 tensor (4 MB)—roughly the size of a typical activation tensor in a moderately sized model. The numbers reveal why careless device placement can erase any speedup from GPU acceleration:
27 NVLink: NVIDIA’s high-bandwidth GPU-to-GPU interconnect (see Hardware Acceleration), providing 600 GB/s bidirectional bandwidth (NVLink 3.0 on A100) compared to 64 GB/s for PCIe 4.0 x16. This ~10\(\times\) bandwidth advantage determines whether tensor parallelism is practical for a given model size: splitting a model across GPUs connected by PCIe can make the \(D_{\text{vol}}/BW\) communication term dominate total training time, erasing the benefit of additional compute.
| Interconnect | Bandwidth | Transfer Time | Relative to Compute |
|---|---|---|---|
| PCIe 3.0 x16 | 16 GB/s | 0.25 ms | 10\(\times\) slower than GPU compute |
| PCIe 4.0 x16 | 32 GB/s | 0.125 ms | 5\(\times\) slower than GPU compute |
| NVLink 3.0 | 600 GB/s bidirectional | 0.007 ms | Comparable to GPU compute |
| GPU Memory | 2039 GB/s | 0.002 ms | Optimal |
These numbers connect directly to the iron law of performance. Every cross-device transfer inflates the data movement term (\(D_{\text{vol}}/BW\)) at a fraction of the available on-device bandwidth. A PCIe 4.0 transfer at 32 GB/s means moving a 1 GB activation tensor adds approximately 31 ms to the data movement cost, equivalent to roughly 9.8 trillion operations on a GPU delivering 312 TFLOPS. For a model forward pass taking 0.5 ms on GPU, transferring inputs and outputs over PCIe 3.0 doubles the total latency. When batches are small or models are lightweight, transfer overhead can exceed computation time entirely.
The systems implication is clear: every tensor should reside on the device where it will be consumed, and transfers should occur only when unavoidable. Frameworks track device placement for every tensor and raise errors when operations attempt to combine tensors from different devices, enforcing this discipline at the API level.
Principle 2: Overlapping computation and communication
When transfers are unavoidable, the next optimization is to hide their latency by executing them concurrently with computation. Modern GPUs contain independent hardware units for computation (SM clusters) and data transfer (copy engines), enabling true simultaneous execution. The framework abstraction that exposes this hardware parallelism is the CUDA stream: an independent execution queue where operations execute sequentially within a stream but concurrently across streams.
Without explicit concurrency control, the GPU serializes all operations on a single default stream, leaving execution units idle while data transfers complete. By placing data transfers on one stream and computation on another, the effective latency approaches the theoretical minimum of \(\max(\text{compute\_time}, \text{transfer\_time})\) rather than their sum. Stream-based overlap effectively hides the \(D_{\text{vol}}/BW\) penalty when computation is the longer operation (see Listing 33):
compute_stream = torch.cuda.Stream()
transfer_stream = torch.cuda.Stream()
# Transfer next batch while computing current batch
with torch.cuda.stream(transfer_stream):
next_batch = next_batch_cpu.to("cuda", non_blocking=True)
with torch.cuda.stream(compute_stream):
output = model(current_batch)
loss = criterion(output, labels)
# Pinned memory enables non_blocking transfers
x_pinned = torch.randn(1000, 1000).pin_memory()
x_gpu = x_pinned.to("cuda", non_blocking=True) # Asynchronous
# Regular memory requires blocking transfer
y_regular = torch.randn(1000, 1000)
y_gpu = y_regular.to("cuda", non_blocking=True) # Still blocks The non_blocking=True flag enables asynchronous transfers that return immediately without waiting for completion. This works only when the source tensor uses pinned memory (page-locked memory that enables DMA transfers). Without pinned memory, the transfer blocks even when non_blocking=True is specified, because the GPU’s copy engine cannot initiate a DMA transfer from pageable host memory.
This overlap principle extends naturally to pipeline parallelism within a single node. Different model stages on separate GPUs can process different microbatches concurrently, with each stage’s computation overlapping the next stage’s data reception (see Listing 34):
# Pipeline parallelism: overlap stages across microbatches
stages = [Stage1().cuda(), Stage2().cuda(), Stage3().cuda()]
streams = [torch.cuda.Stream() for _ in stages]
events = [
[torch.cuda.Event() for _ in range(num_microbatches)]
for _ in stages
]
for mb in range(num_microbatches):
for stage_idx, (stage, stream) in enumerate(zip(stages, streams)):
with torch.cuda.stream(stream):
if stage_idx > 0:
# Wait for previous stage to complete this microbatch
events[stage_idx - 1][mb].wait()
output = stage(inputs[stage_idx][mb])
events[stage_idx][mb].record()Extending this pattern across multiple machines requires distributed training techniques that constitute an advanced topic, but the preceding single-node implementation illustrates the core synchronization principles that underlie all pipeline-parallel systems.
With computation and communication overlapping effectively, the remaining challenge is ensuring correctness when operations complete out of order.
Principle 3: Synchronization and correctness
Concurrent execution introduces ordering constraints. When one stream’s output becomes another stream’s input, the system must enforce a happens-before relationship without unnecessarily serializing independent work. Two synchronization mechanisms exist, with dramatically different performance implications.
Full device synchronization (torch.cuda.synchronize()) blocks all streams and the CPU until every queued operation completes. This creates a global serialization point that eliminates all overlap benefits. CUDA events provide the alternative: fine-grained synchronization that blocks only the dependent stream, allowing other streams and the CPU to continue execution (see Listing 35):
# Create streams and event
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()
event = torch.cuda.Event()
# Stream 1: producer
with torch.cuda.stream(stream1):
result1 = expensive_computation(data1)
event.record() # Mark completion point
# Stream 2: consumer (waits only for stream1's event)
with torch.cuda.stream(stream2):
event.wait() # Block stream2 until event is recorded
result2 = dependent_computation(result1) # Safe to use result1The performance difference between these approaches is not incremental but categorical. Full synchronization after every operation converts a concurrent pipeline into a sequential one, entirely negating the hardware parallelism that streams expose. Event-based synchronization preserves the concurrent execution model while enforcing only the dependencies that correctness requires.
Device placement discipline
Every tensor carries a device attribute, and frameworks enforce a strict rule: operations can only combine tensors on the same device. A RuntimeError results from mixing cuda:0 and cuda:1 tensors, preventing silent cross-device transfers. The .to() method moves tensors between devices with copy-on-write semantics—calling .to("cuda") on a tensor already on the GPU returns the same object without copying. Module .to() recursively moves all parameters and buffers, ensuring the entire model hierarchy lands on a single device. Three placement principles prevent transfer bottlenecks: (1) allocate tensors on the target device from the start rather than creating on CPU and transferring, (2) reuse GPU memory across iterations rather than re-allocating, and (3) colocate all inputs, labels, and model parameters on the same device to eliminate implicit transfers. Violating any of these principles inserts PCIe transfers into the critical path, which at 32 GB/s can dominate a training iteration that otherwise runs at 2.0 TB/s on-device.
Synchronization patterns
As Listing 35 demonstrated, event-based synchronization preserves parallelism by enforcing only the dependencies that correctness requires. A common mistake in production code is inserting torch.cuda.synchronize() calls for debugging and forgetting to remove them, silently converting an overlapped pipeline into a serialized one.
Profiling transfer bottlenecks
When overlap is insufficient, profiling reveals where time is lost. NVIDIA provides two complementary tools: Nsight Systems (nsys profile) captures system-wide timelines correlating CPU activity, GPU kernel execution, and memory transfers, identifying which kernels dominate runtime. Nsight Compute (ncu) provides kernel-level analysis with hardware counters, revealing why those kernels underperform. Table 7 lists the key metrics to examine when optimizing ML kernels.
| Metric | Meaning | Optimization Target |
|---|---|---|
| SM Occupancy | Active warps/maximum warps | Increase parallelism if low |
| Memory Throughput | Achieved/peak bandwidth | Optimize memory access patterns |
| Compute Throughput | Achieved/peak FLOPS | Reduce memory bottlenecks |
| Tensor Core Active | Time in Tensor Core ops | Verify mixed-precision utilization |
Data pipelines and loading
Streams and events answer the second question—where does data live, and how does it move?—by overlapping transfers with computation so that the GPU rarely stalls on a single tensor. Scheduling alone, however, cannot help if data arrives too slowly in the first place. The third question is how does data arrive fast enough? The core systems principle is straightforward: the data pipeline must sustain the accelerator’s consumption rate. A GPU processing 1,000 images per second at 224 \(\times\) 224 resolution requires approximately 151 MB/s of sustained data throughput. If the pipeline cannot maintain this rate, the accelerator idles and the effective utilization term in the iron law drops below 1.
Frameworks address this throughput requirement through three mechanisms. The first is parallel worker processes: the DataLoader spawns multiple CPU processes, each independently loading and preprocessing samples. Because data loading involves disk I/O and CPU-bound transformations (decoding, augmentation, normalization), a single process cannot saturate a modern GPU. Multiple workers overlap I/O wait times with preprocessing computation, collectively sustaining throughput that no single process could achieve. When num_workers > 0, the DataLoader distributes sample indices across workers through a shared queue, and workers push completed samples to a data queue that the main process assembles into batches.
The second mechanism is prefetching. The prefetch_factor parameter (default 2) controls how many batches each worker prepares in advance. With four workers and prefetch_factor=2, the pipeline maintains eight batches in flight, ensuring the GPU never stalls waiting for data. While the model processes batch \(N\) on the GPU, workers simultaneously load and preprocess batch \(N+1\) through \(N+8\) on CPUs, effectively hiding data loading latency behind computation. The cost is memory consumption proportional to batch size times prefetch depth.
The third mechanism is pinned memory for DMA transfers. The pin_memory=True option allocates batch data in page-locked (pinned) host memory rather than pageable memory. Pageable memory can be swapped to disk by the operating system, forcing the CUDA runtime to first copy data to a temporary pinned buffer before initiating the GPU transfer. Pinned memory bypasses this intermediate copy, enabling direct memory access (DMA) transfers where the GPU’s memory controller reads directly from host memory while the CPU continues other work. For a batch of 64 images at 224\(\times\) \(224\times3\) in FP32 (39 MB), pinned memory transfer takes approximately 1.2 ms over PCIe 4.0 x16 (32 GB/s) compared to ~3.0 ms with pageable memory, a 2–3\(\times\) speedup. The cost is reduced available system memory, as pinned pages cannot be swapped.
These three mechanisms appear together in the DataLoader configuration. Understanding how each parameter connects to the underlying systems principle helps practitioners diagnose data pipeline bottlenecks. Listing 36 shows a typical setup where num_workers enables parallel loading, prefetch_factor controls pipeline depth, and pin_memory enables DMA transfers:
from torch.utils.data import DataLoader
loader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4, # Parallel worker processes (mechanism 1)
prefetch_factor=2, # Batches prepared ahead per worker (mechanism 2)
pin_memory=True, # Page-locked memory for DMA (mechanism 3)
worker_init_fn=seed_worker, # Reproducible augmentation per worker
)
# Pipeline effect: while GPU processes batch N,
# 4 workers load batches N+1..N+8 into pinned memory,
# ready for DMA transfer when the GPU finishes.A practical starting point is setting num_workers equal to the number of available CPU cores. The optimal value depends on whether loading is I/O-bound or CPU-bound. For I/O-bound workloads such as reading images from network storage, more workers overlap disk latency and improve throughput. For CPU-bound workloads involving heavy augmentation, the benefit saturates once all cores are in use. Too many workers waste memory, since each maintains a copy of the Dataset object.
Worker process management introduces several subtle issues. Because workers are separate processes, random number generators used in data augmentation must be explicitly seeded per worker via worker_init_fn to ensure reproducibility. Without proper seeding, workers may produce identical augmentation sequences, reducing effective data diversity. Shared state between workers presents a separate challenge: each worker has its own memory space, so modifications to global variables in one worker do not propagate to others or to the main process. For large datasets where caching matters, memory-mapped files or shared memory regions that persist across processes are the standard solution.
The DataLoader wraps a Dataset object that defines how individual samples are accessed. PyTorch supports two dataset paradigms. Map-style datasets implement __len__ and __getitem__, enabling random access to samples by index—this pattern works well for datasets that fit in memory or support efficient random access on disk. Iterable-style datasets implement __iter__ instead, yielding samples sequentially for streaming data sources where random access is impractical. The choice between paradigms determines whether the DataLoader can shuffle samples (map-style only) or must process them in arrival order (iterable-style).
A final detail is collation: the collate_fn parameter determines how individual samples are combined into batches. The default collation stacks tensors along a new batch dimension, which works when all samples have identical shapes. For variable-length data such as text sequences, custom collation handles padding, sorting by length, or creating attention masks—directly affecting both memory usage and training throughput.
DataLoaders, Datasets, and collation functions answer the third question—how does data arrive fast enough?—by sustaining accelerator-rate throughput through parallelism, prefetching, and DMA. These structures, however, handle only ephemeral data: samples flow through the pipeline once per epoch and are discarded. The fourth question asks how frameworks manage data that persists—the model’s own weights—especially when those weights exceed the memory of any single device.
Parameter structures
A GPT-3 scale model stores 175 billion parameters, occupying 350 GB in FP16. Managing these parameters across devices, keeping gradients synchronized, and maintaining optimizer state (which can triple the memory footprint, as the Administrative Tax notebook showed) is a core framework responsibility.
Because parameters persist throughout training and inference, frameworks organize them into compact structures that minimize memory while enabling fast read and write access (Li et al. 2014). During multi-GPU training, frameworks may replicate parameters across devices for parallel computation while keeping a synchronized master copy. Synchronizing multi-billion parameter models can require transferring tens of GB of gradients per step, which is why frameworks implement gradient compression and efficient communication patterns like ring all-reduce.
Parameter structures must also adapt to varying precision requirements. Training typically uses FP32 for gradient stability, but inference and large-scale training increasingly use FP16 or INT8. Frameworks implement type casting and mixed-precision management to enable these optimizations without compromising numerical accuracy.
Distributed execution contexts
The computational graph defines what to compute, but where and how that computation runs across devices is the job of execution contexts. On a single node, execution contexts manage CUDA streams and events (discussed earlier in this chapter) to overlap computation and data transfer across GPUs.
When training scales beyond a single machine, these same abstractions extend to manage process groups and communication primitives. Frameworks use constructs like ProcessGroup (PyTorch) or Mesh (JAX) to define how devices communicate, maintaining state for collective operations such as AllReduce that synchronize gradients across thousands of GPUs. This includes partitioning computational graphs, synchronizing gradients, and redistributing data as needed.
We introduce these concepts here because they shape framework API design even for single-node code. The implementation details of distributed training—including gradient compression, communication topologies, and fault tolerance—constitute advanced topics that build on these single-node foundations.
When models exceed single-device memory, frameworks combine multiple parallelism strategies simultaneously. A GPT-3 scale model, for instance, cannot fit on a single GPU—its 175 B parameters alone require 350 GB in FP16, far exceeding any GPU’s memory. How do practitioners train such models? By distributing computation across multiple devices using three complementary strategies. Figure 11 lays out how large-scale training distributes computation across three orthogonal dimensions to overcome this constraint. In the figure, look for how each dimension addresses a different scaling need: Data Parallelism (replicating the model across columns) scales throughput by processing different batches in parallel; Pipeline Parallelism (splitting layers across rows) distributes a single model’s depth across devices; and Model Parallelism (sharding tensors within each cluster) partitions individual layers that are too large for one device. This “3D” approach allows frameworks to scale beyond the memory limits of any single device. Model Training examines these parallelism strategies in depth, including their implementation trade-offs and communication patterns.
The data structures examined so far—tensors, device managers, data pipelines, parameter structures, and distributed execution contexts—define what data a framework manages and where it lives. What remains is the final question of what actually runs on the hardware.
Core operations
When an engineer writes y = torch.matmul(x, w), the gap between Python and the GPU is larger than it appears. The gap between a single line of Python and thousands of parallel GPU threads is bridged by three distinct layers working in coordination. Figure 12 breaks this bridge into three distinct layers—read from bottom to top to follow the path from hardware to application: hardware abstraction operations manage computing platform complexity, basic numerical operations implement mathematical computations, and system-level operations coordinate resources and execution.
Hardware abstraction operations
The hardware abstraction layer isolates framework code from platform-specific details. It solves three concrete problems: selecting the right compute kernel, moving data through the memory hierarchy, and coordinating execution across processing units.
Compute kernel management
The kernel manager dispatches each operation to the fastest available implementation for the current hardware. When a framework encounters a matrix multiplication, it selects among AVX-512 vector instructions on modern CPUs, cuBLAS on NVIDIA GPUs, or dedicated tensor processing instructions on AI accelerators. The dispatch decision depends on input dimensions, data layout, and hardware capabilities. A \(4096\times4096\) GEMM on an A100 GPU routes to cuBLAS Tensor Core kernels that sustain up to 312 TFLOPS in FP16, while the same operation on a CPU falls back to an AVX-512 path at roughly 2 TFLOPS. When no specialized kernel exists, the manager falls back to a generic implementation rather than failing.
Memory system abstraction
The memory abstraction layer moves tensors between device types (CPU registered memory, GPU pinned memory, unified memory) and transforms data layouts to match hardware preferences. A convolutional layer, for example, may store activations in NCHW format (batch, channels, height, width) on NVIDIA GPUs but convert to NHWC for Apple’s Metal backend. Alignment requirements vary from 4 bytes on CPUs to 128 bytes on some accelerators, and misaligned access can halve effective memory bandwidth. The layer also enforces cache coherency when multiple execution units read and write the same tensor, preventing silent data corruption during concurrent operations.
Execution control
The execution controller coordinates work across multiple processing units and memory spaces. On a modern GPU, this means managing dozens of concurrent CUDA streams: when two independent convolutions are both ready to execute, the controller launches them on separate streams so they overlap on the GPU’s streaming multiprocessors, improving utilization from as low as 40 percent (sequential) to over 80 percent (concurrent). The controller inserts synchronization barriers only where true data dependencies exist, tracks event completions to trigger dependent operations, and routes hardware errors (ECC failures, timeout watchdogs) to the framework’s error handling path.
Basic numerical operations
With hardware abstraction managing the platform-specific details, frameworks build a layer of mathematical operations on top. General Matrix Multiply (GEMM) dominates ML computation (see General matrix multiply (GEMM) for arithmetic intensity analysis and the roofline implications). The operation C = \(\alpha\)AB + \(\beta\)C accounts for the vast majority of arithmetic in neural networks: a single ResNet-50 forward pass performs approximately 4.1 billion floating-point operations, nearly all of which reduce to GEMM. Frameworks optimize GEMM through cache-aware tiling (splitting matrices into blocks that fit in L1/L2 cache), loop unrolling for instruction-level parallelism, and shape-specific kernels. Fully connected layers use standard dense GEMM, while convolutional layers use im2col transformations that reshape input patches into matrix columns, converting convolution into GEMM.
Beyond GEMM, frameworks implement BLAS operations (AXPY for vector addition, GEMV for matrix-vector products) and element-wise operations (activation functions, normalization). Element-wise operations are individually cheap but collectively expensive due to memory bandwidth. Each operation reads and writes the full tensor, so a sequence of five element-wise operations on a 100 MB tensor moves 1 GB of data. Fusing those five operations into a single kernel reduces memory traffic to 200 MB, a 5\(\times\) bandwidth savings that directly translates to faster execution.
Numerical precision adds another dimension. Training in FP32 uses 4 bytes per parameter; quantizing to INT8 reduces this to 1 byte, cutting memory by 4\(\times\) and enabling 2–4\(\times\) throughput improvements on hardware with INT8 acceleration. Training typically requires FP32 for gradient stability, while inference runs at FP16 or INT8 with minimal accuracy loss. Frameworks maintain separate kernel implementations for each precision format and handle mixed-precision workflows where different layers operate at different bit widths within a single forward pass.
System-level operations
Hardware abstraction and numerical operations provide the building blocks; system-level operations orchestrate them. The system layer ties scheduling, memory management, and resource optimization into a coherent execution engine.
The operation scheduler analyzes the computational graph to find parallelism while respecting data dependencies. In a static graph, the scheduler sees the full dependency structure before execution begins and can plan an optimal ordering. In a dynamic graph, dependencies emerge at runtime, forcing the scheduler to make greedy decisions. Concretely, when a ResNet block produces two independent branch outputs, the scheduler launches both branches simultaneously rather than serializing them, reducing idle cycles on the GPU’s streaming multiprocessors.
The memory manager allocates and reclaims GPU memory across the computational graph’s lifetime. Model parameters (a 7B-parameter model consumes approximately 14 GB in FP16) persist for the entire training run, while activation tensors live only until the backward pass consumes them. PyTorch’s caching allocator maintains a memory pool, subdividing and reusing freed blocks without returning them to CUDA, which avoids the 1 ms overhead of cudaMalloc calls. For models that exceed GPU memory, the manager applies gradient checkpointing: discarding selected activations during the forward pass and recomputing them during the backward pass, trading roughly 20–33 percent additional compute for 60 percent or more memory savings (with optimal checkpoint placement).
The resource optimizer integrates these scheduling and memory decisions. When two matrix multiplications with different shapes are ready to execute, it selects the algorithm variant (Winograd, Strassen, or standard tiled GEMM) that best fits each shape and the current memory pressure. A poorly scheduled graph wastes compute; a poorly managed memory pool triggers out-of-memory errors on hardware that theoretically has capacity to spare.
The preceding sections examined what happens beneath the API surface: tensors manage data layout, streams overlap computation with communication, and kernel dispatch routes operations to hardware. These mechanisms operate at the level of individual tensors and operations—the raw materials of machine learning computation. Practitioners, however, rarely write code at this level. A ResNet-50 has 25.6 million parameters organized into dozens of layers; manually tracking each tensor, registering it with an optimizer, and handling device placement would be error-prone and tedious. The abstraction problem is not fully solved by hardware-level mechanisms alone; it also requires a programming model that organizes these low-level primitives into the clean APIs that practitioners actually use.
Checkpoint 1.3: Hardware Abstraction
The abstraction problem is the bridge between portable code and efficient execution.
Individual operations—matrix multiplications, activations, normalizations—are the atoms of deep learning computation. Building models from individual operations, however, would be like building a house from individual atoms. Frameworks need an organizational abstraction that lets engineers compose operations into reusable, nestable building blocks. That abstraction is the module.
nn.Module Abstraction
The hardware-facing half of the abstraction problem—tensors, kernels, streams, and memory managers—makes individual operations fast on diverse silicon. A ResNet-50, however, contains fifty layers, each with multiple parameter tensors, buffers, and mode-dependent behaviors. Manually wiring each tensor to the correct device, registering it with an optimizer, toggling dropout behavior between training and inference, and serializing state for checkpointing—for every layer—would drown practitioners in bookkeeping that has nothing to do with model design. The upper layer of the abstraction problem is organizational: composing thousands of low-level primitives into the clean, composable APIs that practitioners actually use.
Every major framework answers this question through a module abstraction that bundles parameters, forward computation, and state management into a single reusable unit. PyTorch’s nn.Module28 provides an instructive case study because its design patterns recur across frameworks: Keras uses similar layer abstractions, JAX’s Flax employs analogous module structures, and TensorFlow’s functional API shares conceptual parallels. Rather than catalog its API, we extract three enduring design principles that every framework must address regardless of its syntax or programming paradigm.
28 nn.Module: The “design patterns recur” claim holds because nn.Module solves a universal organizational problem: it automatically registers any assigned submodule or parameter into a hierarchical tree, enabling a single .to('cuda') call to recursively place millions of parameters onto a GPU. Keras layers, JAX Flax modules, and TensorFlow’s tf.Module all implement the same tree-walking pattern. Without it, managing model state would require manual bookkeeping that scales linearly with architectural depth, a cost that grows prohibitive for models with hundreds of layers.
Principle 1: Automatic parameter discovery
A modern neural network may contain millions of trainable parameters spread across dozens of layers. Without automation, a programmer would need to enumerate every parameter tensor and pass it to the optimizer manually, an error-prone process that scales poorly with model complexity. Frameworks solve this through automatic parameter discovery: the system walks the module tree, collecting every parameter tensor so the optimizer can update them in a single call.
This is a graph traversal problem at its core. When a developer assigns an nn.Parameter as a class attribute, the framework’s metaclass machinery intercepts the assignment and registers the tensor in an internal dictionary. A call to .parameters() then performs a recursive depth-first traversal of the module tree, yielding every registered parameter. The same pattern appears in every major framework: Keras layers maintain a trainable_weights list, JAX’s Flax modules use init() to return a nested parameter dictionary, and TensorFlow’s tf.Module provides trainable_variables. The mechanism differs but the principle is universal.
The systems consequence is significant. Automatic parameter discovery enables optimizer.step() to update millions of parameters in a single vectorized operation, keeping the operations-per-parameter term efficient by avoiding per-parameter Python dispatch. Without this abstraction, each parameter update would require a separate Python function call, and the interpreter overhead alone would dominate training time for large models. Listing 37 demonstrates the core mechanism: attribute assignment triggers registration, and .parameters() returns all discovered tensors.
import torch
import torch.nn as nn
class CustomLayer(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.weight = nn.Parameter(
torch.randn(output_size, input_size)
)
self.bias = nn.Parameter(torch.randn(output_size))
self.register_buffer("running_mean", torch.zeros(output_size))
def forward(self, x):
return torch.matmul(x, self.weight.t()) + self.bias
layer = CustomLayer(10, 20)
# Framework discovers both parameters automatically:
for name, param in layer.named_parameters():
print(f"{name}: shape {param.shape}")The distinction between parameters and buffers illustrates a subtlety of the discovery mechanism. Parameters carry requires_grad=True and participate in gradient computation. Buffers, registered through register_buffer(), travel with the model during device transfers but remain excluded from gradient updates. This separation is essential for normalization layers, where running statistics must persist across batches but must not receive gradients. The same dual-track design appears in Keras (via non_trainable_weights) and Flax (via state vs. params).
Systems Perspective 1.4: Cross-Framework Parameter Discovery
| Framework | Parameter Access | Non-Trainable State |
|---|---|---|
| PyTorch | model.parameters() |
register_buffer() |
| Keras | layer.trainable_weights |
layer.non_trainable_weights |
| JAX/Flax | params = model.init(key, x) |
Separate state dict |
| TensorFlow | module.trainable_variables |
module.non_trainable_variables |
Despite syntactic differences, all frameworks solve the same problem: enabling optimizers to discover and update trainable parameters while preserving non-trainable state across forward passes.
Principle 2: Mode-dependent behavior
Training and inference require different computational behavior from the same model graph. During training, dropout layers randomly zero elements with probability \(p\) to regularize the network, while during inference those same layers must perform identity mapping to produce deterministic outputs. Batch normalization uses per-batch statistics during training but switches to accumulated running statistics during inference. If these behavioral changes are left to the programmer, forgetting a single mode switch produces silently incorrect predictions in production.
Frameworks solve this with a state flag that propagates through the module hierarchy. A single call to .eval() on the root module recursively sets self.training = False on every descendant, and each layer queries this flag to select its behavior. This is an instance of a broader systems principle: the same computation graph must produce different execution behavior depending on context. Compilers face the same challenge when the same source code must produce debug builds (with bounds checking and symbol tables) vs. release builds (with aggressive optimization). The flag-propagation pattern ensures correctness by centralizing the mode decision at the root rather than requiring per-layer coordination.
This principle extends to parameter freezing for transfer learning. Setting requires_grad=False on specific parameters excludes them from gradient computation, effectively creating a third behavioral mode where some parameters train while others remain fixed. Selective freezing achieves computational savings by pruning the backward pass graph: frozen parameters need no gradient storage, reducing memory consumption proportionally.
Principle 3: Hierarchical composition and serialization
Complex models compose from reusable submodules, creating a tree structure. A ResNet is not implemented as a monolithic block of operations but as a hierarchy: the root module contains a sequence of residual blocks, each block contains convolution layers and normalization layers, and each layer contains parameter tensors. This hierarchical composition must support two critical operations: recursive parameter collection for training and state serialization for checkpointing and deployment.
Hierarchical composition mirrors the hardware memory hierarchy in a systems-relevant way: each submodule’s parameters can be loaded independently, enabling model parallelism across devices. When a model is too large for a single GPU, the framework can assign different subtrees of the module hierarchy to different devices, with the tree structure providing natural partition boundaries.
The state dictionary mechanism provides the serialization half of this principle. The state_dict() method produces a flat key-value mapping of the full module tree, where dotted path names (for example, blocks.0.conv1.weight) encode the hierarchy. This flat structure enables efficient serialization: a 7B-parameter model’s approximately 14 GB FP16 checkpoint can be written as a sequential byte stream, maximizing storage bandwidth utilization. The inverse operation, load_state_dict(), reconstructs the hierarchy from the flat mapping, enabling checkpoint recovery and cross-framework model exchange via formats like ONNX. Listing 38 demonstrates how the module tree enables both recursive parameter access and hierarchical state serialization.
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
x = torch.relu(self.bn1(self.conv1(x)))
x = self.bn2(self.conv2(x))
return torch.relu(x + residual)
class ResNet(nn.Module):
def __init__(self, num_blocks, channels=64):
super().__init__()
self.conv_in = nn.Conv2d(3, channels, 7, padding=3)
self.blocks = nn.ModuleList(
[ResidualBlock(channels) for _ in range(num_blocks)]
)
self.fc = nn.Linear(channels, 10)
def forward(self, x):
x = self.conv_in(x)
for block in self.blocks:
x = block(x)
x = x.mean(dim=[2, 3]) # Global average pooling
return self.fc(x)
model = ResNet(num_blocks=4)
total = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total}")
# state_dict() flattens the tree: 'blocks.0.conv1.weight', etc.
print(list(model.state_dict().keys())[:4])The hierarchical structure also enables module-level traversal for systematic operations. Methods like .named_modules() iterate the entire tree, supporting bulk transformations such as replacing all BatchNorm layers with GroupNorm or applying Xavier initialization to every Linear layer. These traversal operations depend on the same tree structure that enables parameter discovery, illustrating how a single design decision propagates benefits across multiple use cases.
These three principles, automatic parameter discovery, mode-dependent behavior, and hierarchical composition with serialization, are not PyTorch-specific. Every framework must solve them. Keras layers, JAX’s Flax modules, and even functional approaches all address the same problems of parameter management, state tracking, and compositional design. The differences lie not in what problems they solve but in how they prioritize among competing solutions. Two practical patterns built on these principles deserve attention: selective parameter freezing for transfer learning (Listing 39) and module hooks for non-invasive inspection (Listing 40).
# Freeze all parameters in a pretrained model
pretrained_model = torch.hub.load(
"pytorch/vision", "resnet18", pretrained=True
)
for param in pretrained_model.parameters():
param.requires_grad = False
# Replace final layer with trainable parameters
pretrained_model.fc = nn.Linear(512, 10) # New layer is trainable
# Only fc.parameters() will receive gradients during training
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, pretrained_model.parameters()),
lr=0.001,
)Forward and backward hooks intercept intermediate computations without modifying model code, enabling gradient flow diagnosis and activation monitoring. Listing 40 illustrates both hook types.
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
# Forward hook to inspect activations
def forward_hook(module, input, output):
print(
f"Layer: {module.__class__.__name__}, "
f"Output shape: {output.shape}, "
f"mean={output.mean():.3f}, "
f"std={output.std():.3f}"
)
# Backward hook to inspect gradients
def backward_hook(module, grad_input, grad_output):
print(f"Gradient norm: {grad_output[0].norm():.3f}")
# Register hooks on specific layer
handle_fwd = model[0].register_forward_hook(forward_hook)
handle_bwd = model[0].register_full_backward_hook(backward_hook)
# Execute forward and backward pass
x = torch.randn(32, 10)
y = model(x)
loss = y.sum()
loss.backward()
# Remove hooks when done
handle_fwd.remove()
handle_bwd.remove()Together, these patterns—parameter discovery, freezing, and hooks—demonstrate how the three principles translate into practical APIs. The preceding nn.Module patterns illustrate PyTorch’s approach to the abstraction problem. PyTorch, however, is only one of several major frameworks, and its choices (mutable state, class inheritance, eager execution by default) are not the only valid design points. TensorFlow centralizes state differently, and JAX avoids mutable state entirely. These are not superficial API differences; they reflect deeply different answers to the three problems we examined at the chapter’s start.
Framework Platform Analysis
Each major framework represents a distinct point in the design space defined by the three core problems: TensorFlow prioritizes the Abstraction Problem through its comprehensive deployment ecosystem, PyTorch prioritizes the Execution Problem through its dynamic graph approach, and JAX reframes the Differentiation Problem through composable function transformations. These differences are architectural, reflecting fundamental capability trade-offs that determine what each framework can and cannot do well.
TensorFlow: The graph-first production machine
TensorFlow’s architecture reflects a comprehensive solution to the Abstraction Problem: targeting diverse hardware, from cloud TPUs to microcontrollers, through a single interface. Google’s production environment demanded this breadth because the same model often needed to serve predictions on TPU pods in the data center, on Android phones via TensorFlow Lite, and in web browsers through TensorFlow.js. This deployment diversity drove the choice of a Static Graph (or “Define-and-Run”) design. By requiring the model to be represented as a complete computational graph before execution, TensorFlow enables ahead-of-time (AOT) compilation and optimization for each target platform.
The graph-first approach prioritizes the Deployment Spectrum: because the framework sees the entire graph, it can perform aggressive optimizations like constant folding, operator fusion, and memory layout optimization before the first byte of data is processed. TensorFlow’s dominance in complex production ecosystems traces directly to this ahead-of-time optimization capability. Figure 13 maps the full training-to-deployment pipeline—trace how a model flows from data preprocessing through distributed training on the left, then fans out to serving, mobile (TF Lite), browser (TF.js), and language bindings on the right.
While TensorFlow 2.0 introduced eager execution to bridge the gap between research and production, its core strength remains the robust, compiled path from research to global-scale deployment. Model Training examines how TensorFlow’s distribution strategies enable large-scale training, while Model Serving covers its production serving infrastructure.
PyTorch: The eager research standard
Where TensorFlow’s graph-first approach prioritizes production optimization, PyTorch makes the opposite trade-off: it prioritizes developer experience. PyTorch’s architecture represents a sharply different answer to the Execution Problem, built on Dynamic Graphs (or “Define-by-Run”). Instead of building a blueprint before execution, PyTorch builds the computational graph on-the-fly as the code runs. Facebook AI Research (FAIR) adopted this design because researchers need immediate feedback when experimenting with novel architectures; the define-then-run cycle of static graphs introduced a compilation delay that slowed the rapid prototyping essential to research workflows.
PyTorch’s approach won the broader research community for the same reason: it treats deep learning as standard Python programming. Developers can use Python loops, conditionals, and debuggers (like pdb) directly within a model’s forward pass, with no special syntax, no separate compilation step, and no waiting to see if the code works. Eager execution enables rapid iteration and intuitive model design, which is essential for the trial-and-error nature of frontier AI research.
PyTorch’s answer to the Differentiation Problem is the tape-based autograd system examined in Section 1.3.2.1: flexible and debuggable, but harder to optimize globally because the tape is rebuilt each iteration. Its answer to the Abstraction Problem is more pragmatic than comprehensive: strong GPU support through cuBLAS and cuDNN, but deployment to mobile, edge, and browser environments requires exporting through ONNX or specialized runtimes rather than a native path.
The trade-off is therefore a more fragmented deployment path. Because the graph is dynamic, the framework cannot easily perform global optimizations before execution. A model that works perfectly in development may hit performance walls in production when dispatch overhead dominates small operations. To bridge this research-to-production gap, PyTorch introduced TorchScript and PyTorch 2.0 (with torch.compile), which allow developers to capture a dynamic model and turn it into an optimized, static representation for deployment. This evolution shows PyTorch moving toward the production end of the compilation continuum while preserving the eager experience that made it dominant in research.
JAX: The functional transformation engine
PyTorch’s eager execution and TensorFlow’s graph compilation represent two points on a spectrum, yet both share an imperative programming heritage where computation proceeds as a sequence of stateful operations. JAX represents a radically different approach, one built on functional programming principles and composable program transformations rather than computational graphs (Bradbury et al. 2018). Developed by Google Research, JAX has gained significant traction in research settings, particularly for work requiring custom differentiation, advanced optimization research, and large-scale distributed training.
JAX’s architecture reframes the Differentiation Problem entirely. Google Research built JAX on a key observation: if functions are pure (no side effects, no mutable state), the compiler can safely reorder, fuse, and parallelize any operation, because outputs depend only on inputs. This constraint, borrowed from functional programming, is what makes JAX’s composable transformations possible. Rather than implementing automatic differentiation as a tape-based system (PyTorch) or a graph transformation pass (TensorFlow), JAX treats differentiation as one of several composable function transformations. The jax.grad function does not compute gradients directly; it returns a new function that computes gradients. This subtle distinction enables arbitrary compositions: differentiating a differentiated function yields higher-order derivatives, vectorizing a gradient computation (vmap(grad(f))) parallelizes across examples, and compiling a vectorized gradient to XLA (jit(vmap(grad(f)))) eliminates Python overhead entirely.
JAX’s functional paradigm requires a genuine mental shift from “tracking state through objects” to “transforming pure functions.” The conceptual introduction here covers JAX’s core design; transformation composition, pytree handling, and XLA tracing mechanics each warrant dedicated study for production use.
Transformations over state
While PyTorch and TensorFlow build computational graphs (dynamically or statically), JAX transforms functions. The core insight is that automatic differentiation, vectorization, and JIT compilation are all program transformations that can compose. Listing 41 demonstrates this composable approach.
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
pred = jnp.dot(x, params["w"]) + params["b"]
return jnp.mean((pred - y) ** 2)
# Transform: compute gradients
grad_fn = jax.grad(loss_fn)
# Transform: vectorize over batch dimension
batched_grad = jax.vmap(grad_fn, in_axes=(None, 0, 0))
# Transform: compile to XLA
fast_batched_grad = jax.jit(batched_grad)
# Compose all three: fast, batched gradient computationThis functional approach requires pure functions (no side effects) and immutable data (arrays cannot be modified in place). These constraints may seem restrictive coming from PyTorch’s mutable object model, but they enable formal guarantees: the compiler can safely reorder, fuse, and parallelize operations because function outputs depend only on inputs. The restriction is the feature; purity is what makes transformation composition possible.
Key transformations
JAX’s power emerges from composition. Start with a loss function f and apply jax.grad to obtain a new function that computes gradients—unlike PyTorch’s tape-based autograd, grad returns a function, not a value, supporting both forward-mode (jacfwd) and reverse-mode (jacrev) differentiation. Wrap that gradient function in jax.jit and JAX traces it once, compiles to optimized XLA machine code, caches the result, and eliminates Python overhead on subsequent calls. Apply jax.vmap to the compiled gradient function and it automatically vectorizes across a batch dimension, transforming single-example code into batched code without manual reshaping. Finally, jax.pmap maps the vectorized, compiled gradient function across multiple GPUs or TPUs, automatically handling inter-device communication. The result—pmap(jit(vmap(grad(f))))—expresses distributed, compiled, batched gradient computation as a single composed expression. No other framework offers this level of compositional power.
Ecosystem and libraries
JAX’s minimalist core delegates neural network abstractions to companion libraries (Flax, Haiku, Equinox) and optimization to Optax. This separation reflects the functional philosophy: the core provides transformations, while libraries build conventional abstractions on top. The ecosystem is younger and smaller than PyTorch’s or TensorFlow’s, which affects the availability of pre-built components for production use.
Trade-offs and use cases
The functional constraints that JAX imposes become advantages in specific domains. Custom differentiation—higher-order gradients, custom VJP/JVP rules—composes cleanly because pure functions make differentiation rules predictable. Research on optimization algorithms benefits from transformations that let researchers manipulate gradient computation as naturally as they manipulate data. Large-scale distributed training, particularly on TPUs, uses XLA compilation to extract maximum hardware utilization. Scientific computing with AD requirements benefits from functional purity that enables mathematical reasoning about code. JAX requires more upfront investment than PyTorch: the functional paradigm has a learning curve, state management requires explicit patterns, and debugging compiled code is harder than eager execution. Teams should choose JAX when its strengths align with project requirements, not as a default.
Quantitative platform performance analysis
The preceding sections described each framework’s design philosophy in qualitative terms: graph-first vs. eager-first, stateful vs. functional. Design philosophy claims, however, are only meaningful when backed by measurement. Table 8 quantifies how the architectural choices of TensorFlow, PyTorch, and JAX translate to system characteristics. When examining this comparison, note particularly the differences in execution mode, compilation optimization potential, and distributed scalability—these dimensions most directly impact production deployment decisions.
| Aspect | TensorFlow | PyTorch | JAX |
|---|---|---|---|
| Graph Type | Static (1.x), Dynamic (2.x) | Dynamic | Functional transformations |
| Programming Model | Imperative (2.x), Symbolic (1.x) | Imperative | Functional |
| Core Data Structure | Tensor (mutable) | Tensor (mutable) | Array (immutable) |
| Execution Mode | Eager (2.x default), Graph | Eager | Just-in-time compilation |
| Automatic Differentiation | Reverse mode | Reverse mode | Forward and Reverse mode |
| Hardware Acceleration | CPU, GPU, TPU | CPU, GPU | CPU, GPU, TPU |
| Compilation Optimization | XLA: 3–10\(\times\) speedup | TorchScript: 2\(\times\) | XLA: 3–10\(\times\) speedup |
| Memory Efficiency | 70–90 percent (workload dependent) | 70–90 percent (varies) | 75–95 percent (with XLA fusion) |
| Distributed Scalability | High (1024+ GPUs) | High | Very High (1024+ GPUs) |
An important caveat applies to these numbers: GPU utilization and compilation speedups vary significantly by model architecture, batch size, and operation mix. JAX/XLA achieves higher utilization for TPU workloads through aggressive fusion, while PyTorch and TensorFlow perform similarly for most deep learning workloads. These framework-level generalizations provide useful orientation but cannot substitute for profiling specific workloads on target hardware.
How do these architectural differences look in practice? Listing 42 implements the same neural network (a single linear layer mapping ten inputs to one output) across all three frameworks, revealing how design philosophy shapes even the simplest code:
# PyTorch - Dynamic, Pythonic
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# TensorFlow/Keras - High-level API
import tensorflow as tf
model = tf.keras.Sequential(
[tf.keras.layers.Dense(1, input_shape=(10,))]
)
# JAX - Functional approach
import jax.numpy as jnp
from jax import random
def simple_net(params, x):
return jnp.dot(x, params["w"]) + params["b"]
key = random.PRNGKey(0)
params = {
"w": random.normal(key, (10, 1)),
"b": random.normal(key, (1,)),
}These three implementations solve the same mathematical problem but reveal distinct answers to the Three Problems. The differences are not cosmetic; they shape debugging workflows, deployment options, and optimization potential.
PyTorch binds state and computation together through class inheritance (nn.Module), solving the Execution Problem through eager evaluation: the graph builds as Python runs, making standard debuggers and control flow work naturally. The cost is that no optimizer sees the full computation before execution begins.
TensorFlow/Keras inverts this priority through the Sequential API, which declares structure without executing it, solving the Abstraction Problem first: the same declaration compiles to server GPUs, mobile NPUs, or browser WebGL backends. Eager mode (default in TensorFlow two.x) recovers some of PyTorch’s debugging flexibility, but production deployment still relies on graph capture for optimization.
JAX makes the most radical trade-off by treating the model as a pure function29 with immutable data and no internal state. This functional purity solves the Differentiation Problem most elegantly: grad, vmap (automatic vectorization), and jit (just-in-time compilation30) are composable transformations on stateless functions, not infrastructure bolted onto an object system. The cost is explicit parameter management and a programming model unfamiliar to most engineers.
29 Pure Function: Has no side effects and always returns the same output for the same inputs. In JAX, purity is not a style preference but a compiler requirement: jax.jit traces the function once and caches the compiled result, so any side effect (printing, modifying global state, random number generation without explicit key threading) would execute only during the first trace and silently vanish from subsequent calls. This constraint is the cost JAX pays for composable, whole-program optimization.
30 Just-in-Time (JIT) Compilation: Translates high-level code into optimized machine code at runtime, specializing for the actual data shapes and hardware present. The trade-off is compilation latency: the first execution pays a one-time cost (5–30 seconds for transformer models) while subsequent calls with the same shapes execute cached compiled code with microsecond dispatch overhead. Shape changes trigger recompilation, which is why dynamic sequence lengths in language models can degrade JIT performance unless the framework pads to fixed shape buckets.
No framework optimizes all three problems simultaneously; each makes deliberate trade-offs that shape everything from API design to performance characteristics. PyTorch prioritizes the Execution Problem (eager debugging, dynamic graphs) at the cost of optimization potential. TensorFlow prioritizes the Abstraction Problem (unified deployment from cloud to microcontroller) at the cost of development flexibility. JAX reframes the Differentiation Problem (composable function transformations) at the cost of a steeper learning curve. These are the same design tensions examined in the preceding subsections, now visible even in a ten-line program. Exploratory research favors PyTorch’s debugging immediacy, production deployment favors TensorFlow’s optimization depth, and algorithmic research favors JAX’s composable transformations. Each philosophy shapes code syntax, team workflows, debugging practices, and deployment pipelines, which is why framework migration costs are measured in engineer-months rather than engineer-days.
These design differences are not arbitrary; they reflect which term of the iron law each framework prioritizes. TensorFlow’s graph compilation minimizes the Overhead term through ahead-of-time optimization, PyTorch’s eager execution minimizes the developer iteration overhead at the cost of runtime optimization, and JAX’s XLA backend minimizes the Data Movement term through aggressive operation fusion.
Quantitative framework efficiency comparison
How large are these differences in practice? Table 9 compares major frameworks across efficiency dimensions using benchmark workloads representative of production deployment scenarios.
| Framework | Inference Latency (ms) | Memory Usage (MB) | Energy (mJ/inference) | Model Size Reduction | Hardware Utilization (%) |
|---|---|---|---|---|---|
| TensorFlow | 45 | 2,100 | 850 | None | 35 |
| TensorFlow Lite | 12 | 180 | 120 | 4\(\times\) (quantized) | 65 |
| TensorFlow Lite Micro | 8 | 32 | 45 | 8\(\times\) (pruned+quant) | 75 |
| PyTorch | 52 | 1,800 | 920 | None | 32 |
| PyTorch Mobile | 18 | 220 | 180 | 3\(\times\) (quantized) | 58 |
| ONNX Runtime | 15 | 340 | 210 | 2\(\times\) (optimized) | 72 |
| TensorRT | 3 | 450 | 65 | 2\(\times\) (precision opt) | 88 |
| Apache TVM | 6 | 280 | 95 | 3\(\times\) (compiled) | 82 |
The efficiency data reveals several important patterns. First, specialized inference frameworks (TensorRT, Apache TVM) achieve 10–15\(\times\) lower latency than general-purpose training frameworks (PyTorch, TensorFlow) on identical hardware, demonstrating that framework selection has quantitative performance implications beyond qualitative design preferences. Second, mobile-optimized variants (TF Lite, PyTorch Mobile) reduce memory requirements by 10\(\times\) compared to their full counterparts while maintaining accuracy within one percent through quantization and graph optimization. Third, hardware utilization varies dramatically: TensorRT achieves 88 percent GPU utilization through aggressive kernel fusion while vanilla PyTorch achieves only 32 percent, a 2.75\(\times\) efficiency gap that directly translates to cost differences in production deployment.
These efficiency gaps, significant in the data center, become existential as we move beyond the server room. A 17\(\times\) latency difference between PyTorch and TensorRT is an optimization opportunity on a cloud GPU; on a microcontroller with 256 KB of RAM, a framework that requires 1.8 GB of memory simply cannot run at all. The question shifts from “which framework is fastest?” to “which framework fits?”
Deployment Targets
As ML models move from cloud servers to edge devices, the efficiency gaps measured earlier transform from optimization opportunities into hard deployment constraints. The three core problems reweight dramatically at the edge. The execution problem shifts from “eager vs. graph” to “can we execute at all within 10 ms and 50 KB?” The differentiation problem often disappears entirely, since edge devices run inference only. The abstraction problem intensifies: targeting ARM vs. x86, mobile NPUs vs. edge TPUs, microcontrollers with kilobytes of memory.
Table 10 summarizes framework choices by deployment target:
| Environment | Primary Frameworks | Key Optimizations | Typical Constraints |
|---|---|---|---|
| Cloud/Server | PyTorch, TensorFlow, JAX | Distributed training, mixed precision | Throughput, cost |
| Edge | TensorFlow Lite, ONNX Runtime | Quantization (INT8), static graphs | Latency <10 ms, limited memory |
| Mobile | TF Lite, Core ML, PyTorch Mobile | NPU acceleration, model compression | Battery, thermal, app size limits |
| Microcontroller | TF Lite Micro, | 4-bit quantization, | <256 KB RAM, |
| (TinyML) | uTensor | static allocation | no dynamic memory |
Table 10 reveals a fragmented landscape: different deployment targets favor different frameworks. The Smart Doorbell KWS model from Section 1.2.6 exemplifies the Microcontroller tier, where TF Lite Micro’s extreme AOT compilation is the only viable path. This fragmentation creates a practical problem when organizations train in one framework but deploy on a target best served by another.
The Open Neural Network Exchange (ONNX)31 format addresses this fragmentation by enabling model portability across frameworks: train in PyTorch, deploy via TensorFlow Lite or ONNX Runtime. This standardization eliminates manual conversion when moving between development and production environments. Figure 14 captures this hub-and-spoke interoperability model—notice how ONNX sits at the center, accepting models from any training framework on the left and dispatching them to specialized runtimes on the right. Detailed deployment optimization (quantization, pruning, hardware-specific compilation) appears in Model Compression and Model Serving.
31 ONNX: The “fragmentation” ONNX addresses is that the best training framework (often PyTorch for research velocity) rarely matches the best serving runtime (often TensorRT for latency, TF Lite for mobile). ONNX defines a hardware-agnostic graph representation that decouples the two, eliminating the engineer-months of manual model conversion that would otherwise be required each time a deployment target changes. The accepted trade-off is that ONNX export can lose framework-specific optimizations or custom operators, requiring fallback implementations.
ONNX reduces the cost of framework fragmentation, but it does not eliminate the initial selection decision. With the deployment landscape mapped and interoperability options understood, we can now address the practical question: given a specific project’s requirements, the question becomes how an engineer should select a framework.
Framework Selection
Framework selection is a constrained optimization problem across technical capabilities, operational requirements, and organizational factors. No single framework dominates across all criteria, which means the goal is not to find the “best” framework but to find the one whose trade-offs align with the project’s constraints.
The framework selection trade-off space
Framework selection involves three interconnected tensions. The first is between development velocity and production performance: eager execution (PyTorch) prioritizes iteration speed, while graph compilation (TensorFlow/XLA, JAX/JIT) prioritizes runtime optimization. Research teams that need to test ten architecture variants per day cannot afford minutes of compilation between experiments; production teams that deploy a single model for months cannot afford the throughput penalty of eager dispatch. The optimal point shifts as a project moves through its lifecycle.
This velocity-performance tension leads directly to the second: flexibility vs. optimization depth. Dynamic graphs enable the arbitrary control flow that makes eager development fast, but they limit the compiler’s scope. Static graphs constrain expressiveness but enable aggressive fusion and hardware-specific code generation. As Table 3 demonstrated, this trade-off cascades through memory management, utilization, and debugging workflows—it is not a single design decision but a system-wide constraint.
The flexibility-optimization tension, in turn, exposes a third: ecosystem breadth vs. specialization. General-purpose frameworks cover broad operation sets but underperform specialized runtimes. TensorRT achieves 88 percent GPU utilization vs. PyTorch’s 32 percent (Table 9) precisely because it optimizes for a narrower problem. ONNX bridges this gap through standardized interchange, but the underlying trade-off remains: the more a runtime specializes, the faster it runs and the less it supports.
Systems Perspective 1.5: Framework Selection Constraints
The TensorFlow ecosystem illustrates how these axes interact concretely. Its three variants (TensorFlow, TensorFlow Lite, TensorFlow Lite Micro) trace a single design philosophy across progressively tighter constraints, a pattern that generalizes to any framework family. Table 11 quantifies the trade-offs.
| TensorFlow | TensorFlow Lite | TensorFlow Lite for Microcontrollers | |
|---|---|---|---|
| Training | Yes | No | No |
| Inference | Yes (but inefficient on edge) | Yes (and efficient) | Yes (and even more efficient) |
| How Many Ops | ~1400 | ~130 | ~50 |
| Native Quantization Tooling | No | Yes | Yes |
The principle is progressive constraint leading to progressive optimization: fewer supported operations enable smaller binaries, tighter memory budgets, and native quantization. Three dimensions structure this analysis: model requirements (what operations must the framework support?), software dependencies (what runtime environment is available?), and hardware constraints (what are the physical limits?).
Framework selection criteria
Three dimensions structure systematic framework evaluation: what the model requires (supported operations and graph semantics), what the software environment provides (OS, memory management, accelerator delegation), and what the hardware physically permits (compute, memory, power). Each dimension acts as a filter—hard constraints eliminate candidates, and soft preferences rank the survivors.
Model requirements
The first question is whether a framework can express the models a project requires. Examine Table 11: notice how operator count drops from approximately \(10^3\) (full TensorFlow) to \(10^2\) (TensorFlow Lite) to \(10^1\) (TensorFlow Lite Micro). Each reduction eliminates training capability and general-purpose operations while adding native quantization tooling. The engineering principle is that expressiveness and efficiency trade against each other: fewer supported operations enable tighter code generation, smaller binaries, and hardware-specific optimization paths. This progressive constraint model applies to any framework family, not just TensorFlow. The choice between dynamic and static computational graphs further shapes which optimizations each constraint level permits.
Systems Perspective 1.6: Dynamic vs. Static Computational Graphs
torch.compile, tf.function) to recover optimization potential.
Software dependencies
Once model requirements are satisfied, the framework must integrate with the target software environment. Table 12 reveals how operating system requirements, memory management, and accelerator support vary across TensorFlow variants.
| TensorFlow | TensorFlow Lite | TensorFlow Lite for Microcontrollers | |
|---|---|---|---|
| Needs an OS | Yes | Yes | No |
| Memory Mapping of Models | No | Yes | Yes |
| Delegation to accelerators | Yes | Yes | No |
The key distinctions follow the same progressive constraint pattern. TensorFlow Lite Micro eliminates the OS requirement entirely, enabling bare-metal execution on microcontrollers (though it integrates with RTOSes like FreeRTOS and Zephyr when available). Both Lite variants support memory-mapped model access from flash storage, avoiding the RAM overhead of loading full models. Accelerator delegation drops out at the microcontroller tier, where specialized hardware is rarely available. Each software dependency removed is a deployment target gained.
Hardware constraints
Software compatibility alone does not guarantee deployment; the framework must fit within physical hardware limits. Table 13 quantifies this final constraint dimension.
| TensorFlow | TensorFlow Lite | TensorFlow Lite for Microcontrollers | |
|---|---|---|---|
| Base Binary Size | A few MB (varies by platform and build configuration) | Tens to hundreds of KB | On the order of 10 KB |
| Base Memory Footprint | Several MB (minimum runtime overhead) | Hundreds of KB | Tens of KB |
| Optimized Architectures | X86, TPUs, GPUs | Arm Cortex A, x86 | Arm Cortex M, DSPs, MCUs |
Binary size spans three orders of magnitude: from MB (full TensorFlow) to tens of KB (TensorFlow Lite Micro). Memory footprint follows the same pattern. Processor architecture support shifts correspondingly from x86/GPU/TPU (data center) through Arm Cortex-A (mobile/edge) to Arm Cortex-M, DSPs, and MCUs (embedded). These are not arbitrary engineering tiers—they mirror the physical constraints (Light Barrier, power wall, memory wall) that carve the deployment spectrum into distinct paradigms (Physical Constraints: Why Paradigms Exist). The engineering lesson generalizes beyond TensorFlow: every framework family that spans deployment tiers makes analogous trade-offs between capability and resource footprint, and the framework’s job is to make those trade-offs navigable rather than invisible.
Production-ready evaluation factors
The engineering principle underlying production evaluation is that expressiveness and efficiency trade against each other: fewer supported operations enable tighter code generation, smaller binaries, and hardware-specific optimization paths. Technical specifications establish necessary but not sufficient conditions for selection. Production deployments also require evaluating operational factors: migration cost (typically three to six engineer-months for production systems), maintenance burden, and deployment success rates.
These hardware constraints cascade into performance trade-offs that are tightly coupled. Inference latency (tens of milliseconds for mobile image classification, sub-millisecond for industrial control), memory footprint (MB for full TensorFlow down to tens of KB for TF Lite Micro), power consumption (INT8 inference consuming several-fold less energy than FP32), and hardware utilization (operator fusion improving FLOPS utilization from 10–20 percent to 60–80 percent of peak) are not independent dimensions. Quantization simultaneously reduces memory, latency, and energy at the cost of precision, and framework selection determines which of these optimization levers are available in the first place. Scalability introduces a further concern: consistent deployment from microcontrollers to servers, smooth prototype-to-production transitions, and version management across deployed fleets all depend on the framework’s deployment toolchain. The three-dimension methodology illustrated here—model requirements, software dependencies, hardware constraints—applies to any framework ecosystem, not just TensorFlow.
Development support and long-term viability assessment
What determines whether a framework remains viable five years into a production deployment? Technical capabilities are necessary but not sufficient. Community composition shapes framework evolution in measurable ways: PyTorch’s academic community drives research-oriented features and reproducibility tools, though production tooling (PyTorch Lightning, TorchServe) has historically lagged; TensorFlow’s enterprise community emphasizes production reliability through TFX pipelines, TensorBoard visualization, and TensorFlow Model Analysis; JAX’s smaller community concentrates on mathematical rigor, producing specialized research tools (composable transformations, custom VJP rules) but with a steeper onboarding curve.
A framework’s practical utility, however, often depends more on its surrounding ecosystem than on its core capabilities. Hugging Face provides consistent model APIs across all three major frameworks, making pretrained model availability a near-commodity. Cross-framework tools (Weights & Biases, MLflow for experiment tracking; ONNX Runtime for serving) reduce lock-in, while framework-native tools (XLA, TorchScript, TensorFlow Serving) offer deeper optimization at the cost of portability. Cloud ML services (SageMaker, Google AI Platform, Azure ML) provide native integration for specific frameworks, creating operational advantages that compound over time.
These compounding effects make framework migration progressively harder. Integration with existing CI/CD pipelines, monitoring infrastructure, and cloud providers creates operational inertia that resists change. The measurable indicators of viability—contributor diversity (single-company dependence is a risk), backward compatibility track record, and hiring pool alignment with organizational needs—should be evaluated before commitment, not after. The mitigation strategy is defensive: use standardized formats (ONNX), maintain framework-agnostic data pipelines, and document framework-specific customizations to preserve future flexibility.
We have now examined the three core problems individually, compared how major frameworks resolve them, and established criteria for choosing among alternatives. What remains is to see all three problems interact inside a single execution—to watch the machinery we have studied operate as one integrated system.
Anatomy of a Training Step
The concepts developed throughout this chapter—eager vs. graph execution, reverse-mode autodiff, tensor abstractions, kernel dispatch—remain abstract until we see them interact inside a real execution. To solidify understanding, we trace a single training step through the PyTorch stack, revealing how eight lines of Python trigger the execution, differentiation, and abstraction machinery simultaneously.
Listing 43 presents a minimal training iteration for a two-layer multilayer perceptron. Though only eight lines of Python, this code exercises the entire framework stack: tensor allocation, kernel dispatch, autograd recording, gradient computation, and parameter updates. Tracing each phase reveals the three problems in action and connects the quantitative principles developed earlier to concrete execution.
# Single training step for a 2-layer MLP
x = torch.randn(32, 784, device="cuda") # Input batch
y = torch.randint(0, 10, (32,), device="cuda") # Labels
# Forward pass
h = torch.relu(x @ W1 + b1) # Hidden layer
logits = h @ W2 + b2 # Output layer
loss = F.cross_entropy(logits, y)
# Backward pass
loss.backward()
# Parameter update
optimizer.step()Phase 1: Forward pass (solving the execution problem)
When h = torch.relu(x @ W1 + b1) executes, PyTorch’s eager execution triggers immediate computation:
Python Dispatch (~1μs): Python interpreter calls
torch.matmul, which routes through PyTorch’s dispatcher to select the CUDA backend.Kernel Selection (~0.5μs): cuBLAS selects an optimized GEMM kernel based on matrix dimensions (32 \(\times\) 784 \(\times\) 784 \(\times\) 256). For these dimensions, it might choose a tiled algorithm optimized for L2 cache.
Kernel Launch (~5μs): The selected kernel is queued to the GPU’s command buffer. The CPU continues immediately (asynchronous execution).
GPU Execution (~15μs):
- Load W1 from HBM32 to L2 cache (~200 GB/s effective bandwidth)
- Perform matrix multiply in tensor cores (if available)
- Write result to HBM
32 HBM (High Bandwidth Memory): Provides 2–3 TB/s bandwidth on modern GPUs (introduced in Network Architectures). HBM bandwidth determines whether operations are memory bound or compute bound, and its 80 GB capacity on an A100 sets the hard ceiling on model size: weights, activations, gradients, and optimizer state must all fit simultaneously during training. When they do not, the framework must resort to offloading, checkpointing, or model parallelism, each adding complexity to what the programmer perceives as a single loss.backward() call.
- Autograd Recording: Simultaneously, PyTorch’s autograd engine records a
MmBackwardnode on the tape, storing references toxandW1for gradient computation.
The bias addition and ReLU follow similar patterns, each adding a node to the autograd tape.
Phase 2: Backward pass (solving the differentiation problem)
When loss.backward() executes:
Tape Traversal: The autograd engine traverses the recorded graph in reverse topological order.
Gradient Computation: For each node, it calls the registered backward function:
CrossEntropyBackward: Computes \(\frac{\partial \mathcal{L}}{\partial \text{logits}}\) using softmax derivativeMmBackward(W2): Computes \(\frac{\partial \mathcal{L}}{\partial W_2} = h^T \cdot \frac{\partial \mathcal{L}}{\partial \text{logits}}\) and \(\frac{\partial \mathcal{L}}{\partial h}\)ReluBackward: Applies ReLU derivative mask (zero where h ≤ 0)MmBackward(W1): Computes \(\frac{\partial \mathcal{L}}{\partial W_1}\) and \(\frac{\partial \mathcal{L}}{\partial x}\)
Gradient Accumulation: Gradients are accumulated into
.gradattributes of leaf tensors.Memory Management: After each backward node completes, its saved tensors are freed, allowing memory reuse.
Phase 3: Memory traffic analysis (the physics at work)
Applying the Dispatch Overhead Equation (Equation 4) to this step, Table 14 breaks down the FLOPs, memory traffic, and arithmetic intensity for each operation:
| Component | FLOPs | Memory Traffic | Arithmetic Intensity |
|---|---|---|---|
| MatMul (x @ W1) | 2\(\times\) 32\(\times\) \(784\times256\) = 12.8M | 0.9 MB | 13.7 |
| ReLU | \(32\times256\) = 8K | 66 KB | 0.125 |
| MatMul (h @ W2) | 2\(\times\) 32\(\times\) \(256\times10\) = 164K | 44 KB | 3.7 |
| Cross-entropy | ~1K | 3 KB | 0.4 |
| Backward (2\(\times\) forward) | ~26M | ~3.2 MB | ~8.0 |
Total: ~40M FLOPs, ~5MB memory traffic. On an A100:
- Tcompute ≈ 40M / 312 TFLOPS ≈ 0.1µs
- Tmemory ≈ 5MB / 2 TB/s ≈ 2.5µs
- Toverhead ≈ 6 ops\(\times\) 5 μs ≈ 30 μs
The training step is overhead-bound. For small models, Python dispatch and kernel launch dominate. This explains why:
torch.compileprovides 2–3\(\times\) speedup by fusing operations and reducing kernel launches- Batch size increases help amortize per-batch overhead
- Production training uses much larger models where compute dominates
Phase 4: Hardware abstraction (solving the abstraction problem)
The same Python code runs on different hardware through abstraction layers:
- CUDA GPU: cuBLAS GEMM kernels, CUDA streams for async execution
- CPU: Intel MKL or OpenBLAS, OpenMP for parallelism
- TPU: XLA compilation to TPU-specific HLO operations
- Apple Silicon: Metal Performance Shaders via MPS backend
Each backend implements the same tensor operations with hardware-specific optimizations. The framework’s abstraction layer (Section 1.4) ensures identical numerical results (within floating-point tolerance) across platforms. This is the abstraction problem solved: a single loss.backward() call triggers completely different code paths depending on hardware, yet produces mathematically equivalent gradients.
Systems Perspective 1.7: The Three Problems in Action
- Execution: Eager mode enables line-by-line debugging but incurs dispatch overhead
- Differentiation: Autograd tape records operations during forward, replays in reverse during backward
- Abstraction: Same code runs on GPU/CPU/TPU through backend-specific kernel implementations
Understanding this flow enables informed optimization: fuse operations to reduce overhead, use appropriate batch sizes, and match model scale to hardware capabilities.
This detailed trace through a single training step demonstrates how deeply the three core problems interact. Even simple code exercises the full framework stack, and seemingly minor decisions—device placement, batch size, compilation mode—cascade through execution, differentiation, and abstraction layers in ways that are difficult to predict without systems-level understanding. The following section catalogs the most common misconceptions that arise when engineers lack this understanding.
Fallacies and Pitfalls
Framework selection involves subtle trade-offs where intuitions from conventional software engineering fail. The memory wall, kernel fusion constraints, and deployment target diversity create pitfalls that waste months of engineering effort and cause production systems to miss latency targets by 10\(\times\) or more.
Fallacy: “All frameworks provide equivalent performance for the same model architecture.”
Engineers assume that ResNet-50 yields identical performance across frameworks since the mathematics is the same. In production, framework implementation matters enormously. Table 9 shows PyTorch achieves 52 ms inference at 32 percent hardware utilization while TensorRT delivers 3 ms at 88 percent utilization—a 17\(\times\) performance gap on identical hardware. The difference arises from kernel fusion depth, graph optimization strategies, and memory access patterns that vary dramatically between frameworks. Organizations that assume equivalence miss latency service level agreements (SLAs) and require costly last-minute framework migrations.
Pitfall: Choosing frameworks based on popularity rather than project requirements.
Engineers assume the most popular framework works for any project. In reality, deployment constraints dominate. Table 9 shows PyTorch Mobile requires 220 MB memory while TensorFlow Lite Micro runs in 32 KB—a 6875\(\times\) difference. Teams that prototype edge applications with PyTorch face either memory bloat that exceeds device capacity or two- to three-month framework migrations after development completes. Evaluate deployment targets per Section 1.7 before selecting a training framework.
Fallacy: “Framework abstractions eliminate the need for systems knowledge.”
Engineers assume high-level APIs handle all optimization automatically. The Roofline Model (The roofline model) proves otherwise: element-wise operations like ReLU achieve arithmetic intensity of 0.125 FLOPs/byte, using under 0.1 percent of an A100’s peak compute regardless of framework sophistication. Section 1.2.1 explains why: memory bandwidth, not compute, is the bottleneck for most operations. Engineers who lack this understanding leave 80–90 percent of hardware capacity unused, directly translating to 5–10\(\times\) higher inference costs at production scale.
Pitfall: Ignoring vendor lock-in from framework-specific formats.
Engineers assume framework migration is straightforward since models are “just math.” Converting TensorFlow SavedModel to PyTorch requires rewriting custom operations, validating numerical equivalence across 10,000+ test cases, and retraining when operations lack exact equivalents—typically three to six engineer-months for production systems. ONNX (Section 1.7) provides portability but supports only 80–85 percent of operations. Organizations that ignore this during initial framework selection face costly migrations when deployment requirements change or better frameworks emerge.
Pitfall: Selecting development frameworks without evaluating production infrastructure.
Engineers assume training framework choice is independent of deployment infrastructure. In practice, framework-infrastructure mismatches impose substantial operational overhead. TensorFlow Serving provides atomic model swaps with zero downtime; PyTorch deployments often require container restarts imposing thirty- to sixty-second outages. TensorFlow integrates natively with monitoring tools; PyTorch requires custom instrumentation adding two to four weeks of development. Per Section 1.6, evaluate the complete deployment stack during framework selection, including serving infrastructure, monitoring, and operational tooling.
Fallacy: “Increasing batch size is a free throughput optimization within framework memory limits.”
Engineers assume that if memory is available, larger batches always improve throughput. The Dispatch Overhead Equation (Equation 4) reveals hidden costs. A 7B parameter model in FP16 consumes 14 GB, leaving 66 GB on an A100-80 GB. Increasing batch size from 8 to 32 quadruples activation memory for transformers due to attention’s \(O(S^2)\) scaling, potentially triggering recomputation strategies that reduce throughput by 20–30 percent despite the larger batch. Teams that blindly maximize batch size often achieve lower throughput than smaller batches that avoid these memory management pathways.
Pitfall: Treating compilation overhead as negligible.
Engineers assume compilation overhead is a one-time cost that pays off quickly. Table 4 shows torch.compile achieves 48 percent higher ResNet-50 throughput but incurs fifteen to sixty seconds compilation overhead per graph change. For a 10,000-image experiment with 10 code changes: Eager completes in 6.9 seconds while Compiled requires 304.7 seconds (including 10 \(\times\) 30 s recompilation overhead). Teams that enable compilation during rapid prototyping waste hours waiting for recompilations that negate any throughput gains.
Summary
Machine learning frameworks exist to solve three fundamental problems that would otherwise make deep learning impractical:
The Execution Problem: When and how should computation happen? Frameworks navigate the trade-off between eager execution (immediate, debuggable, flexible) and graph execution (deferred, optimizable, deployable). Modern hybrid approaches like
torch.compileattempt to provide both flexibility during development and optimization for production.The Differentiation Problem: How do we compute gradients automatically? Frameworks implement reverse-mode automatic differentiation that computes exact gradients for arbitrary operation compositions. This transforms the mathematical chain rule into a software primitive, enabling training on billions of parameters with a single
loss.backward()call.The Abstraction Problem: How do we target diverse hardware from a single interface? Frameworks provide tensor abstractions, intermediate representations, and runtime systems that hide hardware complexity while enabling efficient utilization across CPUs, GPUs, TPUs, and specialized accelerators.
These problems are interconnected and constrained by the iron law of performance (Iron Law of ML Systems): execution strategy determines dispatch overhead (\(L_{\text{lat}}\)), differentiation determines memory traffic (\(D_{\text{vol}}\)), and abstraction determines hardware utilization (\(\eta\)). The memory wall makes data movement often more expensive than computation, explaining why frameworks invest in kernel fusion, activation checkpointing, mixed-precision training, and compilation pipelines.
Key Takeaways: The Layer Between Math and Hardware
- Three problems define every framework: Execution (how to run), differentiation (how to train), and abstraction (how to express). TensorFlow prioritizes abstraction for deployment breadth, PyTorch prioritizes execution for research velocity, and JAX reframes differentiation through composable function transformations. These are infrastructure commitments, not tooling preferences.
- The memory wall drives optimization: Compute has grown approximately 1000\(\times\) faster than memory bandwidth. Kernel fusion, activation checkpointing, mixed-precision training, and data layout optimizations all target the data movement term (\(D_{\text{vol}}\)) in the iron law, not the compute term.
- Compilation pays off only at scale: The Compilation Continuum principle (Equation 2) quantifies when compilation benefits exceed costs. Research prototyping favors eager mode; production training and inference favor progressive compilation from JIT to AOT. The Dispatch Overhead Law (Equation 4) explains why small models benefit disproportionately.
- The nn.Module pattern is widely adopted: Automatic parameter discovery, mode-dependent behavior, and hierarchical composition with serialization appear across major frameworks, enabling million-parameter optimization in a single
optimizer.step()call regardless of API syntax. - Framework choice constrains deployment by orders of magnitude: A 17\(\times\) latency gap (PyTorch vs. TensorRT) and 7,040\(\times\) memory gap (PyTorch Mobile vs. TFLite Micro) on identical models demonstrate that frameworks are not interchangeable. Deployment target must be evaluated before framework selection.
Understanding framework internals transforms how practitioners approach performance debugging and optimization. When a training job runs slower than expected, engineers who understand execution graphs can identify whether the bottleneck lies in eager-mode overhead, insufficient kernel fusion, or suboptimal memory layout. When deployment fails on target hardware, the compilation pipeline reveals whether the issue is operator support, quantization compatibility, or runtime configuration. This knowledge is essential for diagnosing and resolving performance issues in production systems.


