Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions llms/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@

import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import shard_linear
from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor

world = mx.distributed.init()
rank = world.rank()


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -48,6 +52,15 @@ def __init__(self, args: ModelArgs):
args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)

def shard(self, group: mx.distributed.Group):
self.n_heads = self.n_heads // group.size()
self.n_kv_heads = self.n_kv_heads // group.size()

self.wq = shard_linear(self.wq, "all-to-sharded", group=group)
self.wk = shard_linear(self.wk, "all-to-sharded", group=group)
self.wv = shard_linear(self.wv, "all-to-sharded", group=group)
self.wo = shard_linear(self.wo, "sharded-to-all", group=group)

def __call__(
self,
x: mx.array,
Expand Down Expand Up @@ -95,6 +108,11 @@ def __init__(self, args: ModelArgs):
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)

def shard(self, group: mx.distributed.Group):
self.w1 = shard_linear(self.w1, "all-to-sharded", group=group)
self.w2 = shard_linear(self.w2, "sharded-to-all", group=group)
self.w3 = shard_linear(self.w3, "all-to-sharded", group=group)

def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))

Expand Down Expand Up @@ -195,6 +213,16 @@ def sample(logits):
yield y


_builtin_print = print


def print(*args, sep=" ", end="\n", file=None, flush=False):
"""Overwrite the print statement to only print to terminal on the rank 0
device so that model output is not doubled."""
if rank == 0:
_builtin_print(*args, sep=sep, end=end, file=file, flush=flush)


def tic():
return time.time()

Expand All @@ -205,7 +233,8 @@ def toc(msg, start):


def generate(args):
input("Press enter to start generation")
if rank == 0:
input("Press enter to start generation")
print("------")
print(args.prompt)
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
Expand Down Expand Up @@ -339,9 +368,20 @@ def load_model(model_path):
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.quantize(model, **quantization)
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
Comment on lines +371 to +374
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about that change? Was it needed for a specific model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was having trouble with running quantized models, specifically Llama-2-4bit, and I saw this being used in the gguf llm example for quantized models.

nn.quantize(model, **quantization, class_predicate=class_predicate)
model.update(tree_unflatten(list(weights.items())))
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))

if world.size() > 1:
# convert Linear layers in Transformer/FFN to appropriate Sharded Layers
for layer in model.layers:
layer.attention.shard(group=world)
layer.feed_forward.shard(group=world)

return model, tokenizer


Expand Down