JAX ML Frameworks

Libraries and frameworks built on JAX for neural networks, optimization, and machine learning research. Includes JAX-based neural network layers, training utilities, optimization algorithms, and domain-specific extensions (graphs, replay buffers, Earth observation). Does NOT include general Python ML frameworks, probabilistic programming languages, or domain applications built with JAX.

There are 126 jax ml frameworks tracked. 5 score above 70 (verified tier). The highest-rated is explosion/thinc at 80/100 with 2,893 stars. 10 of the top 10 are actively maintained.

Get all 126 projects as JSON

curl "https://pt-edge.onrender.com/api/v1/datasets/quality?domain=ml-frameworks&subcategory=jax-ml-frameworks&limit=20"

Open to everyone — 100 requests/day, no key needed. Get a free key for 1,000/day.

# Framework Score Tier
1 explosion/thinc

🔮 A refreshing functional take on deep learning, compatible with your...

80
Verified
2 google-deepmind/optax

Optax is a gradient processing and optimization library for JAX.

79
Verified
3 patrick-kidger/diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and...

73
Verified
4 google/grain

Library for reading and processing ML training data.

73
Verified
5 patrick-kidger/equinox

Elegant easy-to-use neural networks + scientific computing in JAX....

71
Verified
6 extropic-ai/thrml

Thermodynamic Hypergraphical Model Library in JAX

68
Established
7 thomaspinder/GPJax

Gaussian processes in JAX and Flax.

67
Established
8 tumaer/JAXFLUIDS

Differentiable Fluid Dynamics Package

66
Established
9 patrick-kidger/optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox....

65
Established
10 google-deepmind/dm-haiku

JAX-based neural network library

62
Established
11 MichaelTMatthews/Craftax

(Crafter + NetHack) in JAX. ICML 2024 Spotlight.

62
Established
12 google/torchax

torchax is a PyTorch frontend for JAX. It gives JAX the ability to author...

61
Established
13 google-research/kauldron

Modular, scalable library to train ML models

60
Established
14 apple/axlearn

An Extensible Deep Learning Library

60
Established
15 google-deepmind/kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.

60
Established
16 google/jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

60
Established
17 lockwo/distreqx

Distrax, but in equinox. Lightweight JAX library of probability...

59
Established
18 apax-hub/apax

A flexible and performant framework for training machine learning potentials.

58
Established
19 Ceyron/exponax

Efficient Differentiable n-d PDE Solvers in JAX.

58
Established
20 flaport/sax

S + Autograd + XLA :: S-parameter based frequency domain circuit simulations...

58
Established
21 ekzhang/jax-js

JAX in JavaScript – ML library for the web, running on WebGPU & Wasm

57
Established
22 jax-ml/bonsai

Minimal, lightweight JAX implementations of popular models.

56
Established
23 n2cholas/awesome-jax

JAX - A curated list of resources https://github.com/google/jax

55
Established
24 arpastrana/jax_fdm

Auto-differentiable and hardware-accelerated force density method

53
Established
25 e3nn/e3nn-jax

jax library for E3 Equivariant Neural Networks

52
Established
26 camail-official/discretax

Discretax is a light weight collection of state space models implemented in JAX ⚡️

52
Established
27 dpiras/cosmopower-jax

Differentiable cosmological emulators: the JAX version of CosmoPower

50
Established
28 instadeepai/flashbax

⚡ Flashbax: Accelerated Replay Buffers in JAX

50
Established
29 sotetsuk/pgx

♟️ Vectorized RL game environments in JAX

50
Established
30 GalacticDynamics/diffraxtra

Extras for Diffrax: OOP and vectorization

49
Emerging
31 Dicklesworthstone/model_guided_research

Systematic investigation of 11 exotic math frameworks (Lie groups, tropical...

49
Emerging
32 bsc-quantic/tn4ml

Tensor Networks for Machine Learning

49
Emerging
33 thorben-frank/mlff

Build neural networks for machine learning force fields with JAX

49
Emerging
34 ergodicio/adept

Automatic-Differentiation-Enabled Plasma Transport in JAX

49
Emerging
35 gordicaleksa/get-started-with-JAX

The purpose of this repo is to make it easy to get started with JAX, Flax,...

48
Emerging
36 perrin-isir/xpag

a modular reinforcement learning library with JAX agents

48
Emerging
37 poets-ai/elegy

A High Level API for Deep Learning in JAX

48
Emerging
38 danijar/ninjax

General Modules for JAX

48
Emerging
39 srush/annotated-s4

Implementation of https://srush.github.io/annotated-s4

48
Emerging
40 google-deepmind/jeo

Jeo: Jax model training lib for Earth Observation

47
Emerging
41 genjax-community/genjax

Probabilistic programming with programmable inference for parallel accelerators.

47
Emerging
42 google/trax

Trax — Deep Learning with Clear Code and Speed

46
Emerging
43 MahmudulAlam/Holographic-Reduced-Representations

Holographic Reduced Representations

46
Emerging
44 BirkhoffG/jax-dataloader

Pytorch-like dataloaders for JAX.

45
Emerging
45 FLAIROx/Kinetix

Reinforcement learning on general 2D physics environments in JAX. ICLR 2025 Oral.

45
Emerging
46 thebuckleylab/jpc

Flexible Inference for Predictive Coding Networks in JAX.

45
Emerging
47 eserie/wax-ml

A Python library for machine-learning and feedback loops on streaming data

44
Emerging
48 tfm000/copulax

JAX based probability modelling

44
Emerging
49 tinker495/JAxtar

JAxtar is a project with a JAX-native implementation of parallelizeable A* &...

44
Emerging
50 RobinKa/jaxga

Geometric Algebra package for JAX

44
Emerging
51 google-deepmind/jraph

A Graph Neural Network Library in Jax

44
Emerging
52 ikostrikov/jaxrl

JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with...

44
Emerging
53 liblaf/apple

🍎 A JAX and Warp library for differentiable physics simulation, featuring...

43
Emerging
54 shyamsn97/hyper-nn

Easy Hypernetworks in Pytorch and Jax

43
Emerging
55 matomatical/hijax

An introduction to vanilla JAX for deep learning research

43
Emerging
56 matthias-wright/flaxmodels

Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.

42
Emerging
57 francois-rozet/inox

Stainless neural networks in JAX

42
Emerging
58 lockwo/awesome-jax

Curated list of JAX Resources and Packages

41
Emerging
59 Ceyron/pdequinox

Neural Emulator Architectures in JAX.

41
Emerging
60 duyongan/sunstreaker

以jax为后端的类似keras的框架

40
Emerging
61 AakashKumarNain/TF_JAX_tutorials

All about the fundamental blocks of TF and JAX!

40
Emerging
62 instadeepai/catx

🐈‍⬛ Contextual bandits library for continuous action trees with smoothing in JAX

40
Emerging
63 AaltoML/kalman-jax

Approximate inference for Markov Gaussian processes using iterated Kalman...

39
Emerging
64 mancusolab/susiepca

Scalable Ultra-Sparse Bayesian PCA

39
Emerging
65 m-wojnar/reinforced-lib

Reinforcement learning library

38
Emerging
66 affjljoo3581/deit3-jax

Jax/Flax implementation of DeiT and DeiT-III (ViT)

38
Emerging
67 cgarciae/treex

A Pytree Module system for Deep Learning in JAX

38
Emerging
68 ivy-llc/mech

Mechanics functions with end-to-end support for deep learning developers,...

38
Emerging
69 danielkelshaw/riemax

Riemannian geometry in JAX

37
Emerging
70 BobMcDear/flaim

Flax Image Models - State-of-the-art pre-trained vision backbones for Flax.

37
Emerging
71 evgenii-nikishin/omd

JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning...

37
Emerging
72 auto-differentiation/xad-py

High-Performance Automatic Differentiation for Python

36
Emerging
73 jeertmans/sampling-paths

Generative Path Candidate Sampling for Faster Point-to-Point Ray Tracing

36
Emerging
74 IvanIsCoding/GNN-for-Combinatorial-Optimization

JAX + Flax implementation of "Combinatorial Optimization with...

36
Emerging
75 XanaduAI/GradDFT

GradDFT is a JAX-based library enabling the differentiable design and...

36
Emerging
76 ericjang/pt-jax

Path Tracing in JAX

35
Emerging
77 Twistient/HoloVec

Holographic vectors you can compute with. Bind structure, bundle sets,...

35
Emerging
78 satojkovic/vit-jax-flax

Vision Transformer from scratch (JAX/Flax).

35
Emerging
79 wladekpal/golden-standard

Is Temporal Difference Learning the Gold Standard for Stitching in RL? Code...

35
Emerging
80 google-deepmind/dks

Multi-framework implementation of Deep Kernel Shaping and Tailored...

35
Emerging
81 mila-iqia/torch_jax_interop

Simple tools to mix and match PyTorch and Jax - Get the best of both worlds!

35
Emerging
82 evgenii-nikishin/rl_with_resets

JAX implementation of deep RL agents with resets from the paper "The Primacy...

34
Emerging
83 omron-sinicx/jaxmapp

JAX-based implementation for multi-agent path planning (MAPP) in continuous spaces.

34
Emerging
84 HomebrewML/revlib

Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed...

34
Emerging
85 ml-for-gp/jaxgptoolbox

Geometry processing utilities compatible with jax for autodifferentiation.

34
Emerging
86 google-research/jestimator

Amos optimizer with JEstimator lib.

33
Emerging
87 davisyoshida/haiku-mup

A port of muP to JAX/Haiku

33
Emerging
88 phlippe/jax_trainer

Lightning-like training API for JAX with Flax

33
Emerging
89 juliuskunze/cwvae-jax

Clockwork VAEs in JAX/Flax

32
Emerging
90 amoudgl/celo

Code for Celo: Training Versatile Learned Optimizers on a Compute Diet

32
Emerging
91 yonesuke/jaxfss

JAX/Flax implementation of finite-size scaling

31
Emerging
92 cor3bit/awesome-soms

A curated list of resources for second-order stochastic optimization

31
Emerging
93 graphcore-research/jax-experimental

JAX for Graphcore IPU (experimental)

30
Emerging
94 malbertosm/frp_rl

Explore the "frp_rl" repository to discover the Free Random Projection...

30
Emerging
95 ethanluoyc/magi

Reinforcement learning library in JAX.

30
Emerging
96 stefanosele/GPfY

Gaussian processes with spherical harmonic features in JAX

29
Experimental
97 pythoncrazy/jimm

JAX Image Modeling of Models contains Computer Vision/Vision Language Model...

29
Experimental
98 NITHISHM2410/flax-pilot

A Simplistic trainer for Flax

29
Experimental
99 mzguntalan/neptune

[WIP] Neptune: JAX iterop-able library in Haskell.

29
Experimental
100 cgarciae/nnx

Neural Networks for JAX

29
Experimental
101 norabelrose/classroom

Preference-based reinforcement learning in PyTorch and JAX with a browser-based GUI.

28
Experimental
102 mzguntalan/zephyr

Zephyr is a declarative neural network library on top of JAX allowing for...

28
Experimental
103 lweitkamp/GANs-JAX

Implementation of several Generative Adversarial Networks in JAX / Flax

27
Experimental
104 alexOarga/haiku-geometric

A collection of graph neural networks implementations in JAX

27
Experimental
105 cifkao/jax-spectral

Short-time Fourier transform (STFT) for JAX

27
Experimental
106 yardenas/jax-dreamer

Dreamer on JAX

27
Experimental
107 Auxeno/ion

A minimal neural network library for JAX

27
Experimental
108 phydra-labs/phydrax

Modular Physics-ML Components in JAX

26
Experimental
109 camml-lab/reax

REAX — Scalable, flexible training for JAX, inspired by the simplicity of...

25
Experimental
110 Anuoluwapo65/pytorch-jax-implementation

pytorch jax

25
Experimental
111 OleksiiBevza/jaxpsmc

JAX based Preconditioned Sequential Monte Carlo framework

25
Experimental
112 ethanluoyc/corax

Corax: Core RL in JAX

23
Experimental
113 ASEM000/serket

The ✨Magical✨ JAX ML Library.

22
Experimental
114 nathanwispinski/meta-rl

A short conceptual replication of "Prefrontal cortex as a meta-reinforcement...

22
Experimental
115 forynski/jax-pid-nn

High-performance JAX/Flax neural network for particle identification in...

22
Experimental
116 ysngshn/ivon-optax

An Optax-based JAX implementation of the IVON optimizer for large-scale VI...

21
Experimental
117 MizuhoAOKI/jax_playground

A collection of hands-on examples for exploring numerical algorithms with JAX.

21
Experimental
118 ScottAlexanderCameron/Jynx

A neural network library written in jax

21
Experimental
119 Ceyron/trainax

Training methodologies for autoregressive neural operators/emulators in JAX.

21
Experimental
120 dtunai/LongConv-Jax

Jax/Flax/Linen implementation of "Simple Hardware-Efficient Long...

19
Experimental
121 wolfwdavid/jax-pinn

JAX/Flax physics-informed neural network with jax2tf export — benchmark JAX...

14
Experimental
122 elliotvilhelm/jax-policy-gradient

🤖 JAX implementation of REINFORCE policy gradient with baseline subtraction,...

12
Experimental
123 PundarikakshNTripathi/JAX-Deep-Dive

Deep dive into JAX: JIT compilation, autodiff, vmap, GPU neural networks,...

12
Experimental
124 bhadreshpsavani/LearningJax

Learning Notes and resources on Jax

12
Experimental
125 SalamanderXing/graph-transformer-jax

graph transformer in JAX

11
Experimental
126 imdebamrita/JAX-FLAX

A learning repository for JAX and FLAX, exploring automatic differentiation,...

10
Experimental