Skip to content

Bert model doesn't seem to instantiate with lora weights #21

@jcrist1

Description

@jcrist1

I tried to instantiate a bert model with the following code:

use candle_core::DType;
use candle_lora::LoraConfig;
use candle_lora_transformers::bert::{BertModel, Config};
use candle_nn::{VarBuilder, VarMap};

fn main() {
    let config = "config.json";
    let device: candle_core::Device = candle_core::Device::Cpu;

    let config_str = std::fs::read_to_string(config).expect("Failed to load config");
    let config: Config = serde_json::from_str(&config_str).expect("failed to parse config");

    let map = VarMap::new();

    let builder = VarBuilder::from_varmap(&map, DType::F32, &device);

    let lora_config = LoraConfig::new(32, 1.0, Some(0.1));
    BertModel::load(builder, &config, false, lora_config)
        .expect("Failed to instantiate bert model");

    let data = map.data().lock().expect("Failed to lock var map data");
    for (key, tensor) in &*data {
        println!("{key}: {:?}", tensor.shape())
    }
}

cargo manifest

[package]
name = "candle-lora-test"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-lora = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
candle-lora-macro = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
candle-lora-transformers = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
serde_json = "1.0.127"

and model config

{
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 1536,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.8.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

Which outputs

embeddings.position_embeddings.lora_embed.a0.weight: [32, 512]
embeddings.LayerNorm.weight: [384]
encoder.layer.2.attention.self.key.bias: [384]
encoder.layer.1.attention.output.dense.bias: [384]
encoder.layer.1.output.LayerNorm.weight: [384]
encoder.layer.4.attention.self.value.weight: [384, 384]
encoder.layer.2.intermediate.dense.bias: [1536]
encoder.layer.4.attention.self.query.bias: [384]
encoder.layer.0.output.LayerNorm.bias: [384]
embeddings.word_embeddings.lora_embed.a0.weight: [32, 30522]
encoder.layer.2.intermediate.dense.weight: [1536, 384]
encoder.layer.0.attention.output.dense.weight: [384, 384]
encoder.layer.3.attention.output.LayerNorm.weight: [384]
encoder.layer.3.output.dense.weight: [384, 1536]
encoder.layer.3.attention.output.dense.weight: [384, 384]
encoder.layer.3.output.LayerNorm.bias: [384]
encoder.layer.4.output.dense.bias: [384]
encoder.layer.5.attention.output.dense.weight: [384, 384]
encoder.layer.3.attention.output.LayerNorm.bias: [384]
encoder.layer.0.attention.output.dense.bias: [384]
encoder.layer.5.output.dense.weight: [384, 1536]
encoder.layer.4.attention.self.key.bias: [384]
encoder.layer.2.attention.output.dense.bias: [384]
encoder.layer.1.attention.self.key.bias: [384]
encoder.layer.0.attention.output.LayerNorm.weight: [384]
encoder.layer.4.intermediate.dense.weight: [1536, 384]
encoder.layer.0.attention.self.value.weight: [384, 384]
encoder.layer.3.attention.self.key.bias: [384]
encoder.layer.3.attention.output.dense.bias: [384]
embeddings.token_type_embeddings.weight: [2, 384]
encoder.layer.4.attention.output.LayerNorm.bias: [384]
encoder.layer.1.output.LayerNorm.bias: [384]
encoder.layer.5.attention.self.value.weight: [384, 384]
encoder.layer.2.attention.output.dense.weight: [384, 384]
encoder.layer.5.attention.self.key.bias: [384]
encoder.layer.5.output.LayerNorm.weight: [384]
encoder.layer.2.output.LayerNorm.weight: [384]
encoder.layer.0.attention.self.key.bias: [384]
encoder.layer.0.output.dense.weight: [384, 1536]
encoder.layer.1.output.dense.bias: [384]
encoder.layer.2.output.dense.weight: [384, 1536]
embeddings.word_embeddings.weight: [30522, 384]
encoder.layer.0.attention.self.query.weight: [384, 384]
encoder.layer.2.attention.output.LayerNorm.weight: [384]
encoder.layer.0.intermediate.dense.bias: [1536]
encoder.layer.2.attention.output.LayerNorm.bias: [384]
encoder.layer.1.attention.self.query.bias: [384]
encoder.layer.4.attention.self.key.weight: [384, 384]
encoder.layer.4.attention.output.dense.bias: [384]
embeddings.position_embeddings.weight: [512, 384]
embeddings.token_type_embeddings.lora_embed.a0.weight: [32, 2]
encoder.layer.1.intermediate.dense.weight: [1536, 384]
encoder.layer.1.attention.self.query.weight: [384, 384]
encoder.layer.1.attention.self.value.weight: [384, 384]
embeddings.position_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.4.output.dense.weight: [384, 1536]
encoder.layer.5.attention.output.LayerNorm.weight: [384]
encoder.layer.5.output.dense.bias: [384]
encoder.layer.0.attention.output.LayerNorm.bias: [384]
encoder.layer.2.output.dense.bias: [384]
embeddings.word_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.1.intermediate.dense.bias: [1536]
encoder.layer.2.attention.self.value.weight: [384, 384]
encoder.layer.2.attention.self.query.bias: [384]
encoder.layer.1.attention.output.LayerNorm.bias: [384]
encoder.layer.1.attention.self.value.bias: [384]
encoder.layer.2.output.LayerNorm.bias: [384]
encoder.layer.3.attention.self.query.weight: [384, 384]
encoder.layer.3.attention.self.value.bias: [384]
encoder.layer.3.attention.self.key.weight: [384, 384]
encoder.layer.1.attention.output.LayerNorm.weight: [384]
encoder.layer.1.attention.output.dense.weight: [384, 384]
embeddings.LayerNorm.bias: [384]
encoder.layer.3.attention.self.value.weight: [384, 384]
encoder.layer.3.intermediate.dense.weight: [1536, 384]
encoder.layer.3.output.dense.bias: [384]
encoder.layer.0.attention.self.key.weight: [384, 384]
encoder.layer.4.attention.self.value.bias: [384]
encoder.layer.3.intermediate.dense.bias: [1536]
encoder.layer.4.attention.output.dense.weight: [384, 384]
encoder.layer.5.attention.self.query.bias: [384]
encoder.layer.5.attention.output.dense.bias: [384]
encoder.layer.5.output.LayerNorm.bias: [384]
encoder.layer.4.attention.output.LayerNorm.weight: [384]
encoder.layer.2.attention.self.query.weight: [384, 384]
encoder.layer.5.attention.self.query.weight: [384, 384]
encoder.layer.5.attention.output.LayerNorm.bias: [384]
encoder.layer.4.intermediate.dense.bias: [1536]
encoder.layer.0.attention.self.value.bias: [384]
encoder.layer.0.output.LayerNorm.weight: [384]
encoder.layer.3.attention.self.query.bias: [384]
encoder.layer.0.intermediate.dense.weight: [1536, 384]
encoder.layer.4.attention.self.query.weight: [384, 384]
encoder.layer.4.output.LayerNorm.weight: [384]
encoder.layer.0.attention.self.query.bias: [384]
encoder.layer.5.intermediate.dense.weight: [1536, 384]
encoder.layer.1.output.dense.weight: [384, 1536]
encoder.layer.4.output.LayerNorm.bias: [384]
encoder.layer.0.output.dense.bias: [384]
encoder.layer.5.attention.self.value.bias: [384]
encoder.layer.2.attention.self.value.bias: [384]
embeddings.token_type_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.2.attention.self.key.weight: [384, 384]
encoder.layer.1.attention.self.key.weight: [384, 384]
encoder.layer.3.output.LayerNorm.weight: [384]
encoder.layer.5.attention.self.key.weight: [384, 384]
encoder.layer.5.intermediate.dense.bias: [1536]

Importantly it doesn't seem to create lora weights for any of the encoder layers, only the embedding layers. I looked at the expanded code, and noticed that the generated constructor for a linear layer looks like this

the lora linear layer
```rust
    impl BertLinear {
        pub fn new(
            vb: VarBuilder,
            weight: Tensor,
            bias: Option<Tensor>,
            merge: bool,
            lora_config: LoraConfig,
        ) -> Self {
            let span = {
                use ::tracing::__macro_support::Callsite as _;
                static __CALLSITE: ::tracing::callsite::DefaultCallsite = {
                    static META: ::tracing::Metadata<'static> = {
                        ::tracing_core::metadata::Metadata::new(
                            "linear",
                            "candle_lora_transformers::bert",
                            tracing::Level::TRACE,
                            ::core::option::Option::Some(
                                "candle-lora-transformers/src/bert.rs",
                            ),
                            ::core::option::Option::Some(58u32),
                            ::core::option::Option::Some(
                                "candle_lora_transformers::bert",
                            ),
                            ::tracing_core::field::FieldSet::new(
                                &[],
                                ::tracing_core::callsite::Identifier(&__CALLSITE),
                            ),
                            ::tracing::metadata::Kind::SPAN,
                        )
                    };
                    ::tracing::callsite::DefaultCallsite::new(&META)
                };
                let mut interest = ::tracing::subscriber::Interest::never();
                if tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
                    && tracing::Level::TRACE
                        <= ::tracing::level_filters::LevelFilter::current()
                    && {
                        interest = __CALLSITE.interest();
                        !interest.is_never()
                    }
                    && ::tracing::__macro_support::__is_enabled(
                        __CALLSITE.metadata(),
                        interest,
                    )
                {
                    let meta = __CALLSITE.metadata();
                    ::tracing::Span::new(meta, &{ meta.fields().value_set(&[]) })
                } else {
                    let span = ::tracing::__macro_support::__disabled_span(
                        __CALLSITE.metadata(),
                    );
                    {};
                    span
                }
            };
            let dims = weight.dims2().unwrap();
            let linear_config = LoraLinearConfig::new(dims.1, dims.0);
            let mut this = Self {
                inner: Arc::new(Linear::new(weight, bias)),
                span,
            };
            if merge {
                this.get_merged_lora_model(
                    lora_config,
                    &vb.pp("lora_linear"),
                    Some(linear_config),
                    None,
                    None,
                    None,
                )
            } else {
                this.get_lora_model(
                    lora_config,
                    &vb.pp("lora_linear"),
                    Some(linear_config),
                    None,
                    None,
                    None,
                )
            }
            this
        }

But when I dig into this.get_lora_model I noticed that it doesn't actually use the self parameter

    impl BertLinear {
        /// Be sure to provide a configuration for each type!
        pub fn get_lora_model<'a>(
            &'a mut self,
            lora_config: candle_lora::LoraConfig,
            vb: &candle_nn::VarBuilder,
            linear_config: Option<candle_lora::LoraLinearConfig>,
            conv1d_config: Option<candle_lora::LoraConv1dConfig>,
            conv2d_config: Option<candle_lora::LoraConv2dConfig>,
            embed_config: Option<candle_lora::LoraEmbeddingConfig>,
        ) {
            let mut linear: ::std::collections::HashMap<
                String,
                &dyn candle_lora::LinearLayerLike,
            > = ::std::collections::HashMap::new();
            let mut conv1d: ::std::collections::HashMap<
                String,
                &dyn candle_lora::Conv1dLayerLike,
            > = ::std::collections::HashMap::new();
            let mut conv2d: ::std::collections::HashMap<
                String,
                &dyn candle_lora::Conv2dLayerLike,
            > = ::std::collections::HashMap::new();
            let mut embed: ::std::collections::HashMap<
                String,
                &dyn candle_lora::EmbeddingLayerLike,
            > = ::std::collections::HashMap::new();
            let mut embed: ::std::collections::HashMap<
                String,
                &dyn candle_lora::EmbeddingLayerLike,
            > = ::std::collections::HashMap::new();
            if !linear.is_empty() && linear_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for linear layers."),
                    );
                };
            }
            if !conv1d.is_empty() && conv1d_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for conv1d layers."),
                    );
                };
            }
            if !conv2d.is_empty() && conv2d_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for conv2d layers."),
                    );
                };
            }
            if !embed.is_empty() && embed_config.is_none() {
                {
                    ::core::panicking::panic_fmt(
                        format_args!("Config not specified for embedding layers."),
                    );
                };
            }
            let mut builder = candle_lora::SelectedLayersBuilder::new();
            if linear_config.is_some() {
                builder = builder.add_linear_layers(linear, linear_config.unwrap());
            }
            if conv1d_config.is_some() {
                builder = builder.add_conv1d_layers(conv1d, conv1d_config.unwrap());
            }
            if conv2d_config.is_some() {
                builder = builder.add_conv2d_layers(conv2d, conv2d_config.unwrap());
            }
            if embed_config.is_some() {
                builder = builder.add_embed_layers(embed, embed_config.unwrap());
            }
            let selection = builder.build();
            let new_layers = candle_lora::Lora::convert_model(
                selection,
                lora_config,
                &vb,
            );
        }

For comparison the get_lora_model of the BertEmbeddings ends with

          // ...
          [
                (self
                    .inner = ::std::sync::Arc::new(
                    new_layers.embed.get("inner").unwrap().clone(),
                )),
            ];

It seems like the macro isn't quite expanding correctly. Could this be the case?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions