This piece serves as a short pre-read to JAX for my Ray Tracer in JAX.
JAX is a Python library developed by Google that treats functions as mathematical objects that can be transformed, analyzed, and optimized.
JAX only works with pure functions — functions where the same input always produces the same output with no side effects like print statements or global state mutations. This is what enables JAX to reason about code mathematically.
# ❌ This breaks JIT compilation
counter = 0
def impure_function(x):
global counter
counter += 1 # Side effect - changes every call
return x * counter
# First call: jitted_fn(5.0) might return 5.0
# Second call: jitted_fn(5.0) should return 10.0, but JIT cached the first result
jitted_fn = jax.jit(impure_function)
# ✅ This works with JIT
def pure_function(x):
return x ** 2 + 3 * x
# JAX can safely compile this (same input always gives same output)
jitted_pure = jax.jit(pure_function)
When JAX applies transformations like jit
, vmap
, or grad
, it needs to understand your function’s mathematical behavior completely. The JIT compiler traces your function once and generates optimized machine code. If the function could behave differently on subsequent calls due to hidden state, the compiled version would be incorrect.
Function purity enables three key transformations:
jit
can aggressively optimize because the function’s behavior is guaranteed to be invariantvmap
can safely vectorize operations because there are no hidden dependencies between iterationsgrad
produces mathematically correct gradients because the computation graph is well-definedTraditional GPU acceleration requires rewriting code in CUDA. JAX compiles to XLA (Accelerated Linear Algebra), Google’s domain-specific compiler that optimizes across device boundaries.
XLA’s fusion optimization combines multiple operations into single kernels, eliminating intermediate memory transfers. This allows JAX code to automatically become optimal low-level implementations:
Most frameworks use operator overloading for automatic differentiation by intercepting operations to build computational graphs:
# Traditional framework approach
x = torch.tensor([1.0], requires_grad=True)
y = x ** 2 + 3 * x # Framework records: pow, mul, add
y.backward() # Traverse graph backwards
JAX’s grad
returns a new function that computes gradients. Since gradients are functions, you can differentiate them again for higher-order derivatives. JAX automatically selects forward-mode or reverse-mode differentiation based on your computation’s mathematical structure.
# JAX approach
def f(x):
return x ** 2 + 3 * x
grad_f = jax.grad(f) # Returns a new function
JAX’s SPMD (Single Program, Multiple Data) model runs identical computations on different data shards across devices. You specify how to split your data; JAX automatically distributes the computation. This enables seamless scaling from single-device prototypes to multi-device production.
Traditional scientific computing forces a choice: readable Python (slow) or optimized C++/CUDA (fast but hard to iterate). JAX eliminates this tradeoff by making transformations orthogonal to your core logic.
This composability enables research directions that would be prohibitively complex in other frameworks: differentiable physics simulations, inverse rendering, neural radiance fields. JAX makes combining optimization, vectorization, and differentiation trivial, removing barriers between domain expertise and high-performance computing.
Read the full JAX documentation here.
**
While GPUs evolved from graphics rendering with thousands of flexible cores running Single Instruction, Multiple Thread (SIMT), TPUs were purpose-built for one thing: matrix multiplication at massive scale.
TPUs use a systolic array architecture: a grid of simple processing elements where data flows in waves, creating highly efficient pipelines for ML workloads.
TPUs started as inference-only chips powering Google Search in 2015, processing billions of queries daily. Google’s had both the hardware expertise and the massive-scale ML workloads to justify custom silicon, which enabled iterative improvements battled-tested on production systems before external release.
This creates a competitive moat. While NVIDIA’s GPUs serve broad markets, TPUs focus solely on ML. Google optimizes the entire stack: TPU hardware, XLA compiler, JAX framework, and cloud infrastructure. This vertical integration enables optimizations impossible with commodity hardware. Making TPUs accessible through Colab, Kaggle, and Google Cloud also builds ecosystem lock-in — researchers prototyping on free TPUs naturally scale on Google Cloud’s paid infrastructure.
This hardware-software co-design is why JAX can achieve such dramatic performance gains on Google’s infrastructure.