Mastering JAX Basics: Arrays, Data Types, and Functional Style

Date

Date

Date

April 25, 2025

April 25, 2025

April 25, 2025

Author

Author

Author

Akshay Atam

Akshay Atam

Akshay Atam

Introduction

JAX is one of the most exciting libraries to emerge in the machine learning and scientific computing communities. Built by researchers at Google, JAX combines the ease of NumPy with the power of automatic differentiation (grad), just-in-time compilation (jit), and hardware acceleration (GPU, TPU).

If you're a Machine Learning Engineer, Data Scientist, or Researcher who wants to tap into the next generation of high-performance computation, this article is for you. We'll walk through the fundamentals: setting up JAX, understanding its data types, and seeing how functional programming principles are deeply embedded into its design.

Whether you're training deep neural networks or exploring scientific simulations, mastering the basics of JAX will open doors to more efficient and scalable workflows.

Try It Yourself on Google Colab 🚀

You can run this notebook interactively on Google Colab without installing anything!

👉 Open the Notebook in Google Colab

No setup needed — just click, connect to a TPU, and start exploring JAX!

(Note: To enable TPU acceleration, go to RuntimeChange runtime type → Select TPU.)

Setting Up JAX

Before starting, you have two options:

  • Local installation:

  • Google Colab: If you're using Google Colab (recommended for beginners), make sure to enable TPU for free acceleration:

  • Navigate to Runtime -> Change runtime type -> Select v2-8 TPU -> Save

This setup ensures that you can take full advantage of JAX's capabilities.


Importing Libraries

Let's import the libraries we'll be using:


You'll notice we import jax.numpy as jnp. This module provides a NumPy-like interface but with added power: differentiation and compilation!


Working with Data Types in JAX

You can create simple scalars using JAX, just like in NumPy.


Notice that the type returned is ArrayImpl. Think of it like NumPy's ndarray — but one that can be differentiated and optimized.

You can even add these variables:



Going Lower Level with lax

JAX also offers jax.lax, a lower-level API for finer control. Here, operations are stricter:


While jnp would allow type promotion (like NumPy), lax expects stricter type control, making it more predictable during optimization.


Creating Arrays in JAX

Array creation is very similar to NumPy:

a = jnp.arange(10)  # 0 to 9
b = jnp.linspace(0, 10, 30)  # 30 points between 0 and 10
arr = jnp.array([50, 14, 235, 63]

These operations will feel extremely natural if you already know NumPy.


Key Takeaways

  • JAX builds on familiar concepts from NumPy but adds the power of differentiation, compilation, and parallelism.

  • It enforces a functional programming paradigm: functions are encouraged to be pure (no side effects, no global state changes).

  • JAX arrays are like NumPy arrays but are designed to run seamlessly on CPUs, GPUs, and TPUs.

In short: if you know NumPy, you already have a head start!


What's Next?

This was just the beginning! In Part 2, we'll dive into:

  • Understanding pure vs impure functions in JAX

  • Array operations in JAX

  • jit for faster execution


    Stay tuned — JAX is about to get a lot more exciting!

Other posts

May 12, 2025

Uncovering Helmet Detection with YOLOv5 and FiftyOne: A Deep Dive

May 12, 2025

Uncovering Helmet Detection with YOLOv5 and FiftyOne: A Deep Dive

May 12, 2025

Uncovering Helmet Detection with YOLOv5 and FiftyOne: A Deep Dive

April 9, 2025

Reviving the Zenith Z-150: A Vintage Restoration Journey

April 9, 2025

Reviving the Zenith Z-150: A Vintage Restoration Journey

April 9, 2025

Reviving the Zenith Z-150: A Vintage Restoration Journey

Got questions?

I’m always excited to collaborate on innovative and exciting projects!

Got questions?

I’m always excited to collaborate on innovative and exciting projects!

Got questions?

I’m always excited to collaborate on innovative and exciting projects!