<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>Forem: Ahmed Elnaggar</title>
    <description>The latest articles on Forem by Ahmed Elnaggar (@ahmed_elnaggar).</description>
    <link>https://forem.com/ahmed_elnaggar</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3700950%2Ffcc3958d-302e-4d73-90cb-fee4604631b2.jpg</url>
      <title>Forem: Ahmed Elnaggar</title>
      <link>https://forem.com/ahmed_elnaggar</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://forem.com/feed/ahmed_elnaggar"/>
    <language>en</language>
    <item>
      <title>Fine-Tune Any HuggingFace Model like Gemma on TPUs with TorchAX</title>
      <dc:creator>Ahmed Elnaggar</dc:creator>
      <pubDate>Mon, 27 Apr 2026 08:45:25 +0000</pubDate>
      <link>https://forem.com/gde/fine-tune-any-huggingface-model-like-gemma-on-tpus-with-torchax-5g21</link>
      <guid>https://forem.com/gde/fine-tune-any-huggingface-model-like-gemma-on-tpus-with-torchax-5g21</guid>
      <description>&lt;h2&gt;
  
  
  What if you could fine-tune any HuggingFace model on TPUs — using PyTorch code?
&lt;/h2&gt;

&lt;p&gt;Here is what the end result looks like:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax.train&lt;/span&gt;

&lt;span class="c1"&gt;# One function: forward → loss → gradients → optimizer update
&lt;/span&gt;&lt;span class="n"&gt;step_fn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make_train_step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_fn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Training loop
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;step_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;buffers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;labels&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Your PyTorch model. JAX's training primitives. Running on TPU. No rewrite needed.&lt;/p&gt;

&lt;p&gt;In the &lt;a href="https://dev.to/gde/run-any-huggingface-model-on-tpus-a-beginners-guide-to-torchax-4ln0"&gt;first part of this series&lt;/a&gt;, we ran HuggingFace models on JAX for fast inference. Now we take the next step: &lt;strong&gt;training&lt;/strong&gt;. We will instruction-tune Gemma 3 1B on the Databricks Dolly 15k dataset using LoRA and torchax's functional training API — all on a free Colab TPU.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_training_tutorial.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open Full Tutorial In Colab" width="117" height="20"&gt;&lt;/a&gt; &lt;a href="https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_training_quickstart.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open Quick Start In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;




&lt;h2&gt;
  
  
  Why Train on TPUs?
&lt;/h2&gt;

&lt;p&gt;Google's Tensor Processing Units (TPUs) are purpose-built for matrix operations — the bread and butter of deep learning. Free Colab gives you access to a TPU v2-8 with ~15GB of high-bandwidth memory. That is enough to fine-tune a 1B parameter model with LoRA.&lt;/p&gt;

&lt;p&gt;But training on TPUs traditionally meant rewriting your model in JAX (Flax, Equinox) or using PyTorch/XLA. &lt;strong&gt;torchax&lt;/strong&gt; offers a third path: keep your PyTorch model, but use JAX's functional training primitives.&lt;/p&gt;

&lt;h3&gt;
  
  
  How torchax Training Differs from Standard PyTorch
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Standard PyTorch&lt;/th&gt;
&lt;th&gt;torchax&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;loss.backward()&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;jax.value_and_grad(loss_fn)(params, ...)&lt;/code&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;optimizer.step()&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;optax.apply_updates(params, updates)&lt;/code&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Model holds its own state&lt;/td&gt;
&lt;td&gt;Params and buffers are separate pytrees&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Eager execution&lt;/td&gt;
&lt;td&gt;JIT-compiled training steps&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The key difference: &lt;strong&gt;functional training&lt;/strong&gt;. Instead of calling &lt;code&gt;loss.backward()&lt;/code&gt; and &lt;code&gt;optimizer.step()&lt;/code&gt; on a stateful model, torchax separates the model into immutable weight pytrees and passes them through pure functions. This is what enables JAX's &lt;code&gt;jax.jit&lt;/code&gt; to compile the entire training step into a single optimized program.&lt;/p&gt;




&lt;h2&gt;
  
  
  Prerequisites &amp;amp; Setup
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;What you need:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Python 3.10+&lt;/li&gt;
&lt;li&gt;Basic familiarity with PyTorch and HuggingFace transformers&lt;/li&gt;
&lt;li&gt;A Google Colab account (free tier works with LoRA)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Zero-setup option:&lt;/strong&gt; Click the Colab badge above. The notebook handles all installation automatically.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Local setup:&lt;/strong&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="c"&gt;# PyTorch CPU (torchax handles the accelerator via JAX)&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;torch &lt;span class="nt"&gt;--index-url&lt;/span&gt; https://download.pytorch.org/whl/cpu

&lt;span class="c"&gt;# JAX + all training dependencies in a single pip call&lt;/span&gt;
pip &lt;span class="nb"&gt;install&lt;/span&gt; &lt;span class="nt"&gt;-U&lt;/span&gt; &lt;span class="s1"&gt;'jax[tpu]'&lt;/span&gt; torchax transformers flax peft datasets optax   &lt;span class="c"&gt;# TPU&lt;/span&gt;
&lt;span class="c"&gt;# pip install -U 'jax[cuda12]' torchax transformers flax peft datasets optax  # GPU&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Colab note:&lt;/strong&gt; The notebook installs packages and automatically restarts the runtime, since Colab pre-loads an older JAX that stays cached in memory until restart.&lt;/p&gt;
&lt;/blockquote&gt;




&lt;h2&gt;
  
  
  Key Concepts for Training
&lt;/h2&gt;

&lt;p&gt;Before writing code, let's understand the four concepts that make torchax training work.&lt;/p&gt;

&lt;h3&gt;
  
  
  1. Param/Buffer Separation
&lt;/h3&gt;

&lt;p&gt;JAX's &lt;code&gt;jax.value_and_grad&lt;/code&gt; needs to know &lt;em&gt;which&lt;/em&gt; inputs to differentiate. In standard PyTorch, the model owns its weights. In torchax training, we explicitly separate:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;params&lt;/strong&gt; — trainable parameters (get gradients)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;buffers&lt;/strong&gt; — everything else (frozen weights, running stats, constants)
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;named_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;frozen&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;named_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;buffers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;named_buffers&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;buffers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;frozen&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;For LoRA, &lt;code&gt;params&lt;/code&gt; contains only the tiny adapter weights (~0.5% of the model). For full fine-tuning, it contains everything.&lt;/p&gt;

&lt;h3&gt;
  
  
  2. optax Optimizers
&lt;/h3&gt;

&lt;p&gt;Unlike PyTorch optimizers (which carry hidden mutable state), optax optimizers are &lt;strong&gt;pure functions&lt;/strong&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# PyTorch: hidden state inside optimizer
&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="c1"&gt;# optax: explicit state, no hidden pockets
&lt;/span&gt;&lt;span class="n"&gt;updates&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;new_opt_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;grads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;new_params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;apply_updates&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;updates&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This functional design means the optimizer state is just another pytree that flows through the training step — perfect for &lt;code&gt;jax.jit&lt;/code&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  3. make_train_step
&lt;/h3&gt;

&lt;p&gt;&lt;code&gt;torchax.train.make_train_step()&lt;/code&gt; is the central API. It composes three pieces into a single JIT-compilable function:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;model_fn&lt;/strong&gt; — a pure function: &lt;code&gt;(weights, buffers, batch) → output&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;loss_fn&lt;/strong&gt; — extracts the scalar loss: &lt;code&gt;(output, labels) → loss&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;optimizer&lt;/strong&gt; — an optax optimizer&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The result is &lt;code&gt;step_fn(params, buffers, opt_state, batch, labels) → (loss, new_params, new_opt_state)&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Under the hood, this uses &lt;code&gt;jax.value_and_grad&lt;/code&gt; for efficient gradient computation and &lt;code&gt;optax.apply_updates&lt;/code&gt; for weight updates — all compiled into a single XLA program.&lt;/p&gt;

&lt;h3&gt;
  
  
  4. Full Fine-Tuning vs LoRA
&lt;/h3&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;&lt;/th&gt;
&lt;th&gt;Full Fine-Tuning&lt;/th&gt;
&lt;th&gt;LoRA&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Trainable params&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;All (~2B)&lt;/td&gt;
&lt;td&gt;Tiny adapters (~0.5%)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Memory&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;~18-20 GB&lt;/td&gt;
&lt;td&gt;~5-7 GB&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Speed&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Slower&lt;/td&gt;
&lt;td&gt;Faster&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Quality&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Higher ceiling&lt;/td&gt;
&lt;td&gt;Nearly as good&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Free Colab TPU&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Tight / may OOM&lt;/td&gt;
&lt;td&gt;Fits comfortably&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;LoRA&lt;/strong&gt; (Low-Rank Adaptation) freezes the base model and adds small trainable matrices to attention layers. Instead of updating the full weight matrix W, it learns a low-rank decomposition: &lt;code&gt;W + (α/r) × B·A&lt;/code&gt; where A and B are tiny matrices.&lt;/p&gt;

&lt;p&gt;For free Colab, LoRA is the recommended path.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 1: Load and Prepare the Dataset
&lt;/h2&gt;

&lt;p&gt;We use &lt;a href="https://huggingface.co/datasets/databricks/databricks-dolly-15k" rel="noopener noreferrer"&gt;Databricks Dolly 15k&lt;/a&gt; — 15,000 human-written instruction-response pairs across 7 categories (QA, summarization, brainstorming, etc.).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;datasets&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;hf_datasets&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;

&lt;span class="n"&gt;MODEL_NAME&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-1b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;DATASET_NAME&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;databricks/databricks-dolly-15k&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;MODEL_NAME&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eos_token&lt;/span&gt;

&lt;span class="n"&gt;raw_dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hf_datasets&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;DATASET_NAME&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Each example has an &lt;code&gt;instruction&lt;/code&gt;, optional &lt;code&gt;context&lt;/code&gt;, &lt;code&gt;response&lt;/code&gt;, and &lt;code&gt;category&lt;/code&gt;. We format these into Gemma's chat template:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;format_example&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;user_content&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;instruction&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;context&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;""&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;user_content&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="se"&gt;\n\n&lt;/span&gt;&lt;span class="s"&gt;Context: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;context&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

    &lt;span class="n"&gt;messages&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
        &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;user_content&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt;
        &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;assistant&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;response&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]},&lt;/span&gt;
    &lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;apply_chat_template&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;text&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Then tokenize and create dataloaders:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;DataCollatorForLanguageModeling&lt;/span&gt;

&lt;span class="c1"&gt;# Subset, split, tokenize
&lt;/span&gt;&lt;span class="n"&gt;subset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;raw_dataset&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;shuffle&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;select&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2200&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;split&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;subset&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train_test_split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;test_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seed&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;tokenize_example&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;formatted&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;format_example&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;example&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;formatted&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;text&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;padding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;max_length&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_length&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;512&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;truncation&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;train_tokenized&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenize_example&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;remove_columns&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;column_names&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;eval_tokenized&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;map&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenize_example&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;remove_columns&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;column_names&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;collator&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataCollatorForLanguageModeling&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mlm&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;train_dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_tokenized&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;collator&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;eval_dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;eval_tokenized&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;shuffle&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;collate_fn&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;collator&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 2: Load the Model and Apply LoRA
&lt;/h2&gt;

&lt;p&gt;Here is where the torchax pattern matters: load the model with torchax &lt;strong&gt;disabled&lt;/strong&gt;, then enable it before moving to JAX.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;peft&lt;/span&gt;

&lt;span class="c1"&gt;# Load model with torchax disabled to avoid intercepting init ops
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;disable_temporarily&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;MODEL_NAME&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bfloat16&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Sync pad_token_id so loss computation properly ignores padding
&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;pad_token_id&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Why disable?&lt;/strong&gt; HuggingFace model initialization uses operations (like in-place tensor filling) that torchax does not support. Disabling torchax during loading keeps everything on CPU, then we move to JAX after.&lt;/p&gt;

&lt;p&gt;Now apply LoRA:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;peft_config&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;peft&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;LoraConfig&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;task_type&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;peft&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;TaskType&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CAUSAL_LM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;inference_mode&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;r&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                             &lt;span class="c1"&gt;# Rank of the LoRA matrices
&lt;/span&gt;    &lt;span class="n"&gt;lora_alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                   &lt;span class="c1"&gt;# Scaling factor
&lt;/span&gt;    &lt;span class="n"&gt;lora_dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                &lt;span class="c1"&gt;# 0.0 for bfloat16 numerical stability
&lt;/span&gt;    &lt;span class="n"&gt;target_modules&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;q_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;k_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;v_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;o_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt;  &lt;span class="c1"&gt;# All attention layers
&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;peft&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_peft_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;peft_config&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;print_trainable_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="c1"&gt;# Output: trainable params: 5,767,168 || all params: 2,619,206,656 || trainable%: 0.22%
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Only 0.22% of parameters are trainable — that is the power of LoRA.&lt;/p&gt;

&lt;p&gt;Finally, enable torchax and move to the JAX device:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;enable_accuracy_mode&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Float32 accumulation for bfloat16 stability
&lt;/span&gt;&lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;enable_globally&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 3: Baseline Evaluation
&lt;/h2&gt;

&lt;p&gt;Before training, we measure the model's performance to compare against later:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;evaluate_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batches&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;total_loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;total_batches&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;&amp;gt;=&lt;/span&gt; &lt;span class="n"&gt;max_batches&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="k"&gt;break&lt;/span&gt;
            &lt;span class="c1"&gt;# Drop attention_mask — Gemma's sliding window attention produces NaN
&lt;/span&gt;            &lt;span class="c1"&gt;# with padded masks on torchax/JAX. Labels already mask padding with -100.
&lt;/span&gt;            &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;attention_mask&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
            &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;total_batches&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;avg_loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;total_loss&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="nf"&gt;max&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;total_batches&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;avg_loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;math&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;min&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;avg_loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;baseline_loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;baseline_ppl&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eval_dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Baseline loss: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;baseline_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;, perplexity: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;baseline_ppl&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We also generate sample responses for qualitative comparison. For fast generation, we register &lt;code&gt;StaticCache&lt;/code&gt; as a JAX pytree and use KV-cached decoding — only the new token is processed each step instead of the full sequence (~50x faster):&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers.cache_utils&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;StaticCache&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;jax.tree_util&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;register_pytree_node&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_flatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;device&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dtype&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_unflatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dev&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;aux&lt;/span&gt;
    &lt;span class="n"&gt;kwargs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;dev&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;device&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dev&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dtype&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;
    &lt;span class="n"&gt;sc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;sc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sc&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;sc&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_flatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_unflatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The generation function uses prefill (process full prompt) then per-token decode with the cache and a &lt;code&gt;tqdm&lt;/code&gt; progress bar:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tqdm.auto&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate_response&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;instruction&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;messages&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;role&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;user&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;content&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;instruction&lt;/span&gt;&lt;span class="p"&gt;}]&lt;/span&gt;
    &lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;apply_chat_template&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;messages&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;add_generation_prompt&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;seq_len&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="n"&gt;kv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                     &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seq_len&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                     &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bfloat16&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="c1"&gt;# Prefill: process full prompt, populate cache
&lt;/span&gt;        &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                           &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;tok&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;generated&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;tok&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()]&lt;/span&gt;
        &lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;seq_len&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Decode: one token at a time using cached keys/values
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_new_tokens&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;desc&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Generating&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;leave&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kv&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tok&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;pos&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;kv&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                               &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;tok&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
            &lt;span class="n"&gt;tid&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tok&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;tid&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eos_token_id&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
                &lt;span class="k"&gt;break&lt;/span&gt;
            &lt;span class="n"&gt;generated&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tid&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;pos&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;generated&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;skip_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 4: Set Up Functional Training
&lt;/h2&gt;

&lt;p&gt;This is where torchax diverges from standard PyTorch. We separate the model, create an optax optimizer, and compose everything into a JIT-compiled training step.&lt;/p&gt;

&lt;h3&gt;
  
  
  Separate params and buffers
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax.train&lt;/span&gt;

&lt;span class="n"&gt;params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;named_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;buffers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;dict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;named_buffers&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;frozen_params&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;named_parameters&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;buffers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;update&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;frozen_params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Create the optimizer
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;schedule&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;warmup_cosine_decay_schedule&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;init_value&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;peak_value&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;warmup_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;decay_steps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;500&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;chain&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clip_by_global_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;adamw&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;learning_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;schedule&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weight_decay&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;opt_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;interop&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;call_jax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;init&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Note &lt;code&gt;tx.interop.call_jax&lt;/code&gt; — this bridges optax's JAX calls with torchax tensors.&lt;/p&gt;

&lt;h3&gt;
  
  
  Define model_fn and loss_fn
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;model_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;buffers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Stateless forward pass using functional_call.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;func&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;functional_call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;buffers&lt;/span&gt;&lt;span class="p"&gt;},&lt;/span&gt; &lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;batch&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;labels&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Extract loss from HuggingFace model output.&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;model_output&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;code&gt;torch.func.functional_call&lt;/code&gt; runs the model as a pure function — no hidden state, just inputs and outputs. This is what enables JAX to trace and compile it.&lt;/p&gt;

&lt;h3&gt;
  
  
  Compose into a training step
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;step_fn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;train&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make_train_step&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_fn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss_fn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That single line creates a function that does: forward pass → loss computation → gradient calculation → optimizer update — all compiled into one XLA program.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 5: The Training Loop
&lt;/h2&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;tqdm.auto&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tqdm&lt;/span&gt;

&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;train_losses&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;start_time&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;time&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;pbar&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tqdm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;total&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_dataloader&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;step&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="c1"&gt;# Drop attention_mask — Gemma's sliding window attention produces NaN with
&lt;/span&gt;        &lt;span class="c1"&gt;# padded masks on torchax/JAX. Labels already mask padding with -100.
&lt;/span&gt;        &lt;span class="n"&gt;batch&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;!=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;attention_mask&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;}&lt;/span&gt;

        &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;step_fn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;buffers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;opt_state&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;labels&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
        &lt;span class="n"&gt;pbar&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;set_postfix&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;loss&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;

&lt;span class="n"&gt;elapsed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;time&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;start_time&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Training complete! &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;train_losses&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; steps in &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;elapsed&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;What to expect:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Step 1:&lt;/strong&gt; ~30-60 seconds (JAX compiles the entire training step)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Steps 2+:&lt;/strong&gt; ~1-3 seconds each (running the compiled program)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Total:&lt;/strong&gt; ~20-40 minutes for 2000 samples with LoRA on free Colab TPU&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The first step is slow because JAX traces through the entire model, loss computation, gradient calculation, and optimizer update — then compiles it all into a single optimized XLA program. Every subsequent step reuses this compiled program.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 6: Evaluate the Improvement
&lt;/h2&gt;

&lt;p&gt;After training, we compare against our baseline:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Load trained params back into model
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;param&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;parts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;obj&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;part&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parts&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
            &lt;span class="n"&gt;obj&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obj&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;part&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="nf"&gt;setattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;obj&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;parts&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;param&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="n"&gt;final_loss&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;final_ppl&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;evaluate_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eval_dataloader&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Metric&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Before&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;After&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Loss&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;baseline_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mf"&gt;10.4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;final_loss&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mf"&gt;10.4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Perplexity&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;lt;&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;baseline_ppl&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mf"&gt;10.2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;final_ppl&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mf"&gt;10.2&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You should see loss decrease and perplexity improve after training. The qualitative comparison (generated responses before vs. after) is even more telling — the fine-tuned model produces more focused, instruction-following responses.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 7: Save and Reload
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Save
&lt;/h3&gt;

&lt;p&gt;Convert JAX arrays back to CPU tensors and save using HuggingFace's standard format:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;

&lt;span class="n"&gt;save_dir&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;./fine_tuned_model&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;

&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;cpu_state_dict&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
        &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;array&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;)).&lt;/span&gt;&lt;span class="nf"&gt;contiguous&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;params&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="p"&gt;}&lt;/span&gt;
    &lt;span class="c1"&gt;# safe_serialization=False avoids a safetensors/torchax C-extension conflict on reload
&lt;/span&gt;    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;save_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;save_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cpu_state_dict&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;safe_serialization&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;save_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;save_dir&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;For LoRA, this saves only the tiny adapter weights (~20MB). For full fine-tuning, it saves the entire model (~4GB).&lt;/p&gt;

&lt;h3&gt;
  
  
  Reload
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;tx&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;disable_temporarily&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="c1"&gt;# For LoRA: load base model + adapters separately
&lt;/span&gt;    &lt;span class="n"&gt;reloaded_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;MODEL_NAME&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bfloat16&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# torch_device="cpu" forces PEFT to load adapter weights on CPU,
&lt;/span&gt;    &lt;span class="c1"&gt;# avoiding a safetensors/torchax C-extension conflict.
&lt;/span&gt;    &lt;span class="n"&gt;reloaded_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;peft&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;PeftModel&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;reloaded_model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;save_dir&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;reloaded_model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;reloaded_model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The pattern is the same as loading: disable torchax, load on CPU, then move to JAX. For LoRA models, you load the base model first, then attach the saved adapters with &lt;code&gt;PeftModel.from_pretrained()&lt;/code&gt;. The &lt;code&gt;torch_device="cpu"&lt;/code&gt; ensures PEFT loads weights through PyTorch's standard path rather than safetensors' C extension, which conflicts with torchax.&lt;/p&gt;




&lt;h2&gt;
  
  
  Full Fine-Tuning: When LoRA Is Not Enough
&lt;/h2&gt;

&lt;p&gt;The notebook supports full fine-tuning by changing one setting:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;TRAINING_MODE&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;full&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This trains all parameters instead of just the LoRA adapters. The trade-off is much higher memory usage. To make it fit on free Colab TPU:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;AdaFactor optimizer&lt;/strong&gt; — uses ~50% less memory than AdamW (stores only row/column statistics instead of per-parameter moments)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Reduced sequence length&lt;/strong&gt; — &lt;code&gt;MAX_SEQ_LEN = 256&lt;/code&gt; halves activation memory&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Smaller batch size&lt;/strong&gt; — &lt;code&gt;BATCH_SIZE = 1&lt;/code&gt; with higher gradient accumulation steps
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;USE_ADAFACTOR&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;span class="n"&gt;USE_GRADIENT_CHECKPOINTING&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;

&lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;TRAINING_MODE&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;full&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt; &lt;span class="ow"&gt;and&lt;/span&gt; &lt;span class="n"&gt;USE_ADAFACTOR&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;chain&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clip_by_global_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;adafactor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;learning_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;schedule&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;chain&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;clip_by_global_norm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;1.0&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="n"&gt;optax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;adamw&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;learning_rate&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;schedule&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weight_decay&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Full fine-tuning gives a higher quality ceiling but LoRA gets you 90%+ of the way with a fraction of the compute.&lt;/p&gt;




&lt;h2&gt;
  
  
  Troubleshooting
&lt;/h2&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Error&lt;/th&gt;
&lt;th&gt;Cause&lt;/th&gt;
&lt;th&gt;Fix&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;OutOfMemoryError&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Model + optimizer too large&lt;/td&gt;
&lt;td&gt;Switch to LoRA, reduce &lt;code&gt;BATCH_SIZE&lt;/code&gt; or &lt;code&gt;MAX_SEQ_LEN&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;TypeError: not a valid JAX type&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Custom HuggingFace type not registered&lt;/td&gt;
&lt;td&gt;Register with &lt;code&gt;jax.tree_util.register_pytree_node()&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;Loss is NaN&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Numerical instability in bfloat16&lt;/td&gt;
&lt;td&gt;1. Call &lt;code&gt;tx.enable_accuracy_mode()&lt;/code&gt; before &lt;code&gt;tx.enable_globally()&lt;/code&gt;. 2. Reduce LR (try 1e-4). 3. Set &lt;code&gt;lora_dropout=0.0&lt;/code&gt;. 4. Add &lt;code&gt;optax.clip_by_global_norm(1.0)&lt;/code&gt;.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;Slow first step&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Normal — JAX JIT compilation&lt;/td&gt;
&lt;td&gt;Wait ~30-60s; subsequent steps are fast&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;make_train_step error&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;API mismatch&lt;/td&gt;
&lt;td&gt;Update: &lt;code&gt;pip install -U torchax&lt;/code&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;




&lt;h2&gt;
  
  
  The Big Picture: Inference + Training
&lt;/h2&gt;

&lt;p&gt;With the &lt;a href="https://dev.to/gde/run-any-huggingface-model-on-tpus-a-beginners-guide-to-torchax-4ln0"&gt;inference tutorial&lt;/a&gt; and this training tutorial, you now have the complete torchax story:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Run&lt;/strong&gt; any HuggingFace model on TPU (&lt;code&gt;model.to("jax")&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Benchmark&lt;/strong&gt; with JIT compilation (10-100x speedup)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Fine-tune&lt;/strong&gt; with LoRA or full training (&lt;code&gt;make_train_step&lt;/code&gt;)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Save&lt;/strong&gt; and reload for production inference&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;All using PyTorch code. No JAX rewrite needed.&lt;/p&gt;




&lt;h2&gt;
  
  
  Resources
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Notebooks:&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_training_tutorial.ipynb" rel="noopener noreferrer"&gt;Full training tutorial&lt;/a&gt; — all the code from this post, ready to run&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_training_quickstart.ipynb" rel="noopener noreferrer"&gt;Training quickstart&lt;/a&gt; — same pipeline in ~10 cells&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://colab.research.google.com/github/agemagician/torchax-huggingface/blob/main/notebooks/torchax_huggingface_tutorial.ipynb" rel="noopener noreferrer"&gt;Inference tutorial&lt;/a&gt; — Part 1 of this series&lt;/li&gt;
&lt;/ul&gt;


&lt;/li&gt;

&lt;li&gt;

&lt;strong&gt;Libraries:&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;&lt;a href="https://github.com/google/torchax" rel="noopener noreferrer"&gt;torchax GitHub&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="https://huggingface.co/docs/peft" rel="noopener noreferrer"&gt;PEFT/LoRA documentation&lt;/a&gt;&lt;/li&gt;
&lt;li&gt;&lt;a href="https://optax.readthedocs.io/" rel="noopener noreferrer"&gt;optax documentation&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;


&lt;/li&gt;

&lt;li&gt;

&lt;strong&gt;References:&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://github.com/google/torchax/blob/main/examples/peft_lora_training.py" rel="noopener noreferrer"&gt;torchax PEFT LoRA example&lt;/a&gt; — the official example this tutorial builds on&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://huggingface.co/blog/qihqi/huggingface-jax-01" rel="noopener noreferrer"&gt;Han Qi's tutorial series&lt;/a&gt; — the original 3-part series on torchax + HuggingFace&lt;/li&gt;
&lt;/ul&gt;


&lt;/li&gt;

&lt;/ul&gt;

&lt;h2&gt;
  
  
  Credits
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;&lt;a href="https://github.com/qihqi" rel="noopener noreferrer"&gt;Han Qi (@qihqi)&lt;/a&gt;&lt;/strong&gt; — author of torchax, PEFT training example, and the original tutorial series&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;a href="https://github.com/google/torchax" rel="noopener noreferrer"&gt;torchax team at Google&lt;/a&gt;&lt;/strong&gt; — library development&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;a href="https://huggingface.co/" rel="noopener noreferrer"&gt;HuggingFace&lt;/a&gt;&lt;/strong&gt; — transformers, PEFT, and datasets ecosystem&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;a href="https://www.databricks.com/" rel="noopener noreferrer"&gt;Databricks&lt;/a&gt;&lt;/strong&gt; — Dolly 15k dataset&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;&lt;a href="https://github.com/jax-ml/jax" rel="noopener noreferrer"&gt;JAX team at Google&lt;/a&gt;&lt;/strong&gt; — JAX, XLA, and TPU support&lt;/li&gt;
&lt;/ul&gt;

</description>
      <category>machinelearning</category>
      <category>pytorch</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
    <item>
      <title>Run Any HuggingFace Model on TPUs: A Beginner's Guide to TorchAX</title>
      <dc:creator>Ahmed Elnaggar</dc:creator>
      <pubDate>Sun, 29 Mar 2026 17:04:14 +0000</pubDate>
      <link>https://forem.com/gde/run-any-huggingface-model-on-tpus-a-beginners-guide-to-torchax-4ln0</link>
      <guid>https://forem.com/gde/run-any-huggingface-model-on-tpus-a-beginners-guide-to-torchax-4ln0</guid>
      <description>&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcvjk9opldk0897wus527.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcvjk9opldk0897wus527.png" alt="TorchAX" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  What if you could run any HuggingFace model on TPUs — without rewriting a single line of model code?
&lt;/h2&gt;

&lt;p&gt;Here is what the end result looks like:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-1b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;bfloat16&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;
&lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;enable_globally&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Enable AFTER loading the model
&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# That's it. Now running on JAX.
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Five lines. Your PyTorch model is now executing on JAX — with access to TPUs, JIT compilation, and automatic parallelism across devices.&lt;/p&gt;

&lt;p&gt;In this tutorial, we will go from zero to building a working chatbot powered by a HuggingFace model running on JAX. Along the way, you will learn key JAX concepts, see real benchmarks, and understand &lt;em&gt;why&lt;/em&gt; this approach exists.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/ahmed-elnaggar/torchax-huggingface/blob/main/notebooks/torchax_huggingface_tutorial.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open Full Tutorial In Colab" width="117" height="20"&gt;&lt;/a&gt; &lt;a href="https://colab.research.google.com/github/ahmed-elnaggar/torchax-huggingface/blob/main/notebooks/torchax_quickstart.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open Quick Start In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;




&lt;h2&gt;
  
  
  Why This Matters: The HuggingFace + JAX Problem
&lt;/h2&gt;

&lt;p&gt;In 2024, HuggingFace &lt;a href="https://github.com/huggingface/transformers/issues/25004" rel="noopener noreferrer"&gt;removed native JAX and TensorFlow support&lt;/a&gt; from its &lt;code&gt;transformers&lt;/code&gt; library to focus development on PyTorch. This left thousands of JAX users — especially those running on Google Cloud TPUs — without a straightforward way to use HuggingFace's massive model collection.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is JAX?
&lt;/h3&gt;

&lt;p&gt;If you are new to JAX, think of it as Google's high-performance numerical computing library. It looks like NumPy on the surface, but under the hood it offers three powerful capabilities:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;JIT Compilation&lt;/strong&gt; — JAX can compile your Python code into optimized machine code using the XLA compiler. The first run is slower (compilation), but every subsequent call is dramatically faster.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;TPU Support&lt;/strong&gt; — JAX is the native programming model for Google's Tensor Processing Units. If you want to use TPUs, JAX is the most natural path.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Automatic Parallelism&lt;/strong&gt; — JAX can automatically distribute computation across multiple devices (TPUs or GPUs) using a single-program model called gSPMD. You describe &lt;em&gt;what&lt;/em&gt; should be sharded; the compiler figures out &lt;em&gt;how&lt;/em&gt;.&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Enter TorchAX
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://github.com/google/torchax" rel="noopener noreferrer"&gt;&lt;strong&gt;torchax&lt;/strong&gt;&lt;/a&gt; is a library from Google that bridges PyTorch and JAX. It works by creating a special &lt;code&gt;torch.Tensor&lt;/code&gt; subclass that secretly holds a &lt;code&gt;jax.Array&lt;/code&gt; inside. When PyTorch operations are called on this tensor, torchax intercepts them and executes the JAX equivalent instead.&lt;/p&gt;

&lt;p&gt;Think of it like a Trojan horse: PyTorch &lt;em&gt;thinks&lt;/em&gt; it is working with regular tensors, but the computation is actually happening on JAX.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;PyTorch Model
    |
    v
torchax.Tensor (looks like torch.Tensor)
    |
    v
jax.Array (actual computation on TPU/GPU)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This means you can take &lt;em&gt;any&lt;/em&gt; PyTorch model — including HuggingFace models — and run it on JAX without modifying the model code at all.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Credits:&lt;/strong&gt; This tutorial builds on the excellent &lt;a href="https://huggingface.co/blog/qihqi/huggingface-jax-01" rel="noopener noreferrer"&gt;3-part blog series&lt;/a&gt; by Han Qi (&lt;a href="https://github.com/qihqi" rel="noopener noreferrer"&gt;@qihqi&lt;/a&gt;), the author of torchax, and on the &lt;a href="https://google.github.io/torchax/" rel="noopener noreferrer"&gt;torchax documentation&lt;/a&gt;. We expand on those tutorials with beginner-friendly explanations, a different model (Gemma instead of Llama), benchmarks, and a complete Colab-ready notebook.&lt;/p&gt;
&lt;/blockquote&gt;




&lt;h2&gt;
  
  
  TorchAX vs. the Alternatives
&lt;/h2&gt;

&lt;p&gt;Before diving into code, it helps to understand where torchax fits in the broader ecosystem:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Approach&lt;/th&gt;
&lt;th&gt;Effort&lt;/th&gt;
&lt;th&gt;Performance&lt;/th&gt;
&lt;th&gt;Best For&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Rewrite in Flax/Equinox&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;High (full rewrite)&lt;/td&gt;
&lt;td&gt;Native JAX speed&lt;/td&gt;
&lt;td&gt;New projects starting in JAX&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;torch-xla (PyTorch/XLA)&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Low (add XLA device)&lt;/td&gt;
&lt;td&gt;Good (XLA compiled)&lt;/td&gt;
&lt;td&gt;PyTorch training on TPUs&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;torchax&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Low (change device to 'jax')&lt;/td&gt;
&lt;td&gt;Great (JAX JIT + interop)&lt;/td&gt;
&lt;td&gt;Running HF models on JAX, mixing PyTorch + JAX&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;ONNX export&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Medium (export + runtime)&lt;/td&gt;
&lt;td&gt;Variable&lt;/td&gt;
&lt;td&gt;Cross-framework deployment&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;When should you use torchax?&lt;/strong&gt; When you have a PyTorch model (especially from HuggingFace) and want to leverage JAX's JIT compilation, TPU support, or interop with JAX libraries — without rewriting the model.&lt;/p&gt;




&lt;h2&gt;
  
  
  Prerequisites &amp;amp; Setup
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;What you need:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Python 3.10+&lt;/li&gt;
&lt;li&gt;Basic familiarity with PyTorch (loading models, running inference)&lt;/li&gt;
&lt;li&gt;A Google Colab account (free tier works for the 1B model)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Zero-setup option:&lt;/strong&gt; Click the Colab badge above. The notebook handles all installation automatically.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Local setup:&lt;/strong&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="c"&gt;# 1. Install PyTorch (CPU version — torchax handles the accelerator)&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;torch &lt;span class="nt"&gt;--index-url&lt;/span&gt; https://download.pytorch.org/whl/cpu  &lt;span class="c"&gt;# Linux&lt;/span&gt;
&lt;span class="c"&gt;# pip install torch  # macOS&lt;/span&gt;

&lt;span class="c"&gt;# 2. Install JAX for your accelerator&lt;/span&gt;
pip &lt;span class="nb"&gt;install&lt;/span&gt; &lt;span class="nt"&gt;-U&lt;/span&gt; jax[tpu]     &lt;span class="c"&gt;# Google Cloud TPU&lt;/span&gt;
&lt;span class="c"&gt;# pip install -U jax[cuda12]  # NVIDIA GPU&lt;/span&gt;
&lt;span class="c"&gt;# pip install -U jax          # CPU only&lt;/span&gt;

&lt;span class="c"&gt;# 3. Install torchax, transformers, and flax (for JAX compatibility)&lt;/span&gt;
pip &lt;span class="nb"&gt;install&lt;/span&gt; &lt;span class="nt"&gt;-U&lt;/span&gt; torchax transformers flax
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Key Concepts for Beginners
&lt;/h2&gt;

&lt;p&gt;Before we write code, let's demystify three JAX concepts you will encounter throughout this tutorial.&lt;/p&gt;

&lt;h3&gt;
  
  
  Pytrees: JAX's Data Containers
&lt;/h3&gt;

&lt;p&gt;A &lt;strong&gt;pytree&lt;/strong&gt; is any nested structure of Python containers (dicts, lists, tuples) with arrays as leaves. JAX uses pytrees everywhere — model weights are pytrees, function inputs/outputs are pytrees.&lt;/p&gt;

&lt;p&gt;Think of a pytree like a shipping box with labeled compartments. JAX knows how to open standard boxes (dicts, lists, tuples), pull out all the arrays, do math on them, and put them back.&lt;/p&gt;

&lt;p&gt;The catch: JAX does not know how to open &lt;em&gt;custom&lt;/em&gt; boxes. HuggingFace defines custom output types like &lt;code&gt;CausalLMOutputWithPast&lt;/code&gt; — we need to teach JAX how to unpack and repack these. This is called &lt;strong&gt;pytree registration&lt;/strong&gt;, and we will see it in action shortly.&lt;/p&gt;

&lt;h3&gt;
  
  
  JIT Compilation: Translate Once, Run Fast Forever
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;JIT&lt;/strong&gt; (Just-In-Time) compilation is like translating a recipe from English to machine code. The first time you call a JIT-compiled function, JAX traces through it, records all the operations, and compiles an optimized version. Subsequent calls skip the tracing and run the compiled version directly.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;First call:  Python code → trace → compile → execute  (slow)
Second call: compiled code → execute                   (fast!)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The speedup can be 10-100x or more. The trade-off is that the compiled function is specialized for the input shapes it was traced with — if shapes change, JAX recompiles.&lt;/p&gt;

&lt;h3&gt;
  
  
  Static vs. Dynamic Values
&lt;/h3&gt;

&lt;p&gt;When JAX traces a function for JIT, it treats inputs as abstract shapes, not concrete values. If your code has a branch like &lt;code&gt;if use_cache:&lt;/code&gt;, JAX cannot evaluate it during tracing because &lt;code&gt;use_cache&lt;/code&gt; is abstract. This causes a &lt;code&gt;ConcretizationTypeError&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The fix: mark such values as &lt;strong&gt;static&lt;/strong&gt; (compile-time constants) so JAX knows their actual value during tracing. We will see two ways to do this: closures and &lt;code&gt;static_argnums&lt;/code&gt;.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 1: Your First Forward Pass
&lt;/h2&gt;

&lt;p&gt;Let's load a model and run it on JAX. We will use &lt;a href="https://huggingface.co/google/gemma-3-1b-it" rel="noopener noreferrer"&gt;Gemma 3 1B IT&lt;/a&gt; — a small, instruction-tuned model from Google that runs comfortably on free Colab hardware.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;

&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;

&lt;span class="c1"&gt;# Load model and tokenizer
&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-1b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bfloat16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device_map&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Enable torchax globally AFTER model loading
# This prevents intercepting unsupported initialization ops
&lt;/span&gt;&lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;enable_globally&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="c1"&gt;# Move model weights to the JAX device
&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Tokenize an input prompt
&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;The secret to baking a good cake is&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Run a forward pass (eager mode)
&lt;/span&gt;&lt;span class="n"&gt;start&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;elapsed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;start&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Output logits shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Eager forward pass: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;elapsed&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;What happened:&lt;/strong&gt;&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;We load the model on CPU first, &lt;em&gt;then&lt;/em&gt; call &lt;code&gt;torchax.enable_globally()&lt;/code&gt;. This ordering is important — enabling torchax before model loading can intercept unsupported initialization ops and cause errors.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;model.to("jax")&lt;/code&gt; moves every parameter from CPU to the JAX device — just like &lt;code&gt;model.to("cuda")&lt;/code&gt; for GPUs.&lt;/li&gt;
&lt;li&gt;The forward pass runs through PyTorch's code path, but every operation is executed by JAX under the hood.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The output &lt;code&gt;logits&lt;/code&gt; tensor has shape &lt;code&gt;(1, sequence_length, vocab_size)&lt;/code&gt;. Each position contains a score for every token in the vocabulary — the highest score is the model's prediction for the next token.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 2: Speed It Up with JIT Compilation
&lt;/h2&gt;

&lt;p&gt;The eager forward pass works, but it is slow — every operation goes through Python one at a time. Let's compile the model for dramatically faster inference.&lt;/p&gt;

&lt;h3&gt;
  
  
  The extract_jax Approach
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;torchax.extract_jax()&lt;/code&gt; function converts a PyTorch model into a pure JAX function:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Extract a JAX-callable function and the model weights as a pytree
&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extract_jax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This returns two things:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;weights&lt;/code&gt; — the model's state_dict as a pytree of &lt;code&gt;jax.Array&lt;/code&gt;s&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;jax_func&lt;/code&gt; — a function with signature &lt;code&gt;jax_func(weights, args_tuple, kwargs_dict)&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Register HuggingFace Output Types as Pytrees
&lt;/h3&gt;

&lt;p&gt;Before we can JIT this function, we need to teach JAX about HuggingFace's custom types:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;jax.tree_util&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;register_pytree_node&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;modeling_outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache_utils&lt;/span&gt;

&lt;span class="c1"&gt;# Register CausalLMOutputWithPast
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;output_flatten&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_tuple&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;output_unflatten&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;modeling_outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CausalLMOutputWithPast&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;modeling_outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CausalLMOutputWithPast&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;output_flatten&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;output_unflatten&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Register DynamicCache
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_flatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_unflatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cache_utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DynamicCache&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;cache_utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DynamicCache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_flatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_unflatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Handle Static Arguments with a Closure
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;use_cache&lt;/code&gt; flag is a boolean that JAX cannot trace. We wrap it in a closure to make it a compile-time constant:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward_no_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;jax_func&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,),&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;use_cache&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;

&lt;span class="n"&gt;jitted_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;jit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;forward_no_cache&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Benchmark: Eager vs. JIT
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Convert input to a native JAX array for jax.jit
&lt;/span&gt;&lt;span class="n"&gt;jax_input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_put&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;span class="c1"&gt;# Warm up (first call triggers compilation)
&lt;/span&gt;&lt;span class="n"&gt;res&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;jitted_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;block_until_ready&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;res&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Benchmark 3 runs
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;start&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;res&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;jitted_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;block_until_ready&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;res&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;elapsed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;start&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Run &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;elapsed&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Expected output (times will vary by hardware):&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight console"&gt;&lt;code&gt;&lt;span class="gp"&gt;Run 0: 0.0142s  #&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;Already compiled from warm-up
&lt;span class="go"&gt;Run 1: 0.0038s
Run 2: 0.0035s
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The JIT-compiled version runs orders of magnitude faster than eager mode. This is the power of XLA compilation — operations are fused, memory is optimized, and the accelerator runs a single optimized program.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 3: The Simpler API — torchax.compile
&lt;/h2&gt;

&lt;p&gt;The &lt;code&gt;extract_jax&lt;/code&gt; + manual JIT approach gives you full control, but for most cases there is a simpler way. The catch is that &lt;code&gt;torchax.compile()&lt;/code&gt; uses &lt;code&gt;jax.jit&lt;/code&gt; under the hood, so we need to avoid passing dynamic boolean flags like &lt;code&gt;use_cache&lt;/code&gt;. We wrap the model in a thin module that bakes in these constants:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;NoCacheModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;base_model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;base_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;base_model&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Return only logits to avoid HuggingFace output class pytree issues
&lt;/span&gt;        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;base_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="c1"&gt;# One-liner: compile the wrapped model
&lt;/span&gt;&lt;span class="n"&gt;compiled_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nc"&gt;NoCacheModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Use it like a normal PyTorch model
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;compiled_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Under the hood, &lt;code&gt;torchax.compile()&lt;/code&gt; wraps your model in a &lt;code&gt;JittableModule&lt;/code&gt; and applies &lt;code&gt;jax.jit&lt;/code&gt;. The first call triggers compilation; subsequent calls are fast. The &lt;code&gt;NoCacheModel&lt;/code&gt; wrapper ensures that boolean flags are constants (not traced) and that the output is a plain tensor (not a custom HuggingFace type that needs pytree registration).&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 4: Text Classification
&lt;/h2&gt;

&lt;p&gt;Let's use our JIT-compiled model for a practical task — sentiment classification. Since Gemma is an instruction-tuned model, we can use prompt engineering:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;classify_sentiment&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Classify the following text as POSITIVE or NEGATIVE.
Text: &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;
Classification:&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Get the predicted next token
&lt;/span&gt;    &lt;span class="n"&gt;next_token_logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt;
    &lt;span class="n"&gt;next_token_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_token_logits&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;prediction&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;next_token_id&lt;/span&gt;&lt;span class="p"&gt;]).&lt;/span&gt;&lt;span class="nf"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;prediction&lt;/span&gt;

&lt;span class="c1"&gt;# Test it
&lt;/span&gt;&lt;span class="n"&gt;texts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;This movie was absolutely fantastic, I loved every minute!&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;The service was terrible and the food was cold.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;A perfectly average experience, nothing special.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;classify_sentiment&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Text: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;...  =&amp;gt;  &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 5: Text Generation (Autoregressive Decoding)
&lt;/h2&gt;

&lt;p&gt;Classification is useful, but the real power of LLMs is generating text. Let's understand how this works.&lt;/p&gt;

&lt;h3&gt;
  
  
  How Autoregressive Decoding Works
&lt;/h3&gt;

&lt;p&gt;An LLM predicts one token at a time. Given an input of length &lt;code&gt;n&lt;/code&gt;, it produces scores for the next token. We pick one (e.g., the highest-scoring token via greedy decoding), append it to the input, and repeat:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Iteration 1: input (1, n)     → output (1, n)     → pick token
Iteration 2: input (1, n+1)   → output (1, n+1)   → pick token
Iteration 3: input (1, n+2)   → output (1, n+2)   → pick token
...
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The problem: &lt;strong&gt;input shapes change every iteration&lt;/strong&gt;. JIT compilation specializes for fixed shapes, so changing shapes means recompilation every step — worse than eager mode.&lt;/p&gt;

&lt;h3&gt;
  
  
  The KV Cache Solution
&lt;/h3&gt;

&lt;p&gt;The &lt;strong&gt;KV (Key-Value) cache&lt;/strong&gt; stores intermediate computations from previous tokens so the model only needs to process the &lt;em&gt;new&lt;/em&gt; token each iteration:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Iteration 1: input (1, n)              → output + kv_cache(n)
Iteration 2: input (1, 1) + cache(n)   → output + kv_cache(n+1)
Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With a &lt;strong&gt;DynamicCache&lt;/strong&gt;, the cache grows each step — shapes still change. With a &lt;strong&gt;StaticCache&lt;/strong&gt;, the cache has a fixed maximum length — shapes stay constant, making it JIT-friendly.&lt;/p&gt;

&lt;h3&gt;
  
  
  Implementation with StaticCache
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers.cache_utils&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;StaticCache&lt;/span&gt;

&lt;span class="c1"&gt;# Register StaticCache as a pytree
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_flatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt;
    &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;device&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dtype&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_unflatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;aux&lt;/span&gt;
    &lt;span class="n"&gt;kwargs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;device&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dtype&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;
    &lt;span class="n"&gt;cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_flatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_unflatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;

    &lt;span class="c1"&gt;# Create a static cache with fixed maximum length
&lt;/span&gt;    &lt;span class="n"&gt;past_key_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seq_length&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;cache_position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Prefill: process the full prompt
&lt;/span&gt;    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;past_key_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;next_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;generated_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_token&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()]&lt;/span&gt;
    &lt;span class="n"&gt;cache_position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Decode: generate one token at a time
&lt;/span&gt;    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_new_tokens&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;past_key_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;next_token&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;next_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;token_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_token&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;token_id&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eos_token_id&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;break&lt;/span&gt;
        &lt;span class="n"&gt;generated_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;cache_position&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;generated_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;skip_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Generate!
&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;generate_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;The secret to baking a good cake is&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 6: Distributed Inference (Tensor Parallelism)
&lt;/h2&gt;

&lt;p&gt;If you have access to multiple devices (e.g., a TPU v2-8 with 8 chips, or multi-GPU), you can shard the model weights across devices for faster inference.&lt;/p&gt;

&lt;h3&gt;
  
  
  How Tensor Parallelism Works
&lt;/h3&gt;

&lt;p&gt;In tensor parallelism, we split weight matrices across devices:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Column-parallel&lt;/strong&gt;: Q, K, V, Gate, and Up projections are split along the output dimension&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Row-parallel&lt;/strong&gt;: O and Down projections are split along the input dimension&lt;/li&gt;
&lt;li&gt;Between these two, only a single &lt;strong&gt;all-reduce&lt;/strong&gt; operation is needed per layer&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;JAX's gSPMD handles the communication automatically — you just specify how each weight should be sharded.&lt;/p&gt;

&lt;h3&gt;
  
  
  Sharding the Weights
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;jax.sharding&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PartitionSpec&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;NamedSharding&lt;/span&gt;

&lt;span class="c1"&gt;# Create a device mesh
&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make_mesh&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_count&lt;/span&gt;&lt;span class="p"&gt;(),),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;axis&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;shard_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;sharded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tensor&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;any&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;q_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;k_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;v_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;gate_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;up_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
            &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;axis&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Column-parallel
&lt;/span&gt;        &lt;span class="k"&gt;elif&lt;/span&gt; &lt;span class="nf"&gt;any&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;o_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;down_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;lm_head&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;embed_tokens&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
            &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;axis&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Row-parallel
&lt;/span&gt;        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Replicate (e.g., layer norms)
&lt;/span&gt;        &lt;span class="n"&gt;sharded&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_put&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;NamedSharding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;spec&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;sharded&lt;/span&gt;

&lt;span class="c1"&gt;# Apply sharding
&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extract_jax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;shard_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Replicate the input across all devices
&lt;/span&gt;&lt;span class="n"&gt;input_ids_sharded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_put&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nc"&gt;NamedSharding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With sharded weights, the same &lt;code&gt;jax.jit&lt;/code&gt;-compiled function now runs in parallel across all devices. The XLA compiler automatically inserts the necessary all-reduce operations.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Note:&lt;/strong&gt; Tensor parallelism requires a multi-device environment. On free Colab TPU (single device), this section is for illustration. Use a TPU v2-8 or multi-GPU setup to run it.&lt;/p&gt;
&lt;/blockquote&gt;




&lt;h2&gt;
  
  
  Step 7: Build a Mini Chatbot
&lt;/h2&gt;

&lt;p&gt;Let's wrap everything into a simple chat function using Gemma's instruction template:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;chat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;user_message&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Gemma instruction format
&lt;/span&gt;    &lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;start_of_turn&amp;gt;user&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;user_message&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;end_of_turn&amp;gt;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;start_of_turn&amp;gt;model&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
    &lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;generate_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;response&lt;/span&gt;

&lt;span class="c1"&gt;# Example conversation
&lt;/span&gt;&lt;span class="n"&gt;questions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;What is JAX and why would I use it?&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Explain tensor parallelism in simple terms.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Write a haiku about machine learning.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;questions&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;User: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Gemma: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;chat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Swapping to a Larger Model
&lt;/h2&gt;

&lt;p&gt;Everything above uses &lt;code&gt;google/gemma-3-1b-it&lt;/code&gt; (1B parameters). To use a larger model, change the model name:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# 7B model — needs more memory (Colab Pro or multi-device)
&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-7b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The rest of the code remains identical. Larger models produce higher quality outputs but require more memory and compute. The 7B model benefits significantly from tensor parallelism on multi-device setups.&lt;/p&gt;

&lt;p&gt;Other models that work well with torchax include any standard HuggingFace &lt;code&gt;AutoModelForCausalLM&lt;/code&gt; architecture — GPT-2, Llama, Mistral, Phi, and more.&lt;/p&gt;




&lt;h2&gt;
  
  
  Troubleshooting
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;TypeError: ... is not a valid JAX type&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
You need to register the type as a pytree. See the registration examples above for &lt;code&gt;CausalLMOutputWithPast&lt;/code&gt;, &lt;code&gt;DynamicCache&lt;/code&gt;, and &lt;code&gt;StaticCache&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;ConcretizationTypeError: Abstract tracer value encountered&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
A value that changes between calls (like a boolean flag) needs to be either: (1) made static via &lt;code&gt;static_argnums&lt;/code&gt; in &lt;code&gt;jax.jit&lt;/code&gt;, or (2) baked into a closure as a constant.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;UserWarning: A large amount of constants were captured&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
Model weights are being inlined as constants in the compiled graph. Pass them as explicit function arguments instead of closing over them.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;RuntimeError: No available devices&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
Ensure JAX can see your accelerator: &lt;code&gt;print(jax.devices())&lt;/code&gt;. In Colab, check that your runtime type is set to TPU or GPU.&lt;/p&gt;




&lt;h2&gt;
  
  
  Conclusion
&lt;/h2&gt;

&lt;p&gt;In this tutorial, we went from zero to a working chatbot running a HuggingFace model on JAX:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Forward pass&lt;/strong&gt; — moved a PyTorch model to JAX with &lt;code&gt;model.to("jax")&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;JIT compilation&lt;/strong&gt; — compiled for 10-100x speedup with &lt;code&gt;jax.jit&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Text classification&lt;/strong&gt; — used prompt engineering for sentiment analysis&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Text generation&lt;/strong&gt; — implemented autoregressive decoding with StaticCache&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Distributed inference&lt;/strong&gt; — sharded weights across devices with tensor parallelism&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Chatbot&lt;/strong&gt; — wrapped generation in an instruction-following chat function&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The key insight: torchax lets you use the entire HuggingFace ecosystem — models, tokenizers, configs — while running on JAX's high-performance backend. No model rewrites needed.&lt;/p&gt;

&lt;h3&gt;
  
  
  Resources
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://github.com/google/torchax" rel="noopener noreferrer"&gt;torchax GitHub&lt;/a&gt; — library source and documentation&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://google.github.io/torchax/" rel="noopener noreferrer"&gt;torchax Docs&lt;/a&gt; — official getting started guide&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://huggingface.co/blog/qihqi/huggingface-jax-01" rel="noopener noreferrer"&gt;Original tutorial series by Han Qi&lt;/a&gt; — the 3-part blog series this tutorial builds on&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://docs.jax.dev/" rel="noopener noreferrer"&gt;JAX Documentation&lt;/a&gt; — JIT compilation, pytrees, distributed arrays&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://huggingface.co/docs/transformers/llm_optims" rel="noopener noreferrer"&gt;HuggingFace LLM Inference Optimization&lt;/a&gt; — StaticCache and torch.compile docs&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://github.com/ahmed-elnaggar/torchax-huggingface" rel="noopener noreferrer"&gt;Companion GitHub repo&lt;/a&gt; — all code, notebooks, and diagrams&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Credits
&lt;/h3&gt;

&lt;p&gt;This tutorial would not be possible without the work of:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Han Qi (&lt;a href="https://github.com/qihqi" rel="noopener noreferrer"&gt;@qihqi&lt;/a&gt;)&lt;/strong&gt; — author of torchax and the &lt;a href="https://github.com/qihqi/learning_machine/tree/main/jax-huggingface" rel="noopener noreferrer"&gt;original HuggingFace + JAX tutorial series&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;The &lt;strong&gt;torchax team at Google&lt;/strong&gt; — for building and maintaining the library&lt;/li&gt;
&lt;li&gt;The &lt;strong&gt;HuggingFace team&lt;/strong&gt; — for the transformers ecosystem&lt;/li&gt;
&lt;li&gt;The &lt;strong&gt;JAX team at Google&lt;/strong&gt; — for JAX, XLA, and TPU support&lt;/li&gt;
&lt;/ul&gt;




&lt;p&gt;What model will you try running on TPUs first? Let me know in the comments!&lt;/p&gt;

</description>
      <category>machinelearning</category>
      <category>pytorch</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
  </channel>
</rss>
