JAXFlow: A JAX-based Deep Learning and Machine Learning Framework
JAXFlow is a modern, lightweight neural network library built on top of JAX. It is a pure-functional, multi-device-ready, and modular deep learning framework designed for research, experimentation, and production-ready machine learning pipelines.
If you're searching for a fast, flexible, and fully-JAX-compatible framework for building neural networks, JAXFlow is designed for you.
๐ Why JAXFlow?
JAXFlow is not just another wrapper around JAXโit's a ground-up, PyTree-aware system for creating, training, and deploying high-performance deep learning models with minimal overhead and full control.
๐ Key Features
-
โ Modular Model API Define networks using
Sequential, subclassedModels, or flexible functional blocks. -
๐ฆฎ JAX-Compatible Execution Built from the ground up to support
jit,vmap,pmap,pjit, and full PyTree semantics. -
๐บ Rich Layer Library Includes
Dense,Conv,BatchNorm,Embedding,Dropout, and customLayerclasses. -
๐๏ธ Training API Use
.compile()+.fit()or write custom training loops for full control. -
โ๏ธ Optimizers & Schedulers Built-in integration with Optax.
-
๐ Losses & Streaming Metrics Includes
CrossEntropy,MSE,Accuracy,F1,Precision, and more. -
๐ Callbacks & Checkpoints Support for EarlyStopping, LearningRateScheduler, and Orbax save/load.
-
๐ง Built-in Models Comes with ready-to-use
ResNet,MLP,Transformer, and composable blocks. -
โก Lazy Imports Fast to import, loading deep components only when needed.
๐ฆ Installation
Requires Python โฅ3.9 and a valid JAX installation.
To install JAX with GPU or TPU support:
pip install "jax[cuda]>=0.6.0" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlOr use:
pip install --upgrade jaxflow[GPU] pip install --upgrade jaxflow[TPU]
๐งโ๐ป Quickstart: Build Your First JAXFlow Model
JAXFlow supports two modeling styles: subclassing and sequential-style.
1. Subclassing a Model
import jaxflow as jf from jaxflow.models import Model from jaxflow.layers import Conv2D, MaxPooling2D, Dense from jaxflow.initializers import GlorotUniform, Zeros class CNN(Model): def __init__(self, num_classes: int = 10): super().__init__() self.conv1 = Conv2D(32, (3, 3), activation=jf.activations.relu) self.pool1 = MaxPooling2D((2, 2)) self.conv2 = Conv2D(64, (3, 3), activation=jf.activations.relu) self.pool2 = MaxPooling2D((2, 2)) self.flatten = jf.layers.GlobalAveragePooling2D() self.dense1 = Dense(64, activation=jf.activations.relu) self.outputs = Dense(num_classes, activation=jf.activations.softmax) def call(self, inputs, training=False): x = self.conv1(inputs, training=training) x = self.pool1(x) x = self.conv2(x, training=training) x = self.pool2(x) x = self.flatten(x) x = self.dense1(x, training=training) return self.outputs(x, training=training)
2. Sequential API with .add()
model = jf.models.Model() model.add(jf.layers.Conv2D(32, (3, 3), activation=jf.activations.relu)) model.add(jf.layers.MaxPooling2D((2, 2))) model.add(jf.layers.Conv2D(64, (3, 3), activation=jf.activations.relu)) model.add(jf.layers.MaxPooling2D((2, 2))) model.add(jf.layers.GlobalAveragePooling2D()) model.add(jf.layers.Dense(64, activation=jf.activations.relu)) model.add(jf.layers.Dense(10, activation=jf.activations.softmax)) model.build(input_shape=(None, 28, 28, 1)) model.compile(optimizer=jf.optimizers.Adam(0.001), loss_fn=jf.losses.SparseCategoricalCrossentropy()) model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)
๐ Documentation
Full documentation and API reference available:
- ๐ GitHub Repo
- ๐ API Docs
- ๐ฆ PyPI Package
๐๏ธ Project Structure
jaxflow/
โโโ activations/ # ReLU, GELU, Swish, etc.
โโโ callbacks/ # EarlyStopping, Logger, Checkpointing
โโโ core/ # Base module, scopes, tree utilities
โโโ gradient/ # JAX custom grad support
โโโ initializers/ # Glorot, He, Zeros, ...
โโโ layers/ # Dense, Conv2D, Embedding, ...
โโโ losses/ # CrossEntropy, MSE, ...
โโโ metrics/ # Accuracy, F1, Precision, ...
โโโ models/ # Sequential, Transformer, ResNet
โโโ optimizers/ # Optax integrations
โโโ regularizers/ # L1, L2, Dropout
๐ฎ Whatโs Next
Planned additions to JAXFlow:
- โ๏ธ Transformer layer with attention support
- โ๏ธ Full callback system with exportable training logs
- โ๏ธ Model persistence and loading with Orbax
- โ๏ธ Classical ML algorithms: SVM, KNN, Logistic Regression
๐ License
JAXFlow is licensed under the Apache-2.0 License.
๐ Built with care for the JAX community. Keep your models clean, fast, and functionalโwith JAXFlow.
jaxflow, jax deep learning, jax neural network library, jax model framework, python deep learning, neural network in jax, jaxflow documentation, jaxflow github, functional deep learning, modular jax library
