diff --git a/ane_transformers/huggingface/distilbert.py b/ane_transformers/huggingface/distilbert.py index 4845c22..c6bd117 100644 --- a/ane_transformers/huggingface/distilbert.py +++ b/ane_transformers/huggingface/distilbert.py @@ -2,6 +2,7 @@ # For licensing see accompanying LICENSE.md file. # Copyright (C) 2022 Apple Inc. All Rights Reserved. # +import re from ane_transformers.reference.layer_norm import LayerNormANE @@ -520,14 +521,26 @@ def forward( return ((loss, ) + output) if loss is not None else output +_LINEAR_TO_CONV2D_LAYERS = [ + "q_lin.weight", + "k_lin.weight", + "v_lin.weight", + "out_lin.weight", + "lin1.weight", + "lin2.weight", + "classifier.weight", + "pre_classifier.weight", + "vocab_transform.weight", + "vocab_projector.weight", + "qa_outputs.weight", +] + + def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """ Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights """ for k in state_dict: - is_internal_proj = all(substr in k for substr in ['lin', '.weight']) - is_output_proj = all(substr in k - for substr in ['classifier', '.weight']) - if is_internal_proj or is_output_proj: + if any(k.endswith(layer) for layer in _LINEAR_TO_CONV2D_LAYERS): if len(state_dict[k].shape) == 2: state_dict[k] = state_dict[k][:, :, None, None] diff --git a/ane_transformers/huggingface/test_distilbert.py b/ane_transformers/huggingface/test_distilbert.py index b14f14d..7e43a07 100644 --- a/ane_transformers/huggingface/test_distilbert.py +++ b/ane_transformers/huggingface/test_distilbert.py @@ -9,7 +9,6 @@ import logging import numpy as np import unittest -import time import torch @@ -32,6 +31,10 @@ ("This is not what I expected!", "NEGATIVE"), ]) +MASKED_LM_MODEL = 'distilbert-base-uncased' +QUESTION_ANSWERING_MODEL = 'distilbert-base-uncased-distilled-squad' +TOKEN_CLASSIFICATION_MODEL = 'elastic/distilbert-base-uncased-finetuned-conll03-english' +MULTIPLE_CHOICE_MODEL = 'Gladiator/distilbert-base-uncased_swag_mqa' class TestDistilBertForSequenceClassification(unittest.TestCase): """ @@ -191,5 +194,56 @@ def test_coreml_conversion_and_speedup(self): ) +class TestDistilBertLoadState(unittest.TestCase): + """ + Test load_state_dict compatibility. + """ + + test_params = ( + ( + MASKED_LM_MODEL, + transformers.AutoModelForMaskedLM, + ane_transformers.DistilBertForMaskedLM, + ), + ( + QUESTION_ANSWERING_MODEL, + transformers.AutoModelForQuestionAnswering, + ane_transformers.DistilBertForQuestionAnswering, + ), + ( + TOKEN_CLASSIFICATION_MODEL, + transformers.AutoModelForTokenClassification, + ane_transformers.DistilBertForTokenClassification, + ), + ( + MULTIPLE_CHOICE_MODEL, + transformers.AutoModelForMultipleChoice, + ane_transformers.DistilBertForMultipleChoice, + ), + ) + + def test_load_state(self): + for model_name, auto_model_cls, ane_model_cls in self.test_params: + with self.subTest(ane_model_cls=ane_model_cls): + try: + # Instantiate the reference model from an exemplar pre-trained + # model hosted on huggingface.co/models + reference_model = auto_model_cls.from_pretrained( + model_name, + return_dict=False, + torchscript=True, + ).eval() + except Exception as e: + raise RuntimeError( + "Failed to download reference model from huggingface.co/models!" + ) from e + logger.info("Downloaded reference model from huggingface.co/models") + + # Initialize an ANE equivalent model and restore the checkpoint + test_model = ane_model_cls(reference_model.config).eval() + test_model.load_state_dict(reference_model.state_dict()) + logger.info("Initialized and restored test model") + + if __name__ == "__main__": unittest.main()