diff --git a/llms/llama/llama.py b/llms/llama/llama.py index b791a5c20..b7ba28812 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -10,9 +10,13 @@ import mlx.core as mx import mlx.nn as nn +from mlx.nn.layers.distributed import shard_linear from mlx.utils import tree_unflatten from sentencepiece import SentencePieceProcessor +world = mx.distributed.init() +rank = world.rank() + @dataclass class ModelArgs: @@ -48,6 +52,15 @@ def __init__(self, args: ModelArgs): args.head_dim, traditional=args.rope_traditional, base=args.rope_theta ) + def shard(self, group: mx.distributed.Group): + self.n_heads = self.n_heads // group.size() + self.n_kv_heads = self.n_kv_heads // group.size() + + self.wq = shard_linear(self.wq, "all-to-sharded", group=group) + self.wk = shard_linear(self.wk, "all-to-sharded", group=group) + self.wv = shard_linear(self.wv, "all-to-sharded", group=group) + self.wo = shard_linear(self.wo, "sharded-to-all", group=group) + def __call__( self, x: mx.array, @@ -95,6 +108,11 @@ def __init__(self, args: ModelArgs): self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + def shard(self, group: mx.distributed.Group): + self.w1 = shard_linear(self.w1, "all-to-sharded", group=group) + self.w2 = shard_linear(self.w2, "sharded-to-all", group=group) + self.w3 = shard_linear(self.w3, "all-to-sharded", group=group) + def __call__(self, x) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) @@ -195,6 +213,16 @@ def sample(logits): yield y +_builtin_print = print + + +def print(*args, sep=" ", end="\n", file=None, flush=False): + """Overwrite the print statement to only print to terminal on the rank 0 + device so that model output is not doubled.""" + if rank == 0: + _builtin_print(*args, sep=sep, end=end, file=file, flush=flush) + + def tic(): return time.time() @@ -205,7 +233,8 @@ def toc(msg, start): def generate(args): - input("Press enter to start generation") + if rank == 0: + input("Press enter to start generation") print("------") print(args.prompt) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) @@ -339,9 +368,20 @@ def load_model(model_path): quantization = config.pop("quantization", None) model = Llama(ModelArgs(**config)) if quantization is not None: - nn.quantize(model, **quantization) + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) model.update(tree_unflatten(list(weights.items()))) tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) + + if world.size() > 1: + # convert Linear layers in Transformer/FFN to appropriate Sharded Layers + for layer in model.layers: + layer.attention.shard(group=world) + layer.feed_forward.shard(group=world) + return model, tokenizer