fattorib/transformer_shmap

Tensor Parallelism with JAX + Shard Map

28
/ 100
Experimental

This project helps machine learning engineers efficiently train very large transformer models. It takes a transformer model definition and training data, and outputs a trained model by distributing the computation across multiple accelerators like TPUs or GPUs. This is for machine learning engineers who need to scale up their large language model training.

No commits in the last 6 months.

Use this if you are a machine learning engineer working with JAX and need to train extremely large transformer models efficiently across multiple accelerators.

Not ideal if you are not familiar with JAX or are training smaller models that don't require tensor parallelism.

large-language-models distributed-training model-scaling deep-learning transformer-architectures
Stale 6m No Package No Dependents
Maintenance 0 / 25
Adoption 5 / 25
Maturity 16 / 25
Community 7 / 25

How are scores calculated?

Stars

11

Forks

1

Language

Python

License

MIT

Last pushed

Sep 29, 2023

Commits (30d)

0

Get this data via API

curl "https://pt-edge.onrender.com/api/v1/quality/transformers/fattorib/transformer_shmap"

Open to everyone — 100 requests/day, no key needed. Get a free key for 1,000/day.