Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

TypeError during gradient computation: type <class 'objax.variable.TrainVar'> is not a valid JAX type #260

@cyugao

Description

@cyugao

I was trying to run a simple example but there are type issues when evaluating the gradients?

TypeError: Argument 'objax.TrainVar(Traced<ConcreteArray([-1.1010288  -0.6818452  -0.95236534], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([-1.1010288 , -0.6818452 , -0.95236534], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = LambdaBinding(), reduce=reduce_mean)' of type <class 'objax.variable.TrainVar'> is not a valid JAX type.

Minimal example from the docs:

import objax
import jax.numpy as jn

n = 1000
ndim = 10
X = objax.random.normal((n, ndim))
y = objax.random.normal((n, 1))
w = objax.TrainVar(jn.zeros(ndim))
b = objax.TrainVar(jn.zeros(1))

def loss(x, y):
    pred = jn.dot(x, w) + b
    return 0.5 * ((y - pred) ** 2).mean()

g_fn = objax.Grad(loss,           # g_fn is Objax module
                  objax.VarCollection({'w': w, 'b': b}))
g_value = g_fn(X, y)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions