-
Notifications
You must be signed in to change notification settings - Fork 53
Open
Description
I am unable to run the vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 checkpoint with the config returned by get_config in vmoe/configs/vmoe_paper/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012. This code fragment:
import jax
from vmoe.nn import models
from vmoe.data import input_pipeline
from vmoe.checkpoints import partitioned
from vmoe.configs.vmoe_paper.vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 import get_config
model = models.VisionTransformerMoe(**get_config()["model"])
checkpoint = partitioned.restore_checkpoint("gs://vmoe_checkpoints/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012", tree=None)
IMAGE_SIZE = 384
BATCH_SIZE = 1
image = jax.random.uniform(key=jax.random.key(1), shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))
model.apply({'params': checkpoint}, image)Gives the following error:
TypeError: cannot reshape array of shape (1, 577, 768) (size 443136) into shape (-1, 4616, 768) because the product of specified axis sizes (3545088) does not evenly divide 443136.
Am I using the config wrong? The issue #160 seems to describe the same problem I'm having.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels