phlippe/jax_trainer
Lightning-like training API for JAX with Flax
This project offers a standardized way to train machine learning models using JAX and Flax. It takes your raw data and model definition, then automatically handles the training process, evaluation, and logging. It's designed for machine learning researchers and engineers who are building and experimenting with models.
No commits in the last 6 months.
Use this if you are a machine learning practitioner who wants to efficiently train and evaluate models in JAX/Flax without repeatedly writing boilerplate code for training loops, logging, and checkpointing.
Not ideal if you are not working with JAX and Flax, or if you need extremely fine-grained, manual control over every line of your model's training process.
Stars
45
Forks
4
Language
Python
License
MIT
Category
Last pushed
Dec 08, 2024
Commits (30d)
0
Get this data via API
curl "https://pt-edge.onrender.com/api/v1/quality/ml-frameworks/phlippe/jax_trainer"
Open to everyone — 100 requests/day, no key needed. Get a free key for 1,000/day.
Higher-rated alternatives
explosion/thinc
🔮 A refreshing functional take on deep learning, compatible with your favorite libraries
google-deepmind/optax
Optax is a gradient processing and optimization library for JAX.
patrick-kidger/diffrax
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable....
google/grain
Library for reading and processing ML training data.
patrick-kidger/equinox
Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/