diff --git a/bert_nli.py b/bert_nli.py index 9171ae2..0325438 100644 --- a/bert_nli.py +++ b/bert_nli.py @@ -17,15 +17,15 @@ def __init__(self,model_path=None,gpu=True,bert_type='bert-base',label_num=3,bat super(BertNLIModel, self).__init__() self.bert_type = bert_type - if 'bert-base' in bert_type: + if 'albert' in bert_type: + self.bert = AlbertModel.from_pretrained(bert_type) + self.tokenizer = AlbertTokenizer.from_pretrained(bert_type) + elif 'bert-base' in bert_type: self.bert = BertModel.from_pretrained('bert-base-uncased') self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') elif 'bert-large' in bert_type: self.bert = BertModel.from_pretrained('bert-large-uncased') self.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') - elif 'albert' in bert_type: - self.bert = AlbertModel.from_pretrained(bert_type) - self.tokenizer = AlbertTokenizer.from_pretrained(bert_type) else: print('illegal bert type {}!'.format(bert_type))