GitHub - mthd98/JAXFlow

JAXFlow Logo

PyPI version License Build Status Coverage Status


JAXFlow: A JAX-based Deep Learning and Machine Learning  Framework

JAXFlow is a modern, lightweight neural network library built on top of JAX. It is a pure-functional, multi-device-ready, and modular deep learning framework designed for research, experimentation, and production-ready machine learning pipelines.

If you're searching for a fast, flexible, and fully-JAX-compatible framework for building neural networks, JAXFlow is designed for you.


๐Ÿš€ Why JAXFlow?

JAXFlow is not just another wrapper around JAXโ€”it's a ground-up, PyTree-aware system for creating, training, and deploying high-performance deep learning models with minimal overhead and full control.

๐Ÿ”‘ Key Features

  • โœ… Modular Model API Define networks using Sequential, subclassed Models, or flexible functional blocks.

  • ๐Ÿฆฎ JAX-Compatible Execution Built from the ground up to support jit, vmap, pmap, pjit, and full PyTree semantics.

  • ๐Ÿ—บ Rich Layer Library Includes Dense, Conv, BatchNorm, Embedding, Dropout, and custom Layer classes.

  • ๐Ÿ‹๏ธ Training API Use .compile() + .fit() or write custom training loops for full control.

  • โš™๏ธ Optimizers & Schedulers Built-in integration with Optax.

  • ๐Ÿ“Š Losses & Streaming Metrics Includes CrossEntropy, MSE, Accuracy, F1, Precision, and more.

  • ๐Ÿ“‚ Callbacks & Checkpoints Support for EarlyStopping, LearningRateScheduler, and Orbax save/load.

  • ๐Ÿง  Built-in Models Comes with ready-to-use ResNet, MLP, Transformer, and composable blocks.

  • โšก Lazy Imports Fast to import, loading deep components only when needed.


๐Ÿ“ฆ Installation

Requires Python โ‰ฅ3.9 and a valid JAX installation.

To install JAX with GPU or TPU support:

pip install "jax[cuda]>=0.6.0" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Or use:

pip install --upgrade jaxflow[GPU]
pip install --upgrade jaxflow[TPU]

๐Ÿง‘โ€๐Ÿ’ป Quickstart: Build Your First JAXFlow Model

JAXFlow supports two modeling styles: subclassing and sequential-style.

1. Subclassing a Model

import jaxflow as jf
from jaxflow.models import Model
from jaxflow.layers import Conv2D, MaxPooling2D, Dense
from jaxflow.initializers import GlorotUniform, Zeros

class CNN(Model):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.conv1 = Conv2D(32, (3, 3), activation=jf.activations.relu)
        self.pool1 = MaxPooling2D((2, 2))
        self.conv2 = Conv2D(64, (3, 3), activation=jf.activations.relu)
        self.pool2 = MaxPooling2D((2, 2))
        self.flatten = jf.layers.GlobalAveragePooling2D()
        self.dense1 = Dense(64, activation=jf.activations.relu)
        self.outputs = Dense(num_classes, activation=jf.activations.softmax)

    def call(self, inputs, training=False):
        x = self.conv1(inputs, training=training)
        x = self.pool1(x)
        x = self.conv2(x, training=training)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dense1(x, training=training)
        return self.outputs(x, training=training)

2. Sequential API with .add()

model = jf.models.Model()
model.add(jf.layers.Conv2D(32, (3, 3), activation=jf.activations.relu))
model.add(jf.layers.MaxPooling2D((2, 2)))
model.add(jf.layers.Conv2D(64, (3, 3), activation=jf.activations.relu))
model.add(jf.layers.MaxPooling2D((2, 2)))
model.add(jf.layers.GlobalAveragePooling2D())
model.add(jf.layers.Dense(64, activation=jf.activations.relu))
model.add(jf.layers.Dense(10, activation=jf.activations.softmax))

model.build(input_shape=(None, 28, 28, 1))
model.compile(optimizer=jf.optimizers.Adam(0.001), loss_fn=jf.losses.SparseCategoricalCrossentropy())
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)

๐Ÿ“– Documentation

Full documentation and API reference available:


๐Ÿ—‚๏ธ Project Structure

jaxflow/
โ”œโ”€โ”€ activations/       # ReLU, GELU, Swish, etc.
โ”œโ”€โ”€ callbacks/         # EarlyStopping, Logger, Checkpointing
โ”œโ”€โ”€ core/              # Base module, scopes, tree utilities
โ”œโ”€โ”€ gradient/          # JAX custom grad support
โ”œโ”€โ”€ initializers/      # Glorot, He, Zeros, ...
โ”œโ”€โ”€ layers/            # Dense, Conv2D, Embedding, ...
โ”œโ”€โ”€ losses/            # CrossEntropy, MSE, ...
โ”œโ”€โ”€ metrics/           # Accuracy, F1, Precision, ...
โ”œโ”€โ”€ models/            # Sequential, Transformer, ResNet
โ”œโ”€โ”€ optimizers/        # Optax integrations
โ””โ”€โ”€ regularizers/      # L1, L2, Dropout

๐Ÿ”ฎ Whatโ€™s Next

Planned additions to JAXFlow:

  • โ˜‘๏ธ Transformer layer with attention support
  • โ˜‘๏ธ Full callback system with exportable training logs
  • โ˜‘๏ธ Model persistence and loading with Orbax
  • โ˜‘๏ธ Classical ML algorithms: SVM, KNN, Logistic Regression

๐Ÿ“„ License

JAXFlow is licensed under the Apache-2.0 License.

๐Ÿ’– Built with care for the JAX community. Keep your models clean, fast, and functionalโ€”with JAXFlow.

jaxflow, jax deep learning, jax neural network library, jax model framework, python deep learning, neural network in jax, jaxflow documentation, jaxflow github, functional deep learning, modular jax library