From decbac3b7ef03ac0e2b69940fa96b4ccd2c82c0e Mon Sep 17 00:00:00 2001 From: v-zetang Date: Tue, 21 Mar 2023 18:01:15 +0800 Subject: [PATCH 1/2] 1 --- alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py b/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py index ef7bcae..ba3bc37 100644 --- a/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py +++ b/alpaca_lora/scripts/utils/process_llama_megatron_ckpt.py @@ -1,5 +1,5 @@ import torch -import json +import os import argparse @@ -140,7 +140,7 @@ def build_llama_state_dict(llama_dir, llama_file, parallel_size): for parallel_idx, parallel_state in enumerate(split_parameter(llama_state, parallel_size)): state['model'] = parallel_state dump_file = "model-model_part-{}.pt".format(parallel_idx) - torch.save(state, llama_dir + 'megatron_{}/'.format(parallel_size) + dump_file) + torch.save(state, os.path.join(llama_dir + 'megatron_{}/'.format(parallel_size)) + dump_file) print("dump new model to {}{}".format(llama_dir, dump_file)) def main(): From fb73e5d8eb854ac5c059140e9fdafb797fc2c4e9 Mon Sep 17 00:00:00 2001 From: v-zetang Date: Wed, 22 Mar 2023 13:28:21 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=B6=88=E9=99=A4=E6=AD=A7=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alpaca_lora/src/llama_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/alpaca_lora/src/llama_model.py b/alpaca_lora/src/llama_model.py index e2bfda5..34762fd 100644 --- a/alpaca_lora/src/llama_model.py +++ b/alpaca_lora/src/llama_model.py @@ -201,7 +201,7 @@ def get_normalized_probs( else: return utils.softmax(logits, dim=-1) - def forward(self, src_tokens, src_lengths, src_pos, tgt_pos, prev_output_tokens): + def forward(self, src_tokens, src_lengths, src_pos, tgt_pos, tgt_tokens): src_x, src_padding, src_attn, src_hiddens = self.decoder( prev_output_tokens=src_tokens, @@ -215,7 +215,7 @@ def forward(self, src_tokens, src_lengths, src_pos, tgt_pos, prev_output_tokens) incremental_state[layer_idx]['key'] = layer_hidden_states tgt_x, tgt_padding, tgt_attn, tgt_hiddens = self.decoder( - prev_output_tokens=prev_output_tokens, + prev_output_tokens=tgt_tokens, incremental_state=incremental_state, src_pos=src_pos, tgt_pos=tgt_pos, @@ -410,7 +410,7 @@ def forward( ): if incremental_state is not None and trunc_flg: - prev_output_tokens = prev_output_tokens[:, -1:] + prev_output_tokens = prev_output_tokens[:, -1:] bsz, target_len = prev_output_tokens.size() x = self.embed_tokens(prev_output_tokens)