Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion Package.resolved

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ let package = Package(
.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"),
.package(url: "https://github.com/mattt/EventSource", from: "1.3.0"),
.package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"),
.package(url: "https://github.com/mattt/swift-xgrammar", from: "0.1.0"),
.package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")),
.package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"),
// mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:).
Expand All @@ -45,6 +46,7 @@ let package = Package(
.product(name: "EventSource", package: "EventSource"),
.product(name: "JSONSchema", package: "JSONSchema"),
.product(name: "PartialJSONDecoder", package: "PartialJSONDecoder"),
.product(name: "XGrammar", package: "swift-xgrammar"),
.product(
Comment on lines +49 to 50
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XGrammar is only imported/used inside the MLX and Llama conditional compilation blocks, but the target depends on it unconditionally. This forces consumers to fetch/build swift-xgrammar even when neither trait is enabled, and can break traitless builds if that dependency has platform/toolchain constraints. Consider making the XGrammar product dependency conditional on the MLX and/or Llama traits (similar to MLXLLM / LlamaSwift).

Suggested change
.product(name: "XGrammar", package: "swift-xgrammar"),
.product(
.product(
name: "XGrammar",
package: "swift-xgrammar",
condition: .when(traits: ["MLX", "Llama"])
),
.product(

Copilot uses AI. Check for mistakes.
name: "MLXLLM",
package: "mlx-swift-lm",
Expand Down
104 changes: 98 additions & 6 deletions Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Foundation
#if Llama
import JSONSchema
import LlamaSwift
import XGrammar

/// Global storage for the current log level threshold.
/// This is needed because the C callback can't capture Swift context.
Expand Down Expand Up @@ -532,7 +533,7 @@ import Foundation
)
} else {
let maxTokens = structuredOptions.maximumResponseTokens ?? 512
let jsonString = try generateStructuredJSON(
let jsonString = try await generateStructuredJSON(
context: context,
prompt: fullPrompt,
schema: type.generationSchema,
Expand Down Expand Up @@ -913,6 +914,41 @@ import Foundation
return "\(header):\n\(schemaJSON)"
}

private func jsonSchemaString(for schema: GenerationSchema) throws -> String {
let encoder = JSONEncoder()
encoder.outputFormatting = [.sortedKeys]
let data = try encoder.encode(schema)
guard let jsonSchema = String(data: data, encoding: .utf8) else {
throw LlamaLanguageModelError.schemaEncodingFailed
}
return jsonSchema
}

private func tokenizerInfo(
for vocab: OpaquePointer,
vocabSize: Int,
stopTokens: Set<Int>
) throws -> TokenizerInfo {
guard vocabSize > 0 else {
throw LlamaLanguageModelError.contextInitializationFailed
}

var encodedVocab: [String] = []
encodedVocab.reserveCapacity(vocabSize)
for tokenId in 0 ..< vocabSize {
let token = llama_token(tokenId)
encodedVocab.append(tokenToText(vocab: vocab, token: token) ?? "")
}

let stopTokenIDs = stopTokens.map { Int32($0) }
return try TokenizerInfo(
encodedVocab: encodedVocab,
encoding: .byteFallback,
stopTokenIDs: stopTokenIDs,
addPrefixSpace: false
)
}
Comment on lines +927 to +950
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building TokenizerInfo reconstructs the entire encodedVocab (size = vocabSize) for every structured-generation request. This is expensive and deterministic for a given model/vocab; consider caching the computed TokenizerInfo (keyed by vocab pointer + vocabSize, or by model ID) to avoid repeatedly traversing the entire vocabulary.

Copilot uses AI. Check for mistakes.

// MARK: - Structured JSON Generation

private func generateStructuredJSON(
Expand All @@ -921,7 +957,7 @@ import Foundation
schema: GenerationSchema,
maxTokens: Int,
options: ResolvedGenerationOptions
) throws -> String {
) async throws -> String {
guard let vocab = llama_model_get_vocab(model!) else {
throw LlamaLanguageModelError.contextInitializationFailed
}
Expand Down Expand Up @@ -964,21 +1000,56 @@ import Foundation

let vocabSize = Int(llama_vocab_n_tokens(vocab))
let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens
let jsonSchema = try jsonSchemaString(for: schema)
let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true)
let eosToken = Int(llama_vocab_eos(vocab))
let eotTokenValue = llama_vocab_eot(vocab)
let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken
let endTokens: Set<Int> = [eosToken, endOfTurnToken]

let tokenizerInfo = try tokenizerInfo(
for: vocab,
vocabSize: vocabSize,
stopTokens: endTokens
)
let matcher = try await grammar.matcher(
for: tokenizerInfo,
stopTokens: endTokens.map { Int32($0) },
terminatesWithoutStopToken: true
)
var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: vocabSize)

return try withUnsafeMutablePointer(to: &batch) { batchPointer in
let backend = LlamaTokenBackend(
var backend = LlamaTokenBackend(
context: context,
vocab: vocab,
vocabSize: vocabSize,
sampler: samplerPointer,
batch: batchPointer,
position: initialPosition,
maximumTokens: maxTokens,
endTokens: [],
endTokens: endTokens,
tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) }
)
var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema)
return try generator.generate()

var output = ""
while backend.remainingTokens > 0 {
bitmask.reset()
let needsMask = matcher.fillNextTokenBitmask(&bitmask)
let token = try backend.sample(using: bitmask, applyMask: needsMask)
if backend.endTokens.contains(token) {
break
}
guard matcher.accept(Int32(token)) else {
throw LlamaLanguageModelError.grammarMismatch
}
Comment on lines +1039 to +1045
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generation loop breaks immediately when a sampled token is in endTokens, before checking/advancing the grammar matcher. If the model samples EOS/EOT before the grammar has fully terminated (or in a step where masking is not applied), this will return incomplete/invalid JSON. Consider only stopping on end tokens when matcher.isTerminated is already true (or accept the token into the matcher and verify termination) so early EOS doesn’t truncate the structured output.

Copilot uses AI. Check for mistakes.
if let tokenText = backend.tokenText(token) {
output += tokenText
}
try backend.decode(token)
if matcher.isTerminated { break }
}
return output
}
}

Expand Down Expand Up @@ -1105,6 +1176,21 @@ import Foundation
}
}

mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int {
guard let logits = llama_get_logits(context) else {
return eosToken
}

if applyMask {
for tokenIndex in 0 ..< vocabSize where !bitmask.isTokenAllowed(tokenIndex) {
logits[tokenIndex] = -Float.infinity
}
}

let tokenIndex = batch.pointee.n_tokens - 1
return Int(llama_sampler_sample(sampler, context, tokenIndex))
}

mutating func sample(from allowedTokens: Set<Int>) throws -> Int {
guard let logits = llama_get_logits(context) else {
return eosToken
Expand Down Expand Up @@ -1536,6 +1622,8 @@ import Foundation
case insufficientMemory
case unsupportedFeature
case encoderOnlyModel
case schemaEncodingFailed
case grammarMismatch

public var errorDescription: String? {
switch self {
Expand All @@ -1557,6 +1645,10 @@ import Foundation
return "This LlamaLanguageModel does not support image segments"
case .encoderOnlyModel:
return "This model is encoder-only (e.g., BERT) and cannot generate text"
case .schemaEncodingFailed:
return "Failed to encode the JSON schema for structured generation"
case .grammarMismatch:
return "Grammar constraints could not be satisfied during generation"
}
}
}
Expand Down
105 changes: 101 additions & 4 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Foundation
import MLXVLM
import Tokenizers
import Hub
import XGrammar

/// Wrapper to store ModelContext in NSCache (requires NSObject subclass).
private final class CachedContext: NSObject, @unchecked Sendable {
Expand Down Expand Up @@ -782,19 +783,59 @@ import Foundation
return "\(header):\n\(schemaJSON)"
}

private func jsonSchemaString(for schema: GenerationSchema) throws -> String {
let encoder = JSONEncoder()
encoder.outputFormatting = [.sortedKeys]
let data = try encoder.encode(schema)
guard let jsonSchema = String(data: data, encoding: .utf8) else {
throw MLXLanguageModelError.schemaEncodingFailed
}
return jsonSchema
}

private func tokenizerInfo(
for tokenizer: any Tokenizer,
vocabSize: Int,
stopTokens: Set<Int>
) throws -> TokenizerInfo {
guard vocabSize > 0 else {
throw MLXLanguageModelError.invalidVocabSize
}

var encodedVocab: [String] = []
encodedVocab.reserveCapacity(vocabSize)
for tokenId in 0 ..< vocabSize {
encodedVocab.append(tokenizer.convertIdToToken(tokenId) ?? "")
}

let stopTokenIDs = stopTokens.map { Int32($0) }
return try TokenizerInfo(
encodedVocab: encodedVocab,
encoding: .byteLevel,
stopTokenIDs: stopTokenIDs,
addPrefixSpace: false
)
}

// MARK: - Structured JSON Generation

/// Errors that can occur when using MLXLanguageModel.
public enum MLXLanguageModelError: Error, LocalizedError {
case invalidVocabSize
case unsupportedJSONValueType
case schemaEncodingFailed
case grammarMismatch

public var errorDescription: String? {
switch self {
case .invalidVocabSize:
return "Invalid vocabulary size for model output"
case .unsupportedJSONValueType:
return "Unsupported JSON value type for schema conversion"
case .schemaEncodingFailed:
return "Failed to encode the JSON schema for structured generation"
case .grammarMismatch:
return "Grammar constraints could not be satisfied during generation"
}
}
}
Expand Down Expand Up @@ -827,13 +868,42 @@ import Foundation
maximumTokens: maxTokens,
endTokens: []
)

var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema)
let json = try generator.generate()
let jsonSchema = try jsonSchemaString(for: schema)
let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true)
let tokenizerInfo = try tokenizerInfo(
for: context.tokenizer,
vocabSize: backend.vocabSize,
stopTokens: backend.endTokens
)
let matcher = try await grammar.matcher(
for: tokenizerInfo,
stopTokens: backend.endTokens.map { Int32($0) },
terminatesWithoutStopToken: true
)
var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: tokenizerInfo.vocabulary.size)

var backendState = backend
var output = ""
while backendState.remainingTokens > 0 {
bitmask.reset()
let needsMask = matcher.fillNextTokenBitmask(&bitmask)
let token = try backendState.sample(using: bitmask, applyMask: needsMask)
if backendState.endTokens.contains(token) {
break
}
guard matcher.accept(Int32(token)) else {
throw MLXLanguageModelError.grammarMismatch
}
if let tokenText = backendState.tokenText(token) {
output += tokenText
}
try backendState.decode(token)
if matcher.isTerminated { break }
}
// Ensure pending MLX operations complete before returning JSON.
// This synchronization can be a performance cost if called frequently.
Stream().synchronize()
return json
return output
}

/// Merges system prompts and schema instructions into a user message.
Expand Down Expand Up @@ -1038,6 +1108,33 @@ import Foundation
}
}

mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int {
var logits = currentLogits[0..., -1, 0...]
logits = processor?.process(logits: logits) ?? logits
if logits.dtype == .bfloat16 {
logits = logits.asType(.float32)
}

if applyMask {
var allowedIndices: [UInt32] = []
allowedIndices.reserveCapacity(vocabSize)
for tokenId in 0 ..< vocabSize where bitmask.isTokenAllowed(tokenId) {
allowedIndices.append(UInt32(tokenId))
}
guard !allowedIndices.isEmpty else {
throw MLXLanguageModelError.grammarMismatch
}
let allowedArray = MLXArray(allowedIndices)
let maskedLogits = full(logits.shape, values: -Float.infinity)
maskedLogits[0..., allowedArray] = logits[0..., allowedArray]
let sampledToken = sampler.sample(logits: maskedLogits)
return sampledToken.item(Int.self)
}
Comment on lines +1118 to +1132
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The masked sampling path is O(vocabSize) per generated token (scans every token ID, builds an indices array, and allocates a full maskedLogits tensor each step). For typical vocab sizes (50k–200k) this can be a major bottleneck for structured generation. Consider reusing buffers across steps and/or applying the mask more directly (e.g., materializing only the mask once per step without reserving vocabSize, or using a bitmask→indices iterator if XGrammar exposes one) to avoid repeated full-vocab scans and allocations.

Copilot uses AI. Check for mistakes.

let sampledToken = sampler.sample(logits: logits)
return sampledToken.item(Int.self)
}

mutating func sample(from allowedTokens: Set<Int>) throws -> Int {
guard !allowedTokens.isEmpty else {
throw ConstrainedGenerationError.tokenizationFailed
Expand Down
Loading