A dead simple Python package for creating custom JAX pytree objects.
- Strives to be minimal, the implementation is just ~100 lines of code
- Has no dependencies other than JAX
- Its compatible with both
dataclassesand regular classes - It has no intention of supporting Neural Network use cases (e.g. partitioning)
Installation
pip install simple-pytree
Usage
import jax from simple_pytree import Pytree class Foo(Pytree): def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) foo = jax.tree_map(lambda x: -x, foo) assert foo.x == -1 and foo.y == -2
Static fields
You can mark fields as static by assigning static_field() to a class attribute with the same name
as the instance attribute:
import jax from simple_pytree import Pytree, static_field class Foo(Pytree): y = static_field() def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) foo = jax.tree_map(lambda x: -x, foo) # y is not modified assert foo.x == -1 and foo.y == 2
Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.
Dataclasses
simple_pytree provides a dataclass decorator you can use with classes
that contain static_fields:
import jax from simple_pytree import Pytree, dataclass, static_field @dataclass class Foo(Pytree): x: int y: int = static_field(default=2) foo = Foo(1) foo = jax.tree_map(lambda x: -x, foo) # y is not modified assert foo.x == -1 and foo.y == 2
simple_pytree.dataclass is just a wrapper around dataclasses.dataclass but
when used static analysis tools and IDEs will understand that static_field is a
field specifier just like dataclasses.field.
Mutability
Pytree objects are immutable by default after __init__:
from simple_pytree import Pytree, static_field class Foo(Pytree): y = static_field() def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) foo.x = 3 # AttributeError
If you want to make them mutable, you can use the mutable argument in class definition:
from simple_pytree import Pytree, static_field class Foo(Pytree, mutable=True): y = static_field() def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) foo.x = 3 # OK
Replacing fields
If you want to make a copy of a Pytree object with some fields modified, you can use the .replace() method:
from simple_pytree import Pytree, static_field class Foo(Pytree): y = static_field() def __init__(self, x, y): self.x = x self.y = y foo = Foo(1, 2) foo = foo.replace(x=10) assert foo.x == 10 and foo.y == 2
replace works for both mutable and immutable Pytree objects. If the class
is a dataclass, replace internally use dataclasses.replace.