Skip to content

Problem with Loading Weights  #12

@Alexwe12

Description

@Alexwe12

Hi, thank you for your work.
I want to further instruction-tune the mBLIP mt0 model using my data. I have set the blip_pretrained_checkpoint argument to the pytorch_model-00001-of-00002.bin file from the mBLIP repository, and the lm_pretrained argument to the bin files in the mBLIP repository, which correspond to the encoder and decoder, along with the config.json and pytorch_model.bin.index.json files, in the bigscience/mt0-xl repository. However, when I instantiate the mBLIP class and load the weights, I encounter an error in line (https://github.com/gregorge/mBLIP/blob/f804c1f8bb84b13795b71aaaa6fe3f44851c908b/src/modules/modeling/mblip.py#L185C97-L185C97). The process proceeds without issues when I set load_in_8bit=False. But when I call model.generate(**inputs), it produces nonsensical output. Could you please advise on how to train the model based on your provided checkpoint?
Thank you very much.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions