Skip to content

Shape mismatch running vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 checkpoint #182

@seliayeu

Description

@seliayeu

I am unable to run the vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 checkpoint with the config returned by get_config in vmoe/configs/vmoe_paper/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012. This code fragment:

import jax
from vmoe.nn import models
from vmoe.data import input_pipeline
from vmoe.checkpoints import partitioned
from vmoe.configs.vmoe_paper.vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012 import get_config

model = models.VisionTransformerMoe(**get_config()["model"])
checkpoint = partitioned.restore_checkpoint("gs://vmoe_checkpoints/vmoe_b16_imagenet21k_randaug_strong_ft_ilsvrc2012", tree=None)

IMAGE_SIZE = 384
BATCH_SIZE = 1

image = jax.random.uniform(key=jax.random.key(1), shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))
model.apply({'params': checkpoint}, image)

Gives the following error:
TypeError: cannot reshape array of shape (1, 577, 768) (size 443136) into shape (-1, 4616, 768) because the product of specified axis sizes (3545088) does not evenly divide 443136.

Am I using the config wrong? The issue #160 seems to describe the same problem I'm having.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions