Hijax#

from flax import nnx
import jax
import jax.numpy as jnp
import optax

current_mode = nnx.var_defaults().hijax # ignore: only needed for testing
nnx.var_defaults(hijax=True)

rngs = nnx.Rngs(0)
model = nnx.Linear(2, 3, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)

@jax.jit
def train_step(x, y):
  loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
  loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(model, mutable=False))  # tmp fix for jax.grad
  optimizer.update(model, grads)
  return loss

x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3))
for _ in range(3):
  print(train_step(x, y))
0.85250294
0.8165137
0.7814907

Hijax Variable#

State propagation:

v = nnx.Variable(jnp.array(0), hijax=True)

@jax.jit
def inc(v):
  v[...] += 1

print(v[...]); inc(v); print(v[...])
0
1
v = nnx.Variable(jnp.array(0), hijax=True)
print(jax.make_jaxpr(inc)(v))
{ lambda ; a:Variable(). let
    jit[
      name=inc
      jaxpr={ lambda ; a:Variable(). let
          b:i32[] = get_variable[
            avals=(ShapedArray(int32[], weak_type=True),)
            has_qdd=True
            treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))
            var_type=<class 'flax.nnx.variablelib.Variable'>
          ] a
          c:i32[] = add b 1:i32[]
          _:i32[] = get_variable[
            avals=(ShapedArray(int32[], weak_type=True),)
            has_qdd=True
            treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))
            var_type=<class 'flax.nnx.variablelib.Variable'>
          ] a
          set_variable[
            treedef=PyTreeDef(CustomNode(Variable[(('eager_sharding', True), ('hijax', True), ('mutable', True), ('ref', False))], [*]))
            var_type=<class 'flax.nnx.variablelib.Variable'>
          ] a c
        in () }
    ] a
  in () }

Pytree values:

v = nnx.Variable({'a': jnp.array(0), 'b': jnp.array(2)}, hijax=True)

@jax.jit
def inc_and_double(v):
  v['a'] += 1
  v['b'] *= 2

print(v); inc_and_double(v); print(v)
Variable( # 2 (8 B)
  value={'a': Array(0, dtype=int32, weak_type=True), 'b': Array(2, dtype=int32, weak_type=True)},
  hijax=True
)
Variable( # 2 (8 B)
  value={'a': Array(1, dtype=int32, weak_type=True), 'b': Array(4, dtype=int32, weak_type=True)},
  hijax=True
)

Dynamic state structure:

rngs = nnx.Rngs(0)
x = rngs.uniform((4, 5))
w = rngs.normal((5, 3))
metrics = nnx.Variable({}, hijax=True)

@jax.jit
def linear(x, w, metrics: nnx.Variable):
  y = x @ w
  metrics['y_mean'] = jnp.mean(y)
  return y

print("Before:", metrics)
y = linear(x, w, metrics)
print("After:", metrics)
Before: Variable(
  value={},
  hijax=True
)
After: Variable( # 1 (4 B)
  value={'y_mean': Array(-1.1782329, dtype=float32)},
  hijax=True
)
# set default Variable mode for the rest of the guide
nnx.var_defaults(hijax=True)

variable = nnx.Variable(jnp.array([1, 2, 3]))

print(variable)
Variable( # 3 (12 B)
  value=Array([1, 2, 3], dtype=int32),
  hijax=True
)

Mutability#

class Linear(nnx.Module):
  def __init__(self, in_features, out_features, rngs: nnx.Rngs):
    self.kernel = nnx.Param(rngs.normal((in_features, out_features)))

  def __call__(self, x):
    return x @ self.kernel

model = Linear(1, 3, rngs=nnx.Rngs(0))

print(f"{nnx.vars_as(model, mutable=False) = !s}")
print(f"{nnx.vars_as(model, mutable=True) = !s}")
nnx.vars_as(model, mutable=False) = Linear( # Param: 3 (12 B)
  kernel=Param( # 3 (12 B)
    value=Array(shape=(1, 3), dtype=dtype('float32')),
    hijax=True,
    mutable=False
  )
)
nnx.vars_as(model, mutable=True) = Linear( # Param: 3 (12 B)
  kernel=Param( # 3 (12 B)
    value=Array(shape=(1, 3), dtype=dtype('float32')),
    hijax=True
  )
)
v = nnx.Variable(jnp.array(0))
v_immut = nnx.vars_as(v, mutable=False)
assert not v_immut.mutable

try:
  v_immut[...] += 1  # raises an error
except Exception as e:
  print(f"{type(e).__name__}: {e}")
ImmutableVariableError: Cannot mutate Variable as it is marked as immutable. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ImmutableVariableError)

Ref support#

v = nnx.Variable(jnp.array(0))
v_ref = nnx.vars_as(v, ref=True)
assert v_ref.ref
print(v_ref)
print(v_ref.get_raw_value())
Variable( # 1 (4 B)
  value=Array(0, dtype=int32, weak_type=True),
  hijax=True,
  ref=True
)
Ref(0, dtype=int32, weak_type=True)
v_immut = nnx.vars_as(v_ref, mutable=False)
assert not v_immut.ref
print("immutable =", v_immut)

v_ref = nnx.vars_as(v_immut, mutable=True)
assert v_ref.ref
print("mutable =", v_ref)
immutable = Variable( # 1 (4 B)
  value=Array(0, dtype=int32, weak_type=True),
  had_ref=True,
  hijax=True,
  mutable=False
)
mutable = Variable( # 1 (4 B)
  value=Array(0, dtype=int32, weak_type=True),
  hijax=True,
  ref=True
)

Examples#

class Block(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)
    self.linear_out = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.gelu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

Training Loop#

# hijax Variables by default
model = Block(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, x, y):
  graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
  def loss_fn(params):
    model =  nnx.merge(graphdef, params, nondiff)
    return ((model(x) - y) ** 2).mean()

  loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False))  # immutable for jax.grad
  optimizer.update(model, grads)

  return loss

for _ in range(3):
  loss = train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
  print(f"{loss = !s}")
loss = 1.000178
loss = 0.9700456
loss = 0.93967044

Scan Over Layers#

# TODO: does not work with hijax yet
# @jax.vmap
# def create_stack(rngs):
#   return nnx.as_immutable_vars(Block(2, 64, 2, rngs=rngs))

# block_stack = nnx.as_mutable_vars(create_stack(nnx.Rngs(0).fork(split=8)))

# def scan_fn(x, block):
#   x = block(x)
#   return x, None

# x = jax.random.uniform(jax.random.key(0), (3, 2))
# y, _ = jax.lax.scan(scan_fn, x, block_stack)

# print("y = ", y)

Limitations#

Mutable Outputs#

@jax.jit
def create_model(rngs):
  return Block(2, 64, 3, rngs=rngs)

try:
  model = create_model(nnx.Rngs(0))
except Exception as e:
  print(f"Error:", e)
Error: mutable hitypes should use lo_ty_qdd instead
@jax.jit
def create_model(rngs):
  return nnx.vars_as((Block(2, 64, 3, rngs=rngs)), hijax=False)

model = nnx.vars_as(create_model(nnx.Rngs(0)), hijax=True)

print("model.linear =", model.linear)
model.linear = Linear( # Param: 128 (512 B)
  kernel=Param( # 128 (512 B)
    value=Array(shape=(2, 64), dtype=dtype('float32')),
    hijax=True
  )
)

Reference Sharing (aliasing)#

# NOTE: doesn't currently fail on the jax side
def get_error(f, *args):
  try:
    return f(*args)
  except Exception as e:
    return f"{type(e).__name__}: {e}"

x = nnx.Variable(jnp.array(0))

@jax.jit
def f(a, b):
  ...

print(get_error(f, x, x))
None
# NOTE: doesn't currently fail on the jax side
class HasShared(nnx.Pytree):
  def __init__(self):
    self.a = nnx.Variable(jnp.array(0))
    self.b = self.a

@jax.jit
def g(has_shared):
  has_shared.a[...] = 5

has_shared = HasShared()

print(get_error(g, has_shared))
print(has_shared)  # updates don't propagate
None
HasShared( # Variable: 1 (4 B)
  a=Variable( # 1 (4 B)
    value=Array(0, dtype=int32, weak_type=True),
    hijax=True
  ),
  b=Variable( # 1 (4 B)
    value=Array(0, dtype=int32, weak_type=True),
    hijax=True
  )
)
print("Duplicates found:")
if (all_duplicates := nnx.find_duplicates(has_shared)):
  for duplicates in all_duplicates:
    print("-", duplicates)
Duplicates found:
- [('a',), ('b',)]
@jax.jit
def h(graphdef, state):
  has_shared = nnx.merge(graphdef, state)
  has_shared.a[...] = 5

graphdef, state = nnx.split(has_shared)
h(graphdef, state)
print(has_shared)
HasShared( # Variable: 1 (4 B)
  a=Variable( # 1 (4 B)
    value=Array(5, dtype=int32, weak_type=True),
    hijax=True
  ),
  b=Variable( # 1 (4 B)
    value=Array(5, dtype=int32, weak_type=True),
    hijax=True
  )
)
# clean up for CI tests
_ = nnx.var_defaults(hijax=current_mode)