<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>Forem: Shah Fahad</title>
    <description>The latest articles on Forem by Shah Fahad (@sfahad).</description>
    <link>https://forem.com/sfahad</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3783516%2F5f595b19-bafe-4d58-85e5-0d83a323d253.jpg</url>
      <title>Forem: Shah Fahad</title>
      <link>https://forem.com/sfahad</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://forem.com/feed/sfahad"/>
    <language>en</language>
    <item>
      <title>CUDA Graphs in LLM Inference: Deep Dive</title>
      <dc:creator>Shah Fahad</dc:creator>
      <pubDate>Sat, 21 Feb 2026 07:09:21 +0000</pubDate>
      <link>https://forem.com/sfahad/cuda-graphs-in-llm-inference-deep-dive-36pb</link>
      <guid>https://forem.com/sfahad/cuda-graphs-in-llm-inference-deep-dive-36pb</guid>
      <description>&lt;h2&gt;
  
  
  Why CUDA Graphs Matter for LLM Inference
&lt;/h2&gt;

&lt;p&gt;LLM inference -- especially the token generation (decode) phase -- is &lt;strong&gt;often dominated by CPU overhead rather than GPU compute&lt;/strong&gt;. Each decode step generates a single token per sequence: the actual GPU work (small matmuls, attention over one query) can finish in microseconds, but the CPU can spend tens of microseconds &lt;em&gt;per kernel launch&lt;/em&gt; on launch bookkeeping, driver calls, and synchronization. With hundreds of kernel launches per transformer forward pass, this CPU overhead can become the bottleneck (though at higher batch sizes or with heavier kernels, decode can still become GPU-bound).&lt;/p&gt;

&lt;p&gt;Making matters worse, the CPU isn't just launching kernels -- it's also preparing data for the next batch: updating token IDs, managing the KV cache block table, running the scheduler, and handling request arrivals/completions. All of this competes for CPU time with kernel launches, amplifying the bottleneck. The GPU ends up sitting idle between launches, throughput drops, latency rises, and expensive GPU cycles are wasted on nothing.&lt;/p&gt;

&lt;p&gt;CUDA graphs solve this by &lt;strong&gt;recording the entire kernel sequence once&lt;/strong&gt; and &lt;strong&gt;replaying it with a single CPU call&lt;/strong&gt;. The driver overhead is paid once at capture time; every subsequent replay amortizes hundreds of per-kernel launches into a single replay launch, largely avoiding the repeated per-kernel launch bookkeeping. For decode-heavy workloads, this can eliminate the majority of per-step overhead.&lt;/p&gt;

&lt;p&gt;This post walks through how CUDA graphs work in the context of LLM serving -- why decode is a natural fit, why context/mixed batches are harder, and how TensorRT-LLM (TRT-LLM) implements both monolithic and piecewise CUDA graph strategies.&lt;/p&gt;




&lt;h2&gt;
  
  
  Table of Contents
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;1. CUDA Graphs Fundamentals&lt;/li&gt;
&lt;li&gt;2. Generation (Decode) CUDA Graphs&lt;/li&gt;
&lt;li&gt;3. KV Cache with Static Addresses&lt;/li&gt;
&lt;li&gt;4. Why Context &amp;amp; Mixed Batches Are Hard&lt;/li&gt;
&lt;li&gt;5. Piecewise CUDA Graphs (torch.compile)&lt;/li&gt;
&lt;li&gt;6. Configuration Guide&lt;/li&gt;
&lt;/ul&gt;




&lt;h2&gt;
  
  
  1. CUDA Graphs Fundamentals
&lt;/h2&gt;

&lt;p&gt;A CUDA graph captures a sequence of GPU operations (kernel launches, memory copies) into a single replayable unit.&lt;/p&gt;

&lt;h3&gt;
  
  
  What Gets Captured (Fixed)
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+--------------------------------------------------------------------+
| CUDA Graph Recording                                               |
|                                                                    |
| +----------+      +----------+      +----------+      +----------+ |
| | Kernel A |      | Kernel B |      | Kernel C |      | Kernel D | |
| |grid(4,1) |-----&amp;gt;|grid(8,1) |-----&amp;gt;|grid(4,1) |-----&amp;gt;|grid(2,1) | |
| |@0x100 -&amp;gt; |      |@0x200 -&amp;gt; |      |@0x300 -&amp;gt; |      |@0x400 -&amp;gt; | |
| |  0x200   |      |  0x300   |      |  0x400   |      |  0x500   | |
| +----------+      +----------+      +----------+      +----------+ |
+--------------------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Baked into the graph:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Which kernels to launch, in what order&lt;/li&gt;
&lt;li&gt;Memory addresses (pointers) each kernel reads/writes&lt;/li&gt;
&lt;li&gt;Kernel launch parameters (grid dims, block dims, shared memory)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;NOT baked (can change between replays):&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The actual data at those addresses&lt;/li&gt;
&lt;li&gt;Data-dependent control flow inside kernels (loops, branches)&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Replay Contract
&lt;/h3&gt;

&lt;p&gt;On replay, the entire sequence launches with minimal CPU overhead. The user's responsibility is to place correct data at the captured addresses before each replay.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why It's Fast
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+----------------------------+
| Without CUDA Graph (eager) |
|                            |
| CPU -- launch --&amp;gt; Kernel A |
| CPU &amp;lt;-- wait ----+         |
| CPU -- launch --&amp;gt; Kernel B |
| CPU &amp;lt;-- wait ----+         |
| CPU -- launch --&amp;gt; Kernel C |
| CPU &amp;lt;-- wait ----+         |
| CPU -- launch --&amp;gt; Kernel D |
|                            |
| = 4x CPU round-trips       |
+----------------------------+

+------------------------------------------+
| With CUDA Graph                          |
|                                          |
| CPU -- replay --&amp;gt; [ Kernel A, B, C, D ]  |
|                                          |
| = 1 launch, entire chain executes on GPU |
+------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  2. Generation (Decode) CUDA Graphs
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Why Decode Is Well-Suited
&lt;/h3&gt;

&lt;p&gt;In decode, each sequence contributes exactly &lt;strong&gt;1 new token&lt;/strong&gt; per step. Total tokens = batch size. This makes the input shape predictable.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+---------------------------------------------------------------+
| Decode step N                                                 |
|                                                               |
| seq0: 1 token  \                                              |
| seq1: 1 token   \                                             |
|                   &amp;gt;-- batch_size = 4, shape = [4, hidden_dim] |
| seq2: 1 token   /                                             |
| seq3: 1 token  /                                              |
+---------------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Pre-allocated Static Buffers
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+-----------------------------------------------------------------+
| Input token IDs buffer (pre-allocated, max_batch_size = 4096)   |
|                                                                 |
| [ token_0 ][ token_1 ][ token_2 ][ token_3 ] ... [ token_4095 ] |
|   @addr_0    @addr_1    @addr_2    @addr_3          @addr_4095  |
|                                                                 |
|   fixed addresses -- same every replay                          |
+-----------------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Multiple Graphs for Different Batch Sizes
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Captured graphs (one per supported batch size, typically powers of two):

  batch_size   grid size     reads
  ----------   ---------     -----
       1  --&amp;gt;  (1, ...)  --&amp;gt; addr_0
       2  --&amp;gt;  (2, ...)  --&amp;gt; addr_0..1
       4  --&amp;gt;  (4, ...)  --&amp;gt; addr_0..3
       8  --&amp;gt;  (8, ...)  --&amp;gt; addr_0..7
       :
    4096  --&amp;gt;  (4096,..) --&amp;gt; addr_0..4095
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;At runtime with 5 active sequences → use batch_size=8 graph, pad 3 dummy sequences.&lt;/p&gt;

&lt;h3&gt;
  
  
  Intermediate Activations Have Stable Addresses
&lt;/h3&gt;

&lt;p&gt;During capture, intermediate tensors are allocated from a graph-private memory pool, giving them stable device addresses:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+----------------------------------------------------------+
| Transformer layer (captured; all addresses fixed)        |
|                                                          |
| [QKV Projection] ----&amp;gt; [Attention] ----&amp;gt; [Output Proj]   |
|  in @A, out @B          in @B, out @C    in @C, out @D   |
|                                               |          |
|                                               v          |
| [FFN Layer 1] --------&amp;gt; [FFN Layer 2] ----&amp;gt; (next layer) |
|  in @D, out @E           in @E, out @F                   |
+----------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;On replay, the same chain executes at the same addresses. Intermediate buffers are never freed between replays -- they persist in the graph's memory pool. This is why &lt;strong&gt;each captured batch size has its own set of stable-address buffers&lt;/strong&gt;, and capturing many batch sizes consumes significant GPU memory.&lt;/p&gt;

&lt;h3&gt;
  
  
  What the Runtime Updates Before Each Replay
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+-----------------------------------------------------+
| 1. input_token_ids[0:B]  &amp;lt;-- new token IDs          |
| 2. position_ids[0:B]     &amp;lt;-- new positions          |
| 3. sequence_lengths[0:B] += 1                       |
| 4. block_table           &amp;lt;-- update if new KV block |
+-----------------------------------------------------+
| 5. &amp;gt;&amp;gt;&amp;gt; REPLAY GRAPH &amp;lt;&amp;lt;&amp;lt;                             |
+-----------------------------------------------------+
| 6. new_logits &amp;lt;-- output_buffer[0:B]                |
+-----------------------------------------------------+
| B = batch_size                                      |
+-----------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  3. KV Cache with Static Addresses
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Apparent Contradiction
&lt;/h3&gt;

&lt;p&gt;KV cache grows every step (new K,V written for each token), yet CUDA graphs require fixed addresses. The solution: &lt;strong&gt;paged/block-based KV cache&lt;/strong&gt; with an &lt;strong&gt;indirection table&lt;/strong&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Block-Based KV Cache Pool
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+-------------------------------------------------------------+
| KV cache pool (pre-allocated; addresses never change)       |
|                                                             |
| [ Block 0 ][ Block 1 ][ Block 2 ][ Block 3 ][ Block 4 ] ... |
|   @blk_0     @blk_1     @blk_2     @blk_3     @blk_4        |
|  32 slots   32 slots   32 slots   32 slots   32 slots       |
|                                                             |
| each block holds K,V for a fixed number of tokens (e.g. 32) |
+-------------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Block Table (Indirection)
&lt;/h3&gt;

&lt;p&gt;Each sequence has a block table mapping logical positions to physical blocks:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Logical positions&lt;/th&gt;
&lt;th&gt;Physical block&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;tokens 0–31&lt;/td&gt;
&lt;td&gt;Block 7&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;tokens 32–63&lt;/td&gt;
&lt;td&gt;Block 12&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;tokens 64–95&lt;/td&gt;
&lt;td&gt;Block 3 (partially filled, e.g. up to 82)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;em&gt;Sequence 0's block table at fixed address @tbl_0&lt;/em&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  How Attention Kernel Uses Indirection
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Inside the attention kernel (pseudo-code):
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;each&lt;/span&gt; &lt;span class="n"&gt;past&lt;/span&gt; &lt;span class="n"&gt;token&lt;/span&gt; &lt;span class="n"&gt;position&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;seq_id&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
    &lt;span class="n"&gt;block_idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;block_table&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;seq_id&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;    &lt;span class="c1"&gt;# read from @tbl_0
&lt;/span&gt;    &lt;span class="n"&gt;offset&lt;/span&gt;    &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;
    &lt;span class="n"&gt;K_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;kv_cache_pool&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;block_idx&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;offset&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;              &lt;span class="c1"&gt;# indirect lookup into pool
&lt;/span&gt;    &lt;span class="n"&gt;V_i&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;kv_cache_pool&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;block_idx&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="n"&gt;offset&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;score&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="nf"&gt;dot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Q&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;K_i&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Step-by-Step: How KV Cache Grows Within CUDA Graph
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Buffer&lt;/th&gt;
&lt;th&gt;Step N&lt;/th&gt;
&lt;th&gt;Step N+1&lt;/th&gt;
&lt;th&gt;Notes&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;block_table&lt;/code&gt; @tbl_0&lt;/td&gt;
&lt;td&gt;&lt;code&gt;[7, 12, 3]&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;[7, 12, 3]&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Same address, same indices&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;seq_length&lt;/code&gt; @len_0&lt;/td&gt;
&lt;td&gt;&lt;code&gt;82&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;83&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Same address, incremented&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;kv_pool Block 3, slot 18&lt;/td&gt;
&lt;td&gt;K,V for token 82&lt;/td&gt;
&lt;td&gt;K,V for token 82&lt;/td&gt;
&lt;td&gt;Unchanged&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;kv_pool Block 3, slot 19&lt;/td&gt;
&lt;td&gt;&lt;em&gt;(empty)&lt;/em&gt;&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;K,V for token 83&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;
&lt;strong&gt;NEW&lt;/strong&gt; — written by kernel&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The kernel wrote to a different slot because &lt;code&gt;sequence_length&lt;/code&gt; told it to. All addresses remain fixed -- only the data changes.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why This Doesn't Violate CUDA Graph Rules
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;What's fixed (baked in graph)&lt;/th&gt;
&lt;th&gt;What changes (data at fixed addrs)&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;kv_cache_pool&lt;/code&gt; base address&lt;/td&gt;
&lt;td&gt;Which blocks are assigned (block_table data)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;block_table&lt;/code&gt; buffer address&lt;/td&gt;
&lt;td&gt;The integer block indices&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;code&gt;sequence_length&lt;/code&gt; buffer address&lt;/td&gt;
&lt;td&gt;The actual length values&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Kernel grid dimensions&lt;/td&gt;
&lt;td&gt;Data-dependent loops inside kernel iterate more/fewer times&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;




&lt;h2&gt;
  
  
  4. Why Context &amp;amp; Mixed Batches Are Hard
&lt;/h2&gt;

&lt;h3&gt;
  
  
  The Core Problem: Variable Total Token Count
&lt;/h3&gt;

&lt;p&gt;In decode, total tokens = batch size (each sequence = 1 token). In context/mixed, total tokens varies wildly:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Batch type&lt;/th&gt;
&lt;th&gt;Sequences&lt;/th&gt;
&lt;th&gt;Total tokens&lt;/th&gt;
&lt;th&gt;Predictable?&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Decode&lt;/td&gt;
&lt;td&gt;&lt;code&gt;seq₀(1) + seq₁(1) + seq₂(1)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;3&lt;/td&gt;
&lt;td&gt;Yes — always = batch_size&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Context&lt;/td&gt;
&lt;td&gt;&lt;code&gt;seq₀(137) + seq₁(2048)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;2185&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Mixed&lt;/td&gt;
&lt;td&gt;&lt;code&gt;seq₀(512 prefill) + seq₁(1 decode)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;513&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  Problem 1: Kernel Grid Dimensions Depend on Total Tokens
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight cpp"&gt;&lt;code&gt;&lt;span class="c1"&gt;// Kernel launch -- grid dims are a function of input shape&lt;/span&gt;
&lt;span class="n"&gt;dim3&lt;/span&gt; &lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;total_tokens&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;TILE_M&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;TILE_M&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hidden_dim&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;TILE_N&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;TILE_N&lt;/span&gt;&lt;span class="p"&gt;);&lt;/span&gt;
&lt;span class="n"&gt;matmul_kernel&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;&amp;lt;&amp;lt;&lt;/span&gt;&lt;span class="n"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;block&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&amp;gt;&amp;gt;&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weight&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;total_tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;hidden_dim&lt;/span&gt;&lt;span class="p"&gt;);&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;total_tokens&lt;/th&gt;
&lt;th&gt;grid size&lt;/th&gt;
&lt;th&gt;Implication&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;512&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(4, …)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;4 blocks — one graph&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;3072&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(24, …)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;24 blocks — &lt;strong&gt;different&lt;/strong&gt; graph required&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The grid is baked at capture time. Different total tokens = different grid = different graph.&lt;/p&gt;

&lt;h3&gt;
  
  
  Problem 2: Attention Grid Depends on Max Context Seq Length and Num Context Requests
&lt;/h3&gt;

&lt;p&gt;For MLP, every token is independent: &lt;code&gt;output[i] = MLP(input[i])&lt;/code&gt;. Fix total_tokens and you're done.&lt;/p&gt;

&lt;p&gt;For attention, the kernel grid depends on &lt;strong&gt;two per-iteration variables&lt;/strong&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+--------------------------------------------------------------+
| TRT-LLM attention grid (simplified call chain)               |
|                                                              |
| Python (trtllm.py)                                           |
|   max_ctx_seq_len = seq_lens[:num_contexts].max()            |
|                             |                                |
|                             v                                |
| C++ (fmhaRunner / fused_multihead_attention_v2)              |
|   |                   |                   |                  |
|   v                   v                   v                  |
|   grid.x              grid.y              grid.z             |
|   ceil(s/unroll)      num_heads           num_ctx_requests   |
|   [VARIES]            [FIXED]             [VARIES]           |
|                                                              |
|   --&amp;gt; grid = ( ceil(s/unroll), num_heads, num_ctx_requests ) |
+--------------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Grid = &lt;code&gt;(ceil(max_ctx_seq_len / unroll_step), num_heads, num_context_requests)&lt;/code&gt;&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;TRT-LLM uses a &lt;strong&gt;padded tiling strategy&lt;/strong&gt;: the grid is sized for the longest context request, and shorter requests have their extra tiles skip computation (the kernel checks &lt;code&gt;cu_seqlens&lt;/code&gt; internally):&lt;/p&gt;

&lt;p&gt;&lt;em&gt;Padded tiling: 3 context requests, &lt;code&gt;seq_lens = [64, 128, 256]&lt;/code&gt;, &lt;code&gt;unroll_step = 64&lt;/code&gt;.&lt;br&gt;
Grid = &lt;code&gt;(4, num_heads, 3)&lt;/code&gt; — sized for longest request (256).&lt;/em&gt;&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;&lt;/th&gt;
&lt;th&gt;Tile 0&lt;/th&gt;
&lt;th&gt;Tile 1&lt;/th&gt;
&lt;th&gt;Tile 2&lt;/th&gt;
&lt;th&gt;Tile 3&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;strong&gt;Req 0&lt;/strong&gt; (64 tokens)&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;td&gt;skip&lt;/td&gt;
&lt;td&gt;skip&lt;/td&gt;
&lt;td&gt;skip&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;strong&gt;Req 1&lt;/strong&gt; (128 tokens)&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;td&gt;skip&lt;/td&gt;
&lt;td&gt;skip&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;strong&gt;Req 2&lt;/strong&gt; (256 tokens)&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;td&gt;compute&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;Even with this padded approach, the grid changes per iteration because &lt;strong&gt;both &lt;code&gt;max_ctx_seq_len&lt;/code&gt; and &lt;code&gt;num_context_requests&lt;/code&gt;&lt;/strong&gt; change depending on which requests the scheduler assigns to the context phase:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Iteration&lt;/th&gt;
&lt;th&gt;Context requests&lt;/th&gt;
&lt;th&gt;max_len&lt;/th&gt;
&lt;th&gt;grid&lt;/th&gt;
&lt;th&gt;What changed&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;1&lt;/td&gt;
&lt;td&gt;32&lt;/td&gt;
&lt;td&gt;128&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(2, heads, 32)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;2&lt;/td&gt;
&lt;td&gt;1&lt;/td&gt;
&lt;td&gt;128&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(2, heads, 1)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;grid.z&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;3&lt;/td&gt;
&lt;td&gt;2&lt;/td&gt;
&lt;td&gt;256&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(4, heads, 2)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;grid.x and z&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;em&gt;Different iterations produce different grids/launch parameters — the combination space explodes across multiple variables (e.g., &lt;code&gt;max_ctx_seq_len&lt;/code&gt;, &lt;code&gt;num_context_requests&lt;/code&gt;, and sequence-length distributions), making “one reusable CUDA graph” impractical.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;A CUDA graph captured with one grid would produce &lt;strong&gt;incorrect results&lt;/strong&gt; if replayed with a different grid/launch configuration (missing tiles = unprocessed tokens; extra tiles = out-of-bounds/garbage work). To make this safe, you’d need to capture graphs for many combinations or pad/standardize to a fixed worst-case launch shape.&lt;/p&gt;
&lt;h3&gt;
  
  
  Why Decode Attention Doesn't Have This Problem
&lt;/h3&gt;

&lt;p&gt;In decode, every sequence has exactly 1 query token. The decode attention uses a different kernel path where:&lt;/p&gt;

&lt;p&gt;Decode attention: &lt;code&gt;grid = (batch_size, num_heads)&lt;/code&gt; — both fixed per captured graph.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;batch_size&lt;/code&gt; is fixed per captured graph (one graph per supported batch size)&lt;/li&gt;
&lt;li&gt;Variable KV cache lengths are handled by data-dependent loops &lt;strong&gt;inside&lt;/strong&gt; the kernel (loop over &lt;code&gt;sequence_length[i]&lt;/code&gt;) -- the grid doesn't change&lt;/li&gt;
&lt;/ul&gt;
&lt;h3&gt;
  
  
  Where Each Layer Type Falls
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Layer&lt;/th&gt;
&lt;th&gt;Shape&lt;/th&gt;
&lt;th&gt;Capturable?&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Layer norm&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;[total_tokens, hidden]&lt;/code&gt; — flat&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Q, K, V projections&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;[total_tokens, hidden]&lt;/code&gt; — flat matmuls&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;
&lt;strong&gt;Fused attention&lt;/strong&gt; (Q@K^T, softmax, scores@V)&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;per-sequence, variable tiles&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;
&lt;strong&gt;No&lt;/strong&gt; — grid varies&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Output projection&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;[total_tokens, hidden]&lt;/code&gt; — flat matmul&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;MLP&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;[total_tokens, hidden]&lt;/code&gt; — flat matmuls&lt;/td&gt;
&lt;td&gt;Yes&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;


&lt;h2&gt;
  
  
  5. Piecewise CUDA Graphs (torch.compile)
&lt;/h2&gt;
&lt;h3&gt;
  
  
  Two Separate CUDA Graph Systems
&lt;/h3&gt;

&lt;p&gt;TRT-LLM uses &lt;strong&gt;two independent&lt;/strong&gt; CUDA graph systems -- understanding this distinction is critical:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;                  Python model forward()
                          |
            +-------------+-------------+
            |                           |
            v                           v
+-------------------------+ +-------------------------+
| torch.compile           | | Native CUDA Graph       |
| (Dynamo tracing)        | | (stream capture)        |
+-------------------------+ +-------------------------+
| Traces Python -&amp;gt; FX     | | Records GPU kernels     |
| Decomposes to ATen ops  | | on the CUDA stream      |
| Custom ops -&amp;gt; split pt  | | Captures everything     |
+-------------------------+ +-------------------------+
| Result: Pieces          | | Result: One monolithic  |
| [graph][eager][graph]...| | graph of full fwd pass  |
+-------------------------+ +-------------------------+
            |                           |
            v                           v
  Used for: mixed/context    Used for: decode-only
  (attn grid varies)         (attn grid fixed)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Generation-only (decode)&lt;/strong&gt;: Uses &lt;strong&gt;native &lt;code&gt;torch.cuda.CUDAGraph&lt;/code&gt;&lt;/strong&gt; capture. This records every kernel launch on the CUDA stream at the driver level -- including FlashAttention. It doesn't need to "understand" the kernels; it just records them. This works because decode attention's grid depends only on &lt;code&gt;batch_size&lt;/code&gt; (fixed per capture).&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Piecewise (mixed/context)&lt;/strong&gt;: Uses &lt;strong&gt;torch.compile&lt;/strong&gt; to trace the model into an FX graph, then TRT-LLM's custom backend splits at attention boundaries and captures each non-attention piece as a CUDA graph. Attention runs eagerly.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Piecewise Architecture
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;+--------------------------------------------------------+
| CUDA GRAPH -- piece 1                     [captured]   |
|   layer_norm -&amp;gt; qkv_projection                         |
|   pre-allocates output buffer @ addr_X                 |
+--------------------------------------------------------+
|                         |                              |
|                         v                              |
+--------------------------------------------------------+
| EAGER -- not graphed                 [runs every time] |
|   flash_attention(q, k, v, cu_seqlens, ...)            |
|   writes result IN-PLACE to addr_X                     |
+--------------------------------------------------------+
|                         |                              |
|                         v                              |
+--------------------------------------------------------+
| CUDA GRAPH -- piece 2                     [captured]   |
|   reads from addr_X                                    |
|   output_proj -&amp;gt; layer_norm -&amp;gt; mlp_up -&amp;gt;               |
|   activation -&amp;gt; mlp_down -&amp;gt; residual_add               |
+--------------------------------------------------------+
|                         |                              |
|                         v                              |
|                 ... next layer ...                     |
+--------------------------------------------------------+
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The in-place attention design is critical: attention writes into a buffer pre-allocated by piece 1, ensuring piece 2's captured graph reads from the correct fixed address.&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Attention Is Excluded
&lt;/h3&gt;

&lt;p&gt;Attention is excluded from CUDA graph capture for a &lt;strong&gt;correctness&lt;/strong&gt; reason, not a tracing limitation.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;The tracing works fine.&lt;/strong&gt; TRT-LLM registers a FakeTensor implementation for the attention custom op, so &lt;code&gt;torch.compile&lt;/code&gt; in fullgraph mode traces the entire forward pass into one FX graph without graph breaks.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;The exclusion is a deliberate choice.&lt;/strong&gt; TRT-LLM's &lt;code&gt;piecewise_optimizer.py&lt;/code&gt; explicitly identifies attention ops and excludes them from CUDA graph pieces:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# tensorrt_llm/_torch/compilation/piecewise_optimizer.py
&lt;/span&gt;&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;is_call_function&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;node&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ops&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;trtllm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;attn_custom_op_inplace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;default&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ops&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;trtllm&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mla_custom_op_inplace&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;default&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]):&lt;/span&gt;
    &lt;span class="n"&gt;exclude_modules_id&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# ← excluded from CUDA graph capture
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The reason: replay correctness.&lt;/strong&gt; If attention were captured in a CUDA graph, the kernel's grid dimensions would be baked in. But attention's grid depends on the per-sequence query distribution, not just total tokens:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Kernel source&lt;/th&gt;
&lt;th&gt;grid.x&lt;/th&gt;
&lt;th&gt;grid.y&lt;/th&gt;
&lt;th&gt;grid.z&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;fused_multihead_attention_v2.cpp&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;ceil(params.s / mUnrollStep)&lt;/code&gt; — &lt;strong&gt;varies&lt;/strong&gt;
&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;params.h&lt;/code&gt; (heads) — fixed&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;params.b&lt;/code&gt; (batch) — &lt;strong&gt;varies&lt;/strong&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;triton_attention.py&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;num_prefill&lt;/code&gt; — &lt;strong&gt;varies&lt;/strong&gt;
&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;n_heads&lt;/code&gt; — fixed&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;ceil(max(seq_len) / SEQ_BLOCK)&lt;/code&gt; — &lt;strong&gt;varies&lt;/strong&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;unfusedAttentionKernels.cu&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;ceil(q_length / 32.0f)&lt;/code&gt; — &lt;strong&gt;varies&lt;/strong&gt;
&lt;/td&gt;
&lt;td&gt;&lt;/td&gt;
&lt;td&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;For the same &lt;code&gt;total_tokens=4096&lt;/code&gt;, different sequence distributions can produce different grids/launch metadata. A captured graph replays the capture-time launch configuration; unless you pad/standardize to that same configuration, replaying on a different distribution would be incorrect. MLP doesn't have this problem because its grid depends primarily on &lt;code&gt;total_tokens&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  What &lt;code&gt;capture_num_tokens&lt;/code&gt; Controls
&lt;/h3&gt;

&lt;p&gt;Pre-captures piecewise graphs for specific total token counts. At runtime, pads &lt;strong&gt;up&lt;/strong&gt; to the next captured value.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;capture_num_tokens: [1, 2, 4, 8, ..., 8192]

Runtime: 4160 total tokens → pad up to the next captured value (e.g., 5120)
  - Waste: (5120 - 4160) / 5120 = 18.7% extra compute
  - Benefit: CUDA graph replay for MLP pieces (zero launch overhead)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Graph Type Summary
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Graph Type&lt;/th&gt;
&lt;th&gt;Capture Mechanism&lt;/th&gt;
&lt;th&gt;What It Captures&lt;/th&gt;
&lt;th&gt;When Used&lt;/th&gt;
&lt;th&gt;Key Parameter&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Generation-only&lt;/td&gt;
&lt;td&gt;Native &lt;code&gt;torch.cuda.CUDAGraph&lt;/code&gt;
&lt;/td&gt;
&lt;td&gt;Full forward pass (including attention)&lt;/td&gt;
&lt;td&gt;Pure decode iterations&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;cuda_graph_config.batch_sizes&lt;/code&gt; or &lt;code&gt;max_batch_size&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Piecewise&lt;/td&gt;
&lt;td&gt;torch.compile + native capture per piece&lt;/td&gt;
&lt;td&gt;All non-attention ops (attention runs eager)&lt;/td&gt;
&lt;td&gt;Mixed/context iterations&lt;/td&gt;
&lt;td&gt;&lt;code&gt;torch_compile_config.capture_num_tokens&lt;/code&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  Memory vs. Coverage Trade-off
&lt;/h3&gt;

&lt;p&gt;Each piecewise capture at token count N pre-allocates intermediate buffers of size &lt;code&gt;[N, hidden_dim]&lt;/code&gt; per piece per layer. Capturing at large N (e.g., 8192) can consume enough GPU memory to shrink KV cache capacity below usable levels. In some setups, pushing &lt;code&gt;capture_num_tokens&lt;/code&gt; too high (e.g., up to 8192) with aggressive &lt;code&gt;kv_cache_free_gpu_mem_fraction&lt;/code&gt; can shrink the KV cache max length enough to cause warmup failures.&lt;/p&gt;




&lt;h2&gt;
  
  
  6. Configuration Guide
&lt;/h2&gt;

&lt;h3&gt;
  
  
  TensorRT-LLM &lt;code&gt;llm_api_options_yaml&lt;/code&gt; Settings
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight yaml"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Generation-only CUDA graphs (decode phase)&lt;/span&gt;
&lt;span class="na"&gt;cuda_graph_config&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt;
  &lt;span class="na"&gt;enable_padding&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt; &lt;span class="kc"&gt;true&lt;/span&gt;
  &lt;span class="na"&gt;max_batch_size&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt; &lt;span class="m"&gt;4096&lt;/span&gt;    &lt;span class="c1"&gt;# or explicit batch_sizes list&lt;/span&gt;

&lt;span class="c1"&gt;# Piecewise CUDA graphs (context/mixed phases, requires torch.compile)&lt;/span&gt;
&lt;span class="na"&gt;torch_compile_config&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt;
  &lt;span class="na"&gt;enable_piecewise_cuda_graph&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt; &lt;span class="kc"&gt;true&lt;/span&gt;
  &lt;span class="na"&gt;capture_num_tokens&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt; &lt;span class="pi"&gt;[&lt;/span&gt;&lt;span class="nv"&gt;1&lt;/span&gt;&lt;span class="pi"&gt;,&lt;/span&gt; &lt;span class="nv"&gt;2&lt;/span&gt;&lt;span class="pi"&gt;,&lt;/span&gt; &lt;span class="nv"&gt;4&lt;/span&gt;&lt;span class="pi"&gt;,&lt;/span&gt; &lt;span class="nv"&gt;...&lt;/span&gt;&lt;span class="pi"&gt;]&lt;/span&gt;   &lt;span class="c1"&gt;# Must cover runtime max_num_tokens!&lt;/span&gt;
  &lt;span class="na"&gt;enable_userbuffers&lt;/span&gt;&lt;span class="pi"&gt;:&lt;/span&gt; &lt;span class="kc"&gt;false&lt;/span&gt;             &lt;span class="c1"&gt;# Default is true; disable if needed&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Key Principles for &lt;code&gt;capture_num_tokens&lt;/code&gt;
&lt;/h3&gt;

&lt;ol&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Must cover &lt;code&gt;max_num_tokens&lt;/code&gt;&lt;/strong&gt;: If the runtime scheduler can produce up to N total tokens, the largest capture point must be &amp;gt;= N. Otherwise, iterations exceeding the max fall back to eager.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Dense where iterations cluster&lt;/strong&gt;: Use iteration logs to find the hot zone. Pack capture points there to minimize padding waste.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Sparse where few iterations land&lt;/strong&gt;: Ramp-up and transition regions need minimal captures (powers of 2 suffice).&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Fewer captures = less memory&lt;/strong&gt;: Each capture pre-allocates intermediate buffers sized &lt;code&gt;[capture_tokens, hidden_dim]&lt;/code&gt; per piece. On memory-constrained systems, fewer large captures may be preferable.&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  TorchCompileConfig Defaults (TensorRT-LLM)
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Field&lt;/th&gt;
&lt;th&gt;Default&lt;/th&gt;
&lt;th&gt;Notes&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;torch_compile_config&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;None&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Torch compile completely off unless explicitly set&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;enable_piecewise_cuda_graph&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;False&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Must opt-in&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;capture_num_tokens&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;code&gt;None&lt;/code&gt; (auto: max 3072)&lt;/td&gt;
&lt;td&gt;Auto-generated: &lt;code&gt;[1,2,4,...,128,256,512,...,3072]&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;enable_userbuffers&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;True&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Enabled by default when torch compile is on&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;enable_fullgraph&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;True&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Full graph compilation in torch.compile&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;enable_inductor&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;False&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Inductor backend disabled by default&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  Checking Coverage at Runtime
&lt;/h3&gt;

&lt;p&gt;Parse the iteration log and compute:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;total_tokens_per_iter = numCtxTokens + numGenRequests

For each iteration:
  - If numCtxTokens == 0: uses generation-only CUDA graph (match on numGenRequests)
  - If numCtxTokens &amp;gt; 0:  uses piecewise CUDA graph (match on total_tokens)

Hit rate = iterations with total_tokens &amp;lt;= max(capture_num_tokens) / total iterations
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Target: &lt;strong&gt;&amp;gt;95% hit rate&lt;/strong&gt; on piecewise graphs for meaningful benefit.&lt;/p&gt;

</description>
      <category>ai</category>
      <category>deeplearning</category>
      <category>llm</category>
      <category>performance</category>
    </item>
  </channel>
</rss>
