Hijax (experimental)

Contents

Hijax (experimental)#

Basic usage#

from flax import nnx
import optax

nnx.var_defaults(hijax=True)

class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x, rngs):
    x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, rngs, x, y):
  graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
  def loss_fn(params):
    model = nnx.merge(graphdef, params, nondiff)
    return ((model(x, rngs) - y) ** 2).mean()
  loss, grads = jax.value_and_grad(loss_fn)(nnx.vars_as(params, mutable=False))
  optimizer.update(model, grads)  # in-place updates
  return loss

nnx.var_defaults(hijax=current_mode)  # clean up for CI tests