Hijax#
from flax import nnx
import jax
import jax.numpy as jnp
import optax
current_mode = nnx.var_defaults().hijax # ignore: only needed for testing
nnx.var_defaults(hijax=True)
rngs = nnx.Rngs(0)
model = nnx.Linear(2, 3, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)
@jax.jit
def train_step(x, y):
loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False)) # tmp fix for jax.grad
optimizer.update(model, grads)
return loss
x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3))
for _ in range(3):
print(train_step(x, y))
0.85250294
0.8165137
0.7814907
Hijax Variable#
State propagation:
v = nnx.Variable(jnp.array(0), hijax=True)
@jax.jit
def inc(v):
v[...] += 1
print(v[...]); inc(v); print(v[...])
0
1
v = nnx.Variable(jnp.array(0), hijax=True)
print(jax.make_jaxpr(inc)(v))
{ lambda ; a:Variable(). let
jit[
name=inc
jaxpr={ lambda ; a:Variable(). let
b:i32[] = get_variable[
avals=(ShapedArray(int32[], weak_type=True),)
has_qdd=True
treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))
var_type=<class 'flax.nnx.variablelib.Variable'>
] a
c:i32[] = add b 1:i32[]
_:i32[] = get_variable[
avals=(ShapedArray(int32[], weak_type=True),)
has_qdd=True
treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))
var_type=<class 'flax.nnx.variablelib.Variable'>
] a
set_variable[
treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))
var_type=<class 'flax.nnx.variablelib.Variable'>
] a c
in () }
] a
in () }
Pytree values:
v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, hijax=True)
@jax.jit
def inc_and_double(v):
v['a'] += 1
v['b'] *= 2
print(v); inc_and_double(v); print(v)
Variable( # 2 (8 B)
value={'a': Array(0, dtype=int32, weak_type=True), 'b': Array(2, dtype=int32, weak_type=True)},
hijax=True
)
Variable( # 2 (8 B)
value={'a': Array(1, dtype=int32, weak_type=True), 'b': Array(4, dtype=int32, weak_type=True)},
hijax=True
)
Dynamic state structure:
rngs = nnx.Rngs(0)
x = rngs.uniform((4, 5))
w = rngs.normal((5, 3))
metrics = nnx.Variable({}, hijax=True)
@jax.jit
def linear(x, w, metrics: nnx.Variable):
y = x @ w
metrics['y_mean'] = jnp.mean(y)
return y
print("Before:", metrics)
y = linear(x, w, metrics)
print("After:", metrics)
Before: Variable(
value={},
hijax=True
)
After: Variable( # 1 (4 B)
value={'y_mean': Array(-1.1782329, dtype=float32)},
hijax=True
)
# set default Variable mode for the rest of the guide
nnx.var_defaults(hijax=True)
variable = nnx.Variable(jnp.array([1, 2, 3]))
print(variable)
Variable( # 3 (12 B)
value=Array([1, 2, 3], dtype=int32),
hijax=True
)
Mutability#
class Linear(nnx.Module):
def __init__(self, in_features, out_features, rngs: nnx.Rngs):
self.kernel = nnx.Param(rngs.normal((in_features, out_features)))
def __call__(self, x):
return x @ self.kernel
model = Linear(1, 3, rngs=nnx.Rngs(0))
print(f"{nnx.vars_as(model, mutable=False) = !s}")
print(f"{nnx.vars_as(model, mutable=True) = !s}")
nnx.vars_as(model, mutable=False) = Linear( # Param: 3 (12 B)
kernel=Param( # 3 (12 B)
value=Array(shape=(1, 3), dtype=dtype('float32')),
hijax=True,
mutable=False
)
)
nnx.vars_as(model, mutable=True) = Linear( # Param: 3 (12 B)
kernel=Param( # 3 (12 B)
value=Array(shape=(1, 3), dtype=dtype('float32')),
hijax=True
)
)
v = nnx.Variable(jnp.array(0))
v_immut = nnx.vars_as(v, mutable=False)
assert not v_immut.mutable
try:
v_immut[...] += 1 # raises an error
except Exception as e:
print(f"{type(e).__name__}: {e}")
ImmutableVariableError: Cannot mutate Variable as it is marked as immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ImmutableVariableError)
Ref support#
v = nnx.Variable(jnp.array(0))
v_ref = nnx.vars_as(v, ref=True)
assert v_ref.ref
print(v_ref)
print(v_ref.get_raw_value())
Variable( # 1 (4 B)
value=Array(0, dtype=int32, weak_type=True),
hijax=True,
ref=True
)
Ref(0, dtype=int32, weak_type=True)
v_immut = nnx.vars_as(v_ref, mutable=False)
assert not v_immut.ref
print("immutable =", v_immut)
v_ref = nnx.vars_as(v_immut, mutable=True)
assert v_ref.ref
print("mutable =", v_ref)
immutable = Variable( # 1 (4 B)
value=Array(0, dtype=int32, weak_type=True),
had_ref=True,
hijax=True,
mutable=False
)
mutable = Variable( # 1 (4 B)
value=Array(0, dtype=int32, weak_type=True),
hijax=True,
ref=True
)
Examples#
class Block(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.linear_out = Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.gelu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
Training Loop#
# hijax Variables by default
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
@jax.jit
def train_step(model, optimizer, x, y):
graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
def loss_fn(params):
model = nnx.merge(graphdef, params, nondiff)
return ((model(x) - y) ** 2).mean()
loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False)) # immutable for jax.grad
optimizer.update(model, grads)
return loss
for _ in range(3):
loss = train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
print(f"{loss = !s}")
loss = 1.000178
loss = 0.9700456
loss = 0.93967044
Scan Over Layers#
# TODO: does not work with hijax yet
# @jax.vmap
# def create_stack(rngs):
# return nnx.as_immutable_vars(Block(2, 64, 2, rngs=rngs))
# block_stack = nnx.as_mutable_vars(create_stack(nnx.Rngs(0).fork(split=8)))
# def scan_fn(x, block):
# x = block(x)
# return x, None
# x = jax.random.uniform(jax.random.key(0), (3, 2))
# y, _ = jax.lax.scan(scan_fn, x, block_stack)
# print("y = ", y)
Limitations#
Mutable Outputs#
@jax.jit
def create_model(rngs):
return Block(2, 64, 3, rngs=rngs)
try:
model = create_model(nnx.Rngs(0))
except Exception as e:
print(f"Error:", e)
Error: mutable hitypes should use lo_ty_qdd instead
@jax.jit
def create_model(rngs):
return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)
model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)
print("model.linear =", model.linear)
model.linear = Linear( # Param: 128 (512 B)
kernel=Param( # 128 (512 B)
value=Array(shape=(2, 64), dtype=dtype('float32')),
hijax=True
)
)
Reference Sharing (aliasing)#
# NOTE: doesn't currently fail on the jax side
def get_error(f, *args):
try:
return f(*args)
except Exception as e:
return f"{type(e).__name__}: {e}"
x = nnx.Variable(jnp.array(0))
@jax.jit
def f(a, b):
...
print(get_error(f, x, x))
None
# NOTE: doesn't currently fail on the jax side
class HasShared(nnx.Pytree):
def __init__(self):
self.a = nnx.Variable(jnp.array(0))
self.b = self.a
@jax.jit
def g(has_shared):
has_shared.a[...] = 5
has_shared = HasShared()
print(get_error(g, has_shared))
print(has_shared) # updates don't propagate
None
HasShared( # Variable: 1 (4 B)
a=Variable( # 1 (4 B)
value=Array(0, dtype=int32, weak_type=True),
hijax=True
),
b=Variable( # 1 (4 B)
value=Array(0, dtype=int32, weak_type=True),
hijax=True
)
)
print("Duplicates found:")
if (all_duplicates := nnx.find_duplicates(has_shared)):
for duplicates in all_duplicates:
print("-", duplicates)
Duplicates found:
- [('a',), ('b',)]
@jax.jit
def h(graphdef, state):
has_shared = nnx.merge(graphdef, state)
has_shared.a[...] = 5
graphdef, state = nnx.split(has_shared)
h(graphdef, state)
print(has_shared)
HasShared( # Variable: 1 (4 B)
a=Variable( # 1 (4 B)
value=Array(5, dtype=int32, weak_type=True),
hijax=True
),
b=Variable( # 1 (4 B)
value=Array(5, dtype=int32, weak_type=True),
hijax=True
)
)
# clean up for CI tests
_ = nnx.var_defaults(hijax=current_mode)