-
Notifications
You must be signed in to change notification settings - Fork 33
Open
Description
I tried to instantiate a bert model with the following code:
use candle_core::DType;
use candle_lora::LoraConfig;
use candle_lora_transformers::bert::{BertModel, Config};
use candle_nn::{VarBuilder, VarMap};
fn main() {
let config = "config.json";
let device: candle_core::Device = candle_core::Device::Cpu;
let config_str = std::fs::read_to_string(config).expect("Failed to load config");
let config: Config = serde_json::from_str(&config_str).expect("failed to parse config");
let map = VarMap::new();
let builder = VarBuilder::from_varmap(&map, DType::F32, &device);
let lora_config = LoraConfig::new(32, 1.0, Some(0.1));
BertModel::load(builder, &config, false, lora_config)
.expect("Failed to instantiate bert model");
let data = map.data().lock().expect("Failed to lock var map data");
for (key, tensor) in &*data {
println!("{key}: {:?}", tensor.shape())
}
}cargo manifest
[package]
name = "candle-lora-test"
version = "0.1.0"
edition = "2021"
[dependencies]
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.6.0"}
candle-lora = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
candle-lora-macro = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
candle-lora-transformers = { git = "https://github.com/EricLBuehler/candle-lora.git", version = "0.2.0" }
serde_json = "1.0.127"and model config
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 384,
"initializer_range": 0.02,
"intermediate_size": 1536,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 6,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.8.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}Which outputs
embeddings.position_embeddings.lora_embed.a0.weight: [32, 512]
embeddings.LayerNorm.weight: [384]
encoder.layer.2.attention.self.key.bias: [384]
encoder.layer.1.attention.output.dense.bias: [384]
encoder.layer.1.output.LayerNorm.weight: [384]
encoder.layer.4.attention.self.value.weight: [384, 384]
encoder.layer.2.intermediate.dense.bias: [1536]
encoder.layer.4.attention.self.query.bias: [384]
encoder.layer.0.output.LayerNorm.bias: [384]
embeddings.word_embeddings.lora_embed.a0.weight: [32, 30522]
encoder.layer.2.intermediate.dense.weight: [1536, 384]
encoder.layer.0.attention.output.dense.weight: [384, 384]
encoder.layer.3.attention.output.LayerNorm.weight: [384]
encoder.layer.3.output.dense.weight: [384, 1536]
encoder.layer.3.attention.output.dense.weight: [384, 384]
encoder.layer.3.output.LayerNorm.bias: [384]
encoder.layer.4.output.dense.bias: [384]
encoder.layer.5.attention.output.dense.weight: [384, 384]
encoder.layer.3.attention.output.LayerNorm.bias: [384]
encoder.layer.0.attention.output.dense.bias: [384]
encoder.layer.5.output.dense.weight: [384, 1536]
encoder.layer.4.attention.self.key.bias: [384]
encoder.layer.2.attention.output.dense.bias: [384]
encoder.layer.1.attention.self.key.bias: [384]
encoder.layer.0.attention.output.LayerNorm.weight: [384]
encoder.layer.4.intermediate.dense.weight: [1536, 384]
encoder.layer.0.attention.self.value.weight: [384, 384]
encoder.layer.3.attention.self.key.bias: [384]
encoder.layer.3.attention.output.dense.bias: [384]
embeddings.token_type_embeddings.weight: [2, 384]
encoder.layer.4.attention.output.LayerNorm.bias: [384]
encoder.layer.1.output.LayerNorm.bias: [384]
encoder.layer.5.attention.self.value.weight: [384, 384]
encoder.layer.2.attention.output.dense.weight: [384, 384]
encoder.layer.5.attention.self.key.bias: [384]
encoder.layer.5.output.LayerNorm.weight: [384]
encoder.layer.2.output.LayerNorm.weight: [384]
encoder.layer.0.attention.self.key.bias: [384]
encoder.layer.0.output.dense.weight: [384, 1536]
encoder.layer.1.output.dense.bias: [384]
encoder.layer.2.output.dense.weight: [384, 1536]
embeddings.word_embeddings.weight: [30522, 384]
encoder.layer.0.attention.self.query.weight: [384, 384]
encoder.layer.2.attention.output.LayerNorm.weight: [384]
encoder.layer.0.intermediate.dense.bias: [1536]
encoder.layer.2.attention.output.LayerNorm.bias: [384]
encoder.layer.1.attention.self.query.bias: [384]
encoder.layer.4.attention.self.key.weight: [384, 384]
encoder.layer.4.attention.output.dense.bias: [384]
embeddings.position_embeddings.weight: [512, 384]
embeddings.token_type_embeddings.lora_embed.a0.weight: [32, 2]
encoder.layer.1.intermediate.dense.weight: [1536, 384]
encoder.layer.1.attention.self.query.weight: [384, 384]
encoder.layer.1.attention.self.value.weight: [384, 384]
embeddings.position_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.4.output.dense.weight: [384, 1536]
encoder.layer.5.attention.output.LayerNorm.weight: [384]
encoder.layer.5.output.dense.bias: [384]
encoder.layer.0.attention.output.LayerNorm.bias: [384]
encoder.layer.2.output.dense.bias: [384]
embeddings.word_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.1.intermediate.dense.bias: [1536]
encoder.layer.2.attention.self.value.weight: [384, 384]
encoder.layer.2.attention.self.query.bias: [384]
encoder.layer.1.attention.output.LayerNorm.bias: [384]
encoder.layer.1.attention.self.value.bias: [384]
encoder.layer.2.output.LayerNorm.bias: [384]
encoder.layer.3.attention.self.query.weight: [384, 384]
encoder.layer.3.attention.self.value.bias: [384]
encoder.layer.3.attention.self.key.weight: [384, 384]
encoder.layer.1.attention.output.LayerNorm.weight: [384]
encoder.layer.1.attention.output.dense.weight: [384, 384]
embeddings.LayerNorm.bias: [384]
encoder.layer.3.attention.self.value.weight: [384, 384]
encoder.layer.3.intermediate.dense.weight: [1536, 384]
encoder.layer.3.output.dense.bias: [384]
encoder.layer.0.attention.self.key.weight: [384, 384]
encoder.layer.4.attention.self.value.bias: [384]
encoder.layer.3.intermediate.dense.bias: [1536]
encoder.layer.4.attention.output.dense.weight: [384, 384]
encoder.layer.5.attention.self.query.bias: [384]
encoder.layer.5.attention.output.dense.bias: [384]
encoder.layer.5.output.LayerNorm.bias: [384]
encoder.layer.4.attention.output.LayerNorm.weight: [384]
encoder.layer.2.attention.self.query.weight: [384, 384]
encoder.layer.5.attention.self.query.weight: [384, 384]
encoder.layer.5.attention.output.LayerNorm.bias: [384]
encoder.layer.4.intermediate.dense.bias: [1536]
encoder.layer.0.attention.self.value.bias: [384]
encoder.layer.0.output.LayerNorm.weight: [384]
encoder.layer.3.attention.self.query.bias: [384]
encoder.layer.0.intermediate.dense.weight: [1536, 384]
encoder.layer.4.attention.self.query.weight: [384, 384]
encoder.layer.4.output.LayerNorm.weight: [384]
encoder.layer.0.attention.self.query.bias: [384]
encoder.layer.5.intermediate.dense.weight: [1536, 384]
encoder.layer.1.output.dense.weight: [384, 1536]
encoder.layer.4.output.LayerNorm.bias: [384]
encoder.layer.0.output.dense.bias: [384]
encoder.layer.5.attention.self.value.bias: [384]
encoder.layer.2.attention.self.value.bias: [384]
embeddings.token_type_embeddings.lora_embed.b0.weight: [384, 32]
encoder.layer.2.attention.self.key.weight: [384, 384]
encoder.layer.1.attention.self.key.weight: [384, 384]
encoder.layer.3.output.LayerNorm.weight: [384]
encoder.layer.5.attention.self.key.weight: [384, 384]
encoder.layer.5.intermediate.dense.bias: [1536]
Importantly it doesn't seem to create lora weights for any of the encoder layers, only the embedding layers. I looked at the expanded code, and noticed that the generated constructor for a linear layer looks like this
the lora linear layer
```rust
impl BertLinear {
pub fn new(
vb: VarBuilder,
weight: Tensor,
bias: Option<Tensor>,
merge: bool,
lora_config: LoraConfig,
) -> Self {
let span = {
use ::tracing::__macro_support::Callsite as _;
static __CALLSITE: ::tracing::callsite::DefaultCallsite = {
static META: ::tracing::Metadata<'static> = {
::tracing_core::metadata::Metadata::new(
"linear",
"candle_lora_transformers::bert",
tracing::Level::TRACE,
::core::option::Option::Some(
"candle-lora-transformers/src/bert.rs",
),
::core::option::Option::Some(58u32),
::core::option::Option::Some(
"candle_lora_transformers::bert",
),
::tracing_core::field::FieldSet::new(
&[],
::tracing_core::callsite::Identifier(&__CALLSITE),
),
::tracing::metadata::Kind::SPAN,
)
};
::tracing::callsite::DefaultCallsite::new(&META)
};
let mut interest = ::tracing::subscriber::Interest::never();
if tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
&& tracing::Level::TRACE
<= ::tracing::level_filters::LevelFilter::current()
&& {
interest = __CALLSITE.interest();
!interest.is_never()
}
&& ::tracing::__macro_support::__is_enabled(
__CALLSITE.metadata(),
interest,
)
{
let meta = __CALLSITE.metadata();
::tracing::Span::new(meta, &{ meta.fields().value_set(&[]) })
} else {
let span = ::tracing::__macro_support::__disabled_span(
__CALLSITE.metadata(),
);
{};
span
}
};
let dims = weight.dims2().unwrap();
let linear_config = LoraLinearConfig::new(dims.1, dims.0);
let mut this = Self {
inner: Arc::new(Linear::new(weight, bias)),
span,
};
if merge {
this.get_merged_lora_model(
lora_config,
&vb.pp("lora_linear"),
Some(linear_config),
None,
None,
None,
)
} else {
this.get_lora_model(
lora_config,
&vb.pp("lora_linear"),
Some(linear_config),
None,
None,
None,
)
}
this
}But when I dig into this.get_lora_model I noticed that it doesn't actually use the self parameter
impl BertLinear {
/// Be sure to provide a configuration for each type!
pub fn get_lora_model<'a>(
&'a mut self,
lora_config: candle_lora::LoraConfig,
vb: &candle_nn::VarBuilder,
linear_config: Option<candle_lora::LoraLinearConfig>,
conv1d_config: Option<candle_lora::LoraConv1dConfig>,
conv2d_config: Option<candle_lora::LoraConv2dConfig>,
embed_config: Option<candle_lora::LoraEmbeddingConfig>,
) {
let mut linear: ::std::collections::HashMap<
String,
&dyn candle_lora::LinearLayerLike,
> = ::std::collections::HashMap::new();
let mut conv1d: ::std::collections::HashMap<
String,
&dyn candle_lora::Conv1dLayerLike,
> = ::std::collections::HashMap::new();
let mut conv2d: ::std::collections::HashMap<
String,
&dyn candle_lora::Conv2dLayerLike,
> = ::std::collections::HashMap::new();
let mut embed: ::std::collections::HashMap<
String,
&dyn candle_lora::EmbeddingLayerLike,
> = ::std::collections::HashMap::new();
let mut embed: ::std::collections::HashMap<
String,
&dyn candle_lora::EmbeddingLayerLike,
> = ::std::collections::HashMap::new();
if !linear.is_empty() && linear_config.is_none() {
{
::core::panicking::panic_fmt(
format_args!("Config not specified for linear layers."),
);
};
}
if !conv1d.is_empty() && conv1d_config.is_none() {
{
::core::panicking::panic_fmt(
format_args!("Config not specified for conv1d layers."),
);
};
}
if !conv2d.is_empty() && conv2d_config.is_none() {
{
::core::panicking::panic_fmt(
format_args!("Config not specified for conv2d layers."),
);
};
}
if !embed.is_empty() && embed_config.is_none() {
{
::core::panicking::panic_fmt(
format_args!("Config not specified for embedding layers."),
);
};
}
let mut builder = candle_lora::SelectedLayersBuilder::new();
if linear_config.is_some() {
builder = builder.add_linear_layers(linear, linear_config.unwrap());
}
if conv1d_config.is_some() {
builder = builder.add_conv1d_layers(conv1d, conv1d_config.unwrap());
}
if conv2d_config.is_some() {
builder = builder.add_conv2d_layers(conv2d, conv2d_config.unwrap());
}
if embed_config.is_some() {
builder = builder.add_embed_layers(embed, embed_config.unwrap());
}
let selection = builder.build();
let new_layers = candle_lora::Lora::convert_model(
selection,
lora_config,
&vb,
);
}For comparison the get_lora_model of the BertEmbeddings ends with
// ...
[
(self
.inner = ::std::sync::Arc::new(
new_layers.embed.get("inner").unwrap().clone(),
)),
];It seems like the macro isn't quite expanding correctly. Could this be the case?
Metadata
Metadata
Assignees
Labels
No labels