From 6499e504235e4bf0211930773108341c6ca1c88c Mon Sep 17 00:00:00 2001 From: stefpi <19478336+stefpi@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:36:52 -0500 Subject: [PATCH 1/7] llama tp --- llms/llama/llama_tp.py | 402 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 402 insertions(+) create mode 100644 llms/llama/llama_tp.py diff --git a/llms/llama/llama_tp.py b/llms/llama/llama_tp.py new file mode 100644 index 000000000..495e4b592 --- /dev/null +++ b/llms/llama/llama_tp.py @@ -0,0 +1,402 @@ +# Copyright © 2023 Apple Inc. + +import argparse +import glob +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from sentencepiece import SentencePieceProcessor + +world = mx.distributed.init() +rank = world.rank() +world_size = world.size() + +@dataclass +class ModelArgs: + dim: int + n_layers: int + head_dim: int + hidden_dim: int + n_heads: int + n_kv_heads: int + norm_eps: float + vocab_size: int + rope_theta: float + rope_traditional: bool = True + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + self.n_heads: int = args.n_heads + self.n_kv_heads: int = args.n_kv_heads + + # Number of heads per rank (sharded) + self.n_heads_per_rank = self.n_heads // world_size + self.n_kv_heads_per_rank = self.n_kv_heads // world_size + + self.repeats = self.n_heads_per_rank // self.n_kv_heads_per_rank + + self.scale = self.args.head_dim**-0.5 + + self.wq = nn.QuantizedAllToShardedLinear(args.dim, args.n_heads * args.head_dim, bias=False, group=world) + self.wk = nn.QuantizedAllToShardedLinear(args.dim, args.n_kv_heads * args.head_dim, bias=False, group=world) + self.wv = nn.QuantizedAllToShardedLinear(args.dim, args.n_kv_heads * args.head_dim, bias=False, group=world) + self.wo = nn.QuantizedShardedToAllLinear(args.n_heads * args.head_dim, args.dim, bias=False, group=world) + self.rope = nn.RoPE( + args.head_dim, traditional=args.rope_traditional, base=args.rope_theta + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + B, L, D = x.shape + + queries, keys, values = self.wq(x), self.wk(x), self.wv(x) + + # Prepare the queries, keys and values for the attention computation + # queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + # keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + # values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + queries = queries.reshape(B, L, self.n_heads_per_rank, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads_per_rank, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads_per_rank, -1).transpose(0, 2, 1, 3) + + def repeat(a): + a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) + return a.reshape([B, self.n_heads_per_rank, L, -1]) + + keys, values = map(repeat, (keys, values)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores += mask + scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) + output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.wo(output), (keys, values) + + +class FeedForward(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.w1 = nn.QuantizedAllToShardedLinear(args.dim, args.hidden_dim, bias=False, group=world) + self.w2 = nn.QuantizedAllToShardedLinear(args.hidden_dim, args.dim, bias=False, group=world) + self.w3 = nn.QuantizedShardedToAllLinear(args.dim, args.hidden_dim, bias=False, group=world) + + def __call__(self, x) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.attention = Attention(args) + self.feed_forward = FeedForward(args=args) + self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.attention(self.attention_norm(x), mask, cache) + h = x + r + r = self.feed_forward(self.ffn_norm(h)) + out = h + r + return out, cache + + +class Llama(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) + self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] + self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + + def __call__(self, x): + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.tok_embeddings.weight.dtype) + + x = self.tok_embeddings(x) + for l in self.layers: + x, _ = l(x, mask) + x = self.norm(x) + return self.output(x) + + def generate(self, x, temp=1.0): + def sample(logits): + if temp == 0: + return mx.argmax(logits, axis=-1) + else: + return mx.random.categorical(logits * (1 / temp)) + + cache = [] + + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.tok_embeddings.weight.dtype) + + # First we process the prompt x the same was as in __call__ but + # save the caches in cache + x = self.tok_embeddings(x) + for l in self.layers: + x, c = l(x, mask=mask) + # We store the per layer cache in a simple python list + cache.append(c) + x = self.norm(x) + # We only care about the last logits that generate the next token + y = self.output(x[:, -1]) + y = sample(y) + + # y now has size [1] + # Since MLX is lazily evaluated nothing is computed yet. + # Calling y.item() would force the computation to happen at + # this point but we can also choose not to do that and let the + # user choose when to start the computation. + yield y + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = y[:, None] + + x = self.tok_embeddings(x) + for i in range(len(cache)): + # We are overwriting the arrays in the cache list. When + # the computation will happen, MLX will be discarding the + # old cache the moment it is not needed anymore. + x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) + x = self.norm(x) + y = sample(self.output(x[:, -1])) + + yield y + + +# overwrite print to make only rank 0 output to terminal +_builtin_print = print +def print(*args, sep=' ', end='\n', file=None, flush=False): + if rank == 0: _builtin_print(*args, sep=sep, end=end, file=file, flush=flush) + +def tic(): + return time.time() + + +def toc(msg, start): + end = time.time() + return f"[INFO] {msg}: {end - start:.3f} s" + + +def generate(args): + if rank == 0: input("Press enter to start generation") + print("------") + print(args.prompt) + x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) + skip = 0 + prompt_processing = None + tokens = [] + start = tic() + for token in model.generate(x, args.temp): + tokens.append(token) + + if len(tokens) == 1: + # Actually perform the computation to measure the prompt processing time + mx.eval(token) + prompt_processing = toc("Prompt processing", start) + + if len(tokens) >= args.max_tokens: + break + + elif (len(tokens) % args.write_every) == 0: + # It is perfectly ok to eval things we have already eval-ed. + mx.eval(tokens) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], end="", flush=True) + skip = len(s) + + mx.eval(tokens) + full_gen = toc("Full generation", start) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], flush=True) + print("------") + print(prompt_processing) + print(full_gen) + + +def few_shot_generate(args): + def possible_end(s): + word = "[Instruction]" + for i in range(len(word) - 1, 0, -1): + if s[-i:] == word[:i]: + return 0 + if s[-len(word) :] == word: + return 1 + return -1 + + def generate(question): + x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)]) + skip = 0 + prompt_processing = None + tokens = [] + start = tic() + for token in model.generate(x, args.temp): + tokens.append(token) + + if len(tokens) == 1: + # Actually perform the computation to measure the prompt processing time + mx.eval(token) + prompt_processing = toc("Prompt processing", start) + + if len(tokens) >= args.max_tokens: + break + + mx.eval(tokens) + token_list = [t.item() for t in tokens] + s = tokenizer.decode(token_list) + + end = possible_end(s) + if end == 0: + continue + if end == 1: + skip = len(s) + break + + print(s[skip:], end="", flush=True) + skip = len(s) + if token_list[-1] == tokenizer.eos_id(): + break + + mx.eval(tokens) + full_gen = toc("Full generation", start) + s = tokenizer.decode([t.item() for t in tokens]) + print(s[skip:], end="", flush=True) + + print("[INFO] Loading few-shot examples from: {}".format(args.few_shot)) + prompt = open(args.few_shot).read().strip() + while True: + question = input("Ask a question: ") + generate(prompt.replace("{}", question)) + print() + + +def sanitize_config(config, weights): + config.pop("model_type", None) + n_heads = config["n_heads"] + if "n_kv_heads" not in config: + config["n_kv_heads"] = n_heads + if "head_dim" not in config: + config["head_dim"] = config["dim"] // n_heads + if "hidden_dim" not in config: + config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] + if config.get("vocab_size", -1) < 0: + config["vocab_size"] = weights["output.weight"].shape[-1] + if "rope_theta" not in config: + config["rope_theta"] = 10000 + unused = ["multiple_of", "ffn_dim_multiplier"] + for k in unused: + config.pop(k, None) + return config + + +def load_model(model_path): + model_path = Path(model_path) + + unsharded_weights_path = Path(model_path / "weights.npz") + if unsharded_weights_path.is_file(): + print("[INFO] Loading model from {}.".format(unsharded_weights_path)) + weights = mx.load(str(unsharded_weights_path)) + else: + sharded_weights_glob = str(model_path / "weights.*.npz") + weight_files = glob.glob(sharded_weights_glob) + print("[INFO] Loading model from {}.".format(sharded_weights_glob)) + + if len(weight_files) == 0: + raise FileNotFoundError("No weights found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + with open(model_path / "config.json", "r") as f: + config = sanitize_config(json.loads(f.read()), weights) + quantization = config.pop("quantization", None) + model = Llama(ModelArgs(**config)) + if quantization is not None: + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) + model.update(tree_unflatten(list(weights.items()))) + tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llama inference script") + parser.add_argument( + "--model-path", + help="Path to the model weights and tokenizer", + default="mlx_model", + ) + parser.add_argument( + "--prompt", + help="The message to be processed by the model. Ignored when --few-shot is provided.", + default="In the beginning the Universe was created.", + ) + parser.add_argument( + "--few-shot", + help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", + ) + parser.add_argument( + "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate" + ) + parser.add_argument( + "--write-every", type=int, default=1, help="After how many tokens to detokenize" + ) + parser.add_argument( + "--temp", type=float, default=0.0, help="The sampling temperature" + ) + parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load_model(args.model_path) + if args.few_shot: + few_shot_generate(args) + else: + generate(args) From 7e4a1b8bd8db2097b5149c1ec8b60c0a7f7fd353 Mon Sep 17 00:00:00 2001 From: stefpi <19478336+stefpi@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:59:52 -0500 Subject: [PATCH 2/7] ffn fix --- llms/llama/llama.py | 6 +++++- llms/llama/llama_tp.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index b791a5c20..c180d2792 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -339,7 +339,11 @@ def load_model(model_path): quantization = config.pop("quantization", None) model = Llama(ModelArgs(**config)) if quantization is not None: - nn.quantize(model, **quantization) + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize(model, **quantization, class_predicate=class_predicate) model.update(tree_unflatten(list(weights.items()))) tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) return model, tokenizer diff --git a/llms/llama/llama_tp.py b/llms/llama/llama_tp.py index 495e4b592..db2fc0be2 100644 --- a/llms/llama/llama_tp.py +++ b/llms/llama/llama_tp.py @@ -83,8 +83,8 @@ def repeat(a): key_cache, value_cache = cache queries = self.rope(queries, offset=key_cache.shape[2]) keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) + keys = mx.concatenate([key_cache, keys], axis=2).astype(mx.float16) + values = mx.concatenate([value_cache, values], axis=2).astype(mx.float16) else: queries = self.rope(queries) keys = self.rope(keys) @@ -102,8 +102,8 @@ def __init__(self, args: ModelArgs): super().__init__() self.w1 = nn.QuantizedAllToShardedLinear(args.dim, args.hidden_dim, bias=False, group=world) - self.w2 = nn.QuantizedAllToShardedLinear(args.hidden_dim, args.dim, bias=False, group=world) - self.w3 = nn.QuantizedShardedToAllLinear(args.dim, args.hidden_dim, bias=False, group=world) + self.w2 = nn.QuantizedShardedToAllLinear(args.hidden_dim, args.dim, bias=False, group=world) + self.w3 = nn.QuantizedAllToShardedLinear(args.dim, args.hidden_dim, bias=False, group=world) def __call__(self, x) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) From 9051e9d351585cfbf7eb60cec56e052d88a58590 Mon Sep 17 00:00:00 2001 From: stefpi <19478336+stefpi@users.noreply.github.com> Date: Sat, 10 Jan 2026 14:07:08 -0500 Subject: [PATCH 3/7] add support for TP in llama inference --- llms/llama/llama.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index c180d2792..13358f2a1 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -13,6 +13,8 @@ from mlx.utils import tree_unflatten from sentencepiece import SentencePieceProcessor +world = mx.distributed.init() +rank = world.rank() @dataclass class ModelArgs: @@ -48,6 +50,16 @@ def __init__(self, args: ModelArgs): args.head_dim, traditional=args.rope_traditional, base=args.rope_theta ) + def shard(self, group: mx.distributed.Group): + self.n_heads = self.n_heads // group.size() + self.n_kv_heads = self.n_kv_heads // group.size() + self.repeats = self.n_heads // self.n_kv_heads + + self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group) + self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group) + self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group) + self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group) + def __call__( self, x: mx.array, @@ -95,6 +107,11 @@ def __init__(self, args: ModelArgs): self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) + def shard(self, group: mx.distributed.Group): + self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group) + self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group) + self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group) + def __call__(self, x) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) @@ -195,6 +212,13 @@ def sample(logits): yield y +_builtin_print = print +def print(*args, sep=' ', end='\n', file=None, flush=False): + """Overwrite the print statement to only print to terminal on the rank 0 + device so that model output is not doubled.""" + if rank == 0: _builtin_print(*args, sep=sep, end=end, file=file, flush=flush) + + def tic(): return time.time() @@ -205,7 +229,7 @@ def toc(msg, start): def generate(args): - input("Press enter to start generation") + if rank == 0: input("Press enter to start generation") print("------") print(args.prompt) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) @@ -346,9 +370,15 @@ def load_model(model_path): nn.quantize(model, **quantization, class_predicate=class_predicate) model.update(tree_unflatten(list(weights.items()))) tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) + + if world.size() > 1: + # convert Linear layers in Transformer/FFN to appropriate Sharded Layers + for layer in model.layers: + layer.attention.shard(group=world) + layer.feed_forward.shard(group=world) + return model, tokenizer - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Llama inference script") parser.add_argument( From 55f0e5c947f927ff502d1a7841f93734a04d782f Mon Sep 17 00:00:00 2001 From: stefpi <19478336+stefpi@users.noreply.github.com> Date: Sat, 10 Jan 2026 18:19:38 -0500 Subject: [PATCH 4/7] cleanup --- llms/llama/llama_tp.py | 402 ----------------------------------------- 1 file changed, 402 deletions(-) delete mode 100644 llms/llama/llama_tp.py diff --git a/llms/llama/llama_tp.py b/llms/llama/llama_tp.py deleted file mode 100644 index db2fc0be2..000000000 --- a/llms/llama/llama_tp.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import argparse -import glob -import json -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten -from sentencepiece import SentencePieceProcessor - -world = mx.distributed.init() -rank = world.rank() -world_size = world.size() - -@dataclass -class ModelArgs: - dim: int - n_layers: int - head_dim: int - hidden_dim: int - n_heads: int - n_kv_heads: int - norm_eps: float - vocab_size: int - rope_theta: float - rope_traditional: bool = True - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - - self.n_heads: int = args.n_heads - self.n_kv_heads: int = args.n_kv_heads - - # Number of heads per rank (sharded) - self.n_heads_per_rank = self.n_heads // world_size - self.n_kv_heads_per_rank = self.n_kv_heads // world_size - - self.repeats = self.n_heads_per_rank // self.n_kv_heads_per_rank - - self.scale = self.args.head_dim**-0.5 - - self.wq = nn.QuantizedAllToShardedLinear(args.dim, args.n_heads * args.head_dim, bias=False, group=world) - self.wk = nn.QuantizedAllToShardedLinear(args.dim, args.n_kv_heads * args.head_dim, bias=False, group=world) - self.wv = nn.QuantizedAllToShardedLinear(args.dim, args.n_kv_heads * args.head_dim, bias=False, group=world) - self.wo = nn.QuantizedShardedToAllLinear(args.n_heads * args.head_dim, args.dim, bias=False, group=world) - self.rope = nn.RoPE( - args.head_dim, traditional=args.rope_traditional, base=args.rope_theta - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: - B, L, D = x.shape - - queries, keys, values = self.wq(x), self.wk(x), self.wv(x) - - # Prepare the queries, keys and values for the attention computation - # queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - # keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - # values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - queries = queries.reshape(B, L, self.n_heads_per_rank, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads_per_rank, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads_per_rank, -1).transpose(0, 2, 1, 3) - - def repeat(a): - a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2) - return a.reshape([B, self.n_heads_per_rank, L, -1]) - - keys, values = map(repeat, (keys, values)) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2).astype(mx.float16) - values = mx.concatenate([value_cache, values], axis=2).astype(mx.float16) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2) - if mask is not None: - scores += mask - scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) - output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.wo(output), (keys, values) - - -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.w1 = nn.QuantizedAllToShardedLinear(args.dim, args.hidden_dim, bias=False, group=world) - self.w2 = nn.QuantizedShardedToAllLinear(args.hidden_dim, args.dim, bias=False, group=world) - self.w3 = nn.QuantizedAllToShardedLinear(args.dim, args.hidden_dim, bias=False, group=world) - - def __call__(self, x) -> mx.array: - return self.w2(nn.silu(self.w1(x)) * self.w3(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) - self.feed_forward = FeedForward(args=args) - self.attention_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = nn.RMSNorm(args.dim, eps=args.norm_eps) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.attention(self.attention_norm(x), mask, cache) - h = x + r - r = self.feed_forward(self.ffn_norm(h)) - out = h + r - return out, cache - - -class Llama(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) - self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)] - self.norm = nn.RMSNorm(args.dim, eps=args.norm_eps) - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) - - def __call__(self, x): - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.tok_embeddings.weight.dtype) - - x = self.tok_embeddings(x) - for l in self.layers: - x, _ = l(x, mask) - x = self.norm(x) - return self.output(x) - - def generate(self, x, temp=1.0): - def sample(logits): - if temp == 0: - return mx.argmax(logits, axis=-1) - else: - return mx.random.categorical(logits * (1 / temp)) - - cache = [] - - # Make an additive causal mask. We will need that to process the prompt. - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - mask = mask.astype(self.tok_embeddings.weight.dtype) - - # First we process the prompt x the same was as in __call__ but - # save the caches in cache - x = self.tok_embeddings(x) - for l in self.layers: - x, c = l(x, mask=mask) - # We store the per layer cache in a simple python list - cache.append(c) - x = self.norm(x) - # We only care about the last logits that generate the next token - y = self.output(x[:, -1]) - y = sample(y) - - # y now has size [1] - # Since MLX is lazily evaluated nothing is computed yet. - # Calling y.item() would force the computation to happen at - # this point but we can also choose not to do that and let the - # user choose when to start the computation. - yield y - - # Now we parsed the prompt and generated the first token we - # need to feed it back into the model and loop to generate the - # rest. - while True: - # Unsqueezing the last dimension to add a sequence length - # dimension of 1 - x = y[:, None] - - x = self.tok_embeddings(x) - for i in range(len(cache)): - # We are overwriting the arrays in the cache list. When - # the computation will happen, MLX will be discarding the - # old cache the moment it is not needed anymore. - x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) - x = self.norm(x) - y = sample(self.output(x[:, -1])) - - yield y - - -# overwrite print to make only rank 0 output to terminal -_builtin_print = print -def print(*args, sep=' ', end='\n', file=None, flush=False): - if rank == 0: _builtin_print(*args, sep=sep, end=end, file=file, flush=flush) - -def tic(): - return time.time() - - -def toc(msg, start): - end = time.time() - return f"[INFO] {msg}: {end - start:.3f} s" - - -def generate(args): - if rank == 0: input("Press enter to start generation") - print("------") - print(args.prompt) - x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) - skip = 0 - prompt_processing = None - tokens = [] - start = tic() - for token in model.generate(x, args.temp): - tokens.append(token) - - if len(tokens) == 1: - # Actually perform the computation to measure the prompt processing time - mx.eval(token) - prompt_processing = toc("Prompt processing", start) - - if len(tokens) >= args.max_tokens: - break - - elif (len(tokens) % args.write_every) == 0: - # It is perfectly ok to eval things we have already eval-ed. - mx.eval(tokens) - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - skip = len(s) - - mx.eval(tokens) - full_gen = toc("Full generation", start) - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], flush=True) - print("------") - print(prompt_processing) - print(full_gen) - - -def few_shot_generate(args): - def possible_end(s): - word = "[Instruction]" - for i in range(len(word) - 1, 0, -1): - if s[-i:] == word[:i]: - return 0 - if s[-len(word) :] == word: - return 1 - return -1 - - def generate(question): - x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)]) - skip = 0 - prompt_processing = None - tokens = [] - start = tic() - for token in model.generate(x, args.temp): - tokens.append(token) - - if len(tokens) == 1: - # Actually perform the computation to measure the prompt processing time - mx.eval(token) - prompt_processing = toc("Prompt processing", start) - - if len(tokens) >= args.max_tokens: - break - - mx.eval(tokens) - token_list = [t.item() for t in tokens] - s = tokenizer.decode(token_list) - - end = possible_end(s) - if end == 0: - continue - if end == 1: - skip = len(s) - break - - print(s[skip:], end="", flush=True) - skip = len(s) - if token_list[-1] == tokenizer.eos_id(): - break - - mx.eval(tokens) - full_gen = toc("Full generation", start) - s = tokenizer.decode([t.item() for t in tokens]) - print(s[skip:], end="", flush=True) - - print("[INFO] Loading few-shot examples from: {}".format(args.few_shot)) - prompt = open(args.few_shot).read().strip() - while True: - question = input("Ask a question: ") - generate(prompt.replace("{}", question)) - print() - - -def sanitize_config(config, weights): - config.pop("model_type", None) - n_heads = config["n_heads"] - if "n_kv_heads" not in config: - config["n_kv_heads"] = n_heads - if "head_dim" not in config: - config["head_dim"] = config["dim"] // n_heads - if "hidden_dim" not in config: - config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0] - if config.get("vocab_size", -1) < 0: - config["vocab_size"] = weights["output.weight"].shape[-1] - if "rope_theta" not in config: - config["rope_theta"] = 10000 - unused = ["multiple_of", "ffn_dim_multiplier"] - for k in unused: - config.pop(k, None) - return config - - -def load_model(model_path): - model_path = Path(model_path) - - unsharded_weights_path = Path(model_path / "weights.npz") - if unsharded_weights_path.is_file(): - print("[INFO] Loading model from {}.".format(unsharded_weights_path)) - weights = mx.load(str(unsharded_weights_path)) - else: - sharded_weights_glob = str(model_path / "weights.*.npz") - weight_files = glob.glob(sharded_weights_glob) - print("[INFO] Loading model from {}.".format(sharded_weights_glob)) - - if len(weight_files) == 0: - raise FileNotFoundError("No weights found in {}".format(model_path)) - - weights = {} - for wf in weight_files: - weights.update(mx.load(wf).items()) - - with open(model_path / "config.json", "r") as f: - config = sanitize_config(json.loads(f.read()), weights) - quantization = config.pop("quantization", None) - model = Llama(ModelArgs(**config)) - if quantization is not None: - class_predicate = ( - lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) - and f"{p}.scales" in weights - ) - nn.quantize(model, **quantization, class_predicate=class_predicate) - model.update(tree_unflatten(list(weights.items()))) - tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) - return model, tokenizer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Llama inference script") - parser.add_argument( - "--model-path", - help="Path to the model weights and tokenizer", - default="mlx_model", - ) - parser.add_argument( - "--prompt", - help="The message to be processed by the model. Ignored when --few-shot is provided.", - default="In the beginning the Universe was created.", - ) - parser.add_argument( - "--few-shot", - help="Read a few shot prompt from a file (as in `sample_prompt.txt`).", - ) - parser.add_argument( - "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate" - ) - parser.add_argument( - "--write-every", type=int, default=1, help="After how many tokens to detokenize" - ) - parser.add_argument( - "--temp", type=float, default=0.0, help="The sampling temperature" - ) - parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") - - args = parser.parse_args() - - mx.random.seed(args.seed) - - model, tokenizer = load_model(args.model_path) - if args.few_shot: - few_shot_generate(args) - else: - generate(args) From fb1603964367d3a0e8ad9649d24064c4c7173abb Mon Sep 17 00:00:00 2001 From: stefpi <19478336+stefpi@users.noreply.github.com> Date: Sun, 11 Jan 2026 13:22:02 -0500 Subject: [PATCH 5/7] pre-commit formatting --- llms/llama/llama.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 13358f2a1..95ce432fb 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -16,6 +16,7 @@ world = mx.distributed.init() rank = world.rank() + @dataclass class ModelArgs: dim: int @@ -55,10 +56,18 @@ def shard(self, group: mx.distributed.Group): self.n_kv_heads = self.n_kv_heads // group.size() self.repeats = self.n_heads // self.n_kv_heads - self.wq = nn.layers.distributed.shard_linear(self.wq, "all-to-sharded", group=group) - self.wk = nn.layers.distributed.shard_linear(self.wk, "all-to-sharded", group=group) - self.wv = nn.layers.distributed.shard_linear(self.wv, "all-to-sharded", group=group) - self.wo = nn.layers.distributed.shard_linear(self.wo, "sharded-to-all", group=group) + self.wq = nn.layers.distributed.shard_linear( + self.wq, "all-to-sharded", group=group + ) + self.wk = nn.layers.distributed.shard_linear( + self.wk, "all-to-sharded", group=group + ) + self.wv = nn.layers.distributed.shard_linear( + self.wv, "all-to-sharded", group=group + ) + self.wo = nn.layers.distributed.shard_linear( + self.wo, "sharded-to-all", group=group + ) def __call__( self, @@ -108,9 +117,15 @@ def __init__(self, args: ModelArgs): self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) def shard(self, group: mx.distributed.Group): - self.w1 = nn.layers.distributed.shard_linear(self.w1, "all-to-sharded", group=group) - self.w2 = nn.layers.distributed.shard_linear(self.w2, "sharded-to-all", group=group) - self.w3 = nn.layers.distributed.shard_linear(self.w3, "all-to-sharded", group=group) + self.w1 = nn.layers.distributed.shard_linear( + self.w1, "all-to-sharded", group=group + ) + self.w2 = nn.layers.distributed.shard_linear( + self.w2, "sharded-to-all", group=group + ) + self.w3 = nn.layers.distributed.shard_linear( + self.w3, "all-to-sharded", group=group + ) def __call__(self, x) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) @@ -213,10 +228,13 @@ def sample(logits): _builtin_print = print -def print(*args, sep=' ', end='\n', file=None, flush=False): + + +def print(*args, sep=" ", end="\n", file=None, flush=False): """Overwrite the print statement to only print to terminal on the rank 0 device so that model output is not doubled.""" - if rank == 0: _builtin_print(*args, sep=sep, end=end, file=file, flush=flush) + if rank == 0: + _builtin_print(*args, sep=sep, end=end, file=file, flush=flush) def tic(): @@ -229,7 +247,8 @@ def toc(msg, start): def generate(args): - if rank == 0: input("Press enter to start generation") + if rank == 0: + input("Press enter to start generation") print("------") print(args.prompt) x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)]) @@ -370,15 +389,16 @@ def load_model(model_path): nn.quantize(model, **quantization, class_predicate=class_predicate) model.update(tree_unflatten(list(weights.items()))) tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model")) - + if world.size() > 1: # convert Linear layers in Transformer/FFN to appropriate Sharded Layers for layer in model.layers: layer.attention.shard(group=world) layer.feed_forward.shard(group=world) - + return model, tokenizer + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Llama inference script") parser.add_argument( From 704bab703dda88b4679f20c3df5bc17986099f66 Mon Sep 17 00:00:00 2001 From: stefpi <19478336+stefpi@users.noreply.github.com> Date: Thu, 15 Jan 2026 13:09:23 -0500 Subject: [PATCH 6/7] import shard_linear --- llms/llama/llama.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 95ce432fb..8269c7482 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -10,6 +10,7 @@ import mlx.core as mx import mlx.nn as nn +from mlx.nn.layers.distributed import shard_linear from mlx.utils import tree_unflatten from sentencepiece import SentencePieceProcessor @@ -56,18 +57,10 @@ def shard(self, group: mx.distributed.Group): self.n_kv_heads = self.n_kv_heads // group.size() self.repeats = self.n_heads // self.n_kv_heads - self.wq = nn.layers.distributed.shard_linear( - self.wq, "all-to-sharded", group=group - ) - self.wk = nn.layers.distributed.shard_linear( - self.wk, "all-to-sharded", group=group - ) - self.wv = nn.layers.distributed.shard_linear( - self.wv, "all-to-sharded", group=group - ) - self.wo = nn.layers.distributed.shard_linear( - self.wo, "sharded-to-all", group=group - ) + self.wq = shard_linear(self.wq, "all-to-sharded", group=group) + self.wk = shard_linear(self.wk, "all-to-sharded", group=group) + self.wv = shard_linear(self.wv, "all-to-sharded", group=group) + self.wo = shard_linear(self.wo, "sharded-to-all", group=group) def __call__( self, @@ -117,15 +110,9 @@ def __init__(self, args: ModelArgs): self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) def shard(self, group: mx.distributed.Group): - self.w1 = nn.layers.distributed.shard_linear( - self.w1, "all-to-sharded", group=group - ) - self.w2 = nn.layers.distributed.shard_linear( - self.w2, "sharded-to-all", group=group - ) - self.w3 = nn.layers.distributed.shard_linear( - self.w3, "all-to-sharded", group=group - ) + self.w1 = shard_linear(self.w1, "all-to-sharded", group=group) + self.w2 = shard_linear(self.w2, "sharded-to-all", group=group) + self.w3 = shard_linear(self.w3, "all-to-sharded", group=group) def __call__(self, x) -> mx.array: return self.w2(nn.silu(self.w1(x)) * self.w3(x)) From 3c7583edb1b3c69a6670220dc67e61df9933c305 Mon Sep 17 00:00:00 2001 From: stef <19478336+stefpi@users.noreply.github.com> Date: Wed, 28 Jan 2026 19:01:25 -0500 Subject: [PATCH 7/7] remove redundant repeats definition Co-authored-by: Awni Hannun --- llms/llama/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llms/llama/llama.py b/llms/llama/llama.py index 8269c7482..b7ba28812 100644 --- a/llms/llama/llama.py +++ b/llms/llama/llama.py @@ -55,7 +55,6 @@ def __init__(self, args: ModelArgs): def shard(self, group: mx.distributed.Group): self.n_heads = self.n_heads // group.size() self.n_kv_heads = self.n_kv_heads // group.size() - self.repeats = self.n_heads // self.n_kv_heads self.wq = shard_linear(self.wq, "all-to-sharded", group=group) self.wk = shard_linear(self.wk, "all-to-sharded", group=group)