Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions examples/Llama/bmtrain_mgpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Defined by User
export TRIGGER_FILE=bmtrain_mgpu.sh
export SCRIPT_FILE=llama_pretrain.py

# ENVS
export PROJ_HOME=$PWD
export PRE_LOAD_DIR=$PROJ_HOME/checkpoints_in
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=0
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_IB_GID_INDEX=0
export NCCL_IB_HCA=mlx5_2,mlx5_5
export NCCL_DEBUG=debug
export OMP_NUM_THREADS=4

echo "[INFO] $0: hostfile configfile model_name exp_name exp_version"
set -u
hostfile=$1
configfile=$2
model_name=$3
exp_name=$4
exp_version=$5
set +u

# DIST
export HOSTFILE=$hostfile
export CONFIGFILE=$configfile
export NODE_ADDR=$(ifconfig -a|grep inet|grep -v 127.0.0.1|grep -v inet6|awk '{print $2;}'|tr -d "addr:")
export GPU_NUM_PER_NODE=$(awk -F" |=" '{ranks[$1]=$NF;}END{print ranks["'$NODE_ADDR'"];}' $HOSTFILE)
export NODES_NUM=$(cat $HOSTFILE | wc -l)
export MASTER_ADDR=$(head -n1 $HOSTFILE | awk '{print $1;}')
export RANK=$(awk '{ranks[$1]=(FNR-1);}END{print ranks["'$NODE_ADDR'"];}' $HOSTFILE)
export MASTER_PORT=23456


## wandb
export WANDB_MODE=offline

## EXP
export MODEL_NAME=$model_name
export EXP_NAME=$exp_name
export WANDB_DIR=$PROJ_HOME/wandb/${EXP_NAME}/$exp_version
mkdir -p $PROJ_HOME/checkpoints_out
export SAVE_DIR=$PROJ_HOME/checkpoints_out/${EXP_NAME}/$exp_version
mkdir -p $SAVE_DIR
mkdir -p $WANDB_DIR
## Backup ckpts & scripts into exp versions
cp -r $PRE_LOAD_DIR/$MODEL_NAME $SAVE_DIR
cp -r $PROJ_HOME/$TRIGGER_FILE $SAVE_DIR
cp -r $hostfile $SAVE_DIR
cp -r $configfile $SAVE_DIR

export EPOCH_NUM=1
export BATCH_SIZE=6
export GRADIENT_ACCUM_STEPS=1
export LR=3.0e-4
export LR=1.0e-5
export LR=6.0e-5
export WARMUP_RATE=0.008
export WARMUP_RATE=0.02
export WARMUP_RATE=0.1
export WARMUP_RATE=0.2

## EXTRA OPTS
OPTS=" --batch_size $BATCH_SIZE \
--epochs $EPOCH_NUM \
--gradient_accumulation_steps $GRADIENT_ACCUM_STEPS \
--lr $LR \
--warm_up $WARMUP_RATE \
--weight_decay 0.1 \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--save_dir $SAVE_DIR \
--pre_load_dir $PRE_LOAD_DIR \
--experiment_name $EXP_NAME \
--model_name $MODEL_NAME \
--wandb_dir $WANDB_DIR \
--yaml_config $CONFIGFILE"

## Trigger job on Each Node when bmt or ddp.

mkdir -p $PRE_LOAD_DIR
torchrun \
--nproc_per_node $GPU_NUM_PER_NODE \
--nnodes $NODES_NUM \
--node_rank $RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
$SCRIPT_FILE \
--not_call_launch \
$OPTS
40 changes: 40 additions & 0 deletions examples/Llama/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright © 2023 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.data.tokenizer import Tokenizer
import transformers

state_dict = "./checkpoints_in/"
model_name = 'Llama-3.1-8B'

loader = AutoLoader("llama3",
model_dir=state_dict,
model_name=model_name,
device='cuda',
use_cache=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()

model.eval()

model.cuda()

print("model loaded")

text = "Gravity is "

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=1024
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

print("content:", content)
10 changes: 10 additions & 0 deletions examples/Llama/llama-pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
batch_size: 1
gradient_accumulation_steps: 1
lr: 1.0e-5
warm_up: 0.01
save_interval: 100
log_interval: 1
bmt_loss_scale: 1.0
save_optim: True
save_rng: True
eps: 1.e-8
82 changes: 82 additions & 0 deletions examples/Llama/llama_bmt_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional, Tuple
import warnings

import torch
from torch import nn
import transformers

from transformers.cache_utils import Cache, DynamicCache
from transformers.processing_utils import Unpack
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.utils import TransformersKwargs, auto_docstring
from transformers.utils.generic import check_model_inputs
from transformers.masking_utils import create_causal_mask


def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = DynamicCache()

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)

hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)


def replace_llama_attn_with_bmt():
print("replace_llama_attn_with_bmt")
transformers.models.llama.modeling_llama.LlamaModel.forward = forward

141 changes: 141 additions & 0 deletions examples/Llama/llama_pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")
import os
import torch
from torch.utils.data import Dataset
import gc
import json

gc.collect()
torch.cuda.empty_cache()

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from llama_bmt_monkey_patch import (
replace_llama_attn_with_bmt,
)

from flagai.env_args import EnvArgs
from flagai.env_trainer_v1 import EnvTrainer
from flagai.data.dataset.indexed_dataset.build_index_mappings import _build_train_valid_test_datasets, _build_train_valid_test_weighted_datasets
import bmtrain as bmt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# You can input all parameters by the command line.
# For example: python train_env_trainer.py --epochs=300 --batch_size=4 --env_type=pytorch
env_args = EnvArgs(
env_type="bmtrain",
experiment_name="llama3",
batch_size=1,
gradient_accumulation_steps=1,
lr=2e-4,
weight_decay=1e-3,
epochs=10000,
log_interval=1,
eval_interval=5000,
num_gpus=1,
load_dir=None,
pytorch_device=device,
save_dir="checkpoints_out",
checkpoint_activations=False,
save_interval=100,
fp16=True,
training_script=__file__,
)
env_args = env_args.parse_args()
#env_args.wandb = False

# overwrite
if env_args.yaml_config:
import yaml
file_data = open(env_args.yaml_config, 'r', encoding="utf-8").read()
data = yaml.load_all(file_data, Loader=yaml.SafeLoader)
delattr(env_args, 'yaml_config')
arg_dict = env_args.__dict__
for subdata in data:
for key, value in subdata.items():
if isinstance(value, list):
for v in value:
arg_dict[key].append(v)
else:
arg_dict[key] = value
trainer = EnvTrainer(env_args)

# Trainer as Trigger
if not env_args.not_call_launch:
import sys
sys.exit(0)

print(f"Trainer effective env_args={env_args} local_rank={os.environ['LOCAL_RANK']}",
flush=True)
checkpoints = env_args.pre_load_dir
model_name = env_args.model_name

print('*' * 20, "model_name", model_name, flush=True)

cache_dir = os.path.join(checkpoints, model_name)
print('*' * 20, "cache_dir", cache_dir)
tokenizer = AutoTokenizer.from_pretrained(cache_dir)
print('*' * 20, "tokenizer", tokenizer)

# avoid sync loading models in case of Mem OOM
if env_args.bmt_async_load:
import time
time.sleep(10 * 60 * (os.environ['LOCAL_RANK'] % 4))

config_file = os.path.join(cache_dir, 'config.json')
with open(config_file, 'r') as f:
model_args = json.load(f)

# bmt
replace_llama_attn_with_bmt()

model = LlamaForCausalLM.from_pretrained(cache_dir)

## bmt_pre_load

trainer.pre_train(model)

print('*' * 20, "model", model, flush=True)

## Use Prebuilt DataSets
data_prefix = '../indexed_dataset/data/demo_text_document'
data_impl = 'mmap'
splits_string = '90,10'
train_valid_test_num_samples = [90, 10]
seq_length = 1024
seed = 2023
skip_warmup = True

train_dataset, valid_dataset, _ = _build_train_valid_test_datasets(
data_prefix, data_impl, splits_string, train_valid_test_num_samples,
seq_length, seed, skip_warmup)
print("Total train_dataset: ", len(train_dataset), flush=True)
print("Total valid_dataset: ", len(valid_dataset), flush=True)


def collate_fn(batch):

def padding(indice, max_length, pad_idx=0):
pad_indice = [
item.tolist() + [pad_idx] * max(0, max_length - len(item.tolist()))
for item in indice
]
return torch.tensor(pad_indice)

input_ids = [data["input_ids"] for data in batch]
max_length = max([len(t) for t in input_ids])
input_ids = padding(input_ids, max_length)[:, :seq_length]

data = {"input_ids": input_ids, "labels": input_ids}
return data


trainer.do_train(train_dataset=train_dataset,
valid_dataset=None,
collate_fn=collate_fn,
optimizer=None,
rank_split=False)

29 changes: 29 additions & 0 deletions examples/Llama/local_trigger_docker.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash
#
# Defined by user
export PROJ_HOME=$PWD

echo "[INFO] $0: hostfile configfile model_name exp_name"
set -u
hostfile=$1
configfile=$2
model_name=$3
exp_name=$4
set +u
NODES_NUM=`cat $hostfile |wc -l`
echo "NODES_NUM": $NODES_NUM
if [ $NODES_NUM -ne 1 ];then
echo "Make Sure One Node in hostfile"
exit 0
fi

exp_YYYYMMDDHH=$(date +"%Y%m%d%H")
echo "exp_YYYYMMDDHH": $exp_YYYYMMDDHH

SAVE_DIR=$PROJ_HOME/checkpoints_out/${exp_name}/$exp_YYYYMMDDHH
LOGFILE=$SAVE_DIR/$configfile.log.txt
echo "LOGFILE": $LOGFILE

cd $PROJ_HOME;
mkdir -p $SAVE_DIR;
bash bmtrain_mgpu.sh $hostfile $configfile $model_name $exp_name $exp_YYYYMMDDHH 1>$LOGFILE 2>&1 &
Loading
Loading