From 6f5b4bd64f6bf065c36e8293cfbb4b5514f76aee Mon Sep 17 00:00:00 2001 From: youbo_sun Date: Mon, 24 Apr 2023 20:50:13 +0800 Subject: [PATCH 1/3] # issue 38 update transformers class name:LlamaTokenizer,LlamaForCausalLM --- chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chat.py b/chat.py index ba1f7ee..04b4aa3 100644 --- a/chat.py +++ b/chat.py @@ -24,8 +24,8 @@ def load_model(model_name, eight_bit=0, device_map="auto"): gpu_count = torch.cuda.device_count() print('gpu_count', gpu_count) - tokenizer = transformers.LLaMATokenizer.from_pretrained(model_name) - model = transformers.LLaMAForCausalLM.from_pretrained( + tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name) + model = transformers.LlamaTokenizer.from_pretrained( model_name, #device_map=device_map, #device_map="auto", From add764bdf8443a146feea840a3b922981f5ab52b Mon Sep 17 00:00:00 2001 From: youbo_sun Date: Tue, 25 Apr 2023 22:38:14 +0800 Subject: [PATCH 2/3] # update README.md How to fine-tuning LlamaDecoderLayer --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e451438..ba4097f 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ torchrun --nproc_per_node=4 --master_port= train.py \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --fsdp "full_shard auto_wrap" \ - --fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \ + --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ --tf32 True ``` From 4d247fd2282fdd9b6b122d11e456c976168713cd Mon Sep 17 00:00:00 2001 From: youbo_sun Date: Tue, 25 Apr 2023 22:59:14 +0800 Subject: [PATCH 3/3] # update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d9f02de..8ac4d59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy rouge_score fire openai -git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176 +transformers==4.28.1 torch sentencepiece tokenizers==0.12.1