diff --git a/Package.resolved b/Package.resolved index 5aa2b99..e3750b0 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "f7b86b800200fa069a2b288e06bafe53bc937a1851b6effeebba326a62be227e", + "originHash" : "29f2699893382ca66cfbbd685a48c6ce8aacb263744cf3874cc5aaff4c26a63e", "pins" : [ { "identity" : "eventsource", @@ -19,6 +19,15 @@ "version" : "1.3.1" } }, + { + "identity" : "llama.swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/llama.swift", + "state" : { + "revision" : "0391e7847330f0bb3ff2ca3498beb5a4a511d6e0", + "version" : "2.7974.0" + } + }, { "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", @@ -45,6 +54,15 @@ "revision" : "0687f71944021d616d34d922343dcef086855920", "version" : "600.0.1" } + }, + { + "identity" : "swift-xgrammar", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/swift-xgrammar", + "state" : { + "revision" : "e0f2de361aa4891dbca5113629bc90ab8e5c3888", + "version" : "0.1.0" + } } ], "version" : 3 diff --git a/Package.swift b/Package.swift index 3916bf0..b8dfb0b 100644 --- a/Package.swift +++ b/Package.swift @@ -31,6 +31,7 @@ let package = Package( .package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"), .package(url: "https://github.com/mattt/EventSource", from: "1.3.0"), .package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"), + .package(url: "https://github.com/mattt/swift-xgrammar", from: "0.1.0"), .package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")), .package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"), // mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:). @@ -45,6 +46,7 @@ let package = Package( .product(name: "EventSource", package: "EventSource"), .product(name: "JSONSchema", package: "JSONSchema"), .product(name: "PartialJSONDecoder", package: "PartialJSONDecoder"), + .product(name: "XGrammar", package: "swift-xgrammar"), .product( name: "MLXLLM", package: "mlx-swift-lm", diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 9142187..ff05263 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -2,6 +2,7 @@ import Foundation #if Llama import JSONSchema import LlamaSwift + import XGrammar /// Global storage for the current log level threshold. /// This is needed because the C callback can't capture Swift context. @@ -532,7 +533,7 @@ import Foundation ) } else { let maxTokens = structuredOptions.maximumResponseTokens ?? 512 - let jsonString = try generateStructuredJSON( + let jsonString = try await generateStructuredJSON( context: context, prompt: fullPrompt, schema: type.generationSchema, @@ -913,6 +914,41 @@ import Foundation return "\(header):\n\(schemaJSON)" } + private func jsonSchemaString(for schema: GenerationSchema) throws -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + let data = try encoder.encode(schema) + guard let jsonSchema = String(data: data, encoding: .utf8) else { + throw LlamaLanguageModelError.schemaEncodingFailed + } + return jsonSchema + } + + private func tokenizerInfo( + for vocab: OpaquePointer, + vocabSize: Int, + stopTokens: Set + ) throws -> TokenizerInfo { + guard vocabSize > 0 else { + throw LlamaLanguageModelError.contextInitializationFailed + } + + var encodedVocab: [String] = [] + encodedVocab.reserveCapacity(vocabSize) + for tokenId in 0 ..< vocabSize { + let token = llama_token(tokenId) + encodedVocab.append(tokenToText(vocab: vocab, token: token) ?? "") + } + + let stopTokenIDs = stopTokens.map { Int32($0) } + return try TokenizerInfo( + encodedVocab: encodedVocab, + encoding: .byteFallback, + stopTokenIDs: stopTokenIDs, + addPrefixSpace: false + ) + } + // MARK: - Structured JSON Generation private func generateStructuredJSON( @@ -921,7 +957,7 @@ import Foundation schema: GenerationSchema, maxTokens: Int, options: ResolvedGenerationOptions - ) throws -> String { + ) async throws -> String { guard let vocab = llama_model_get_vocab(model!) else { throw LlamaLanguageModelError.contextInitializationFailed } @@ -964,9 +1000,27 @@ import Foundation let vocabSize = Int(llama_vocab_n_tokens(vocab)) let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens + let jsonSchema = try jsonSchemaString(for: schema) + let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true) + let eosToken = Int(llama_vocab_eos(vocab)) + let eotTokenValue = llama_vocab_eot(vocab) + let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken + let endTokens: Set = [eosToken, endOfTurnToken] + + let tokenizerInfo = try tokenizerInfo( + for: vocab, + vocabSize: vocabSize, + stopTokens: endTokens + ) + let matcher = try await grammar.matcher( + for: tokenizerInfo, + stopTokens: endTokens.map { Int32($0) }, + terminatesWithoutStopToken: true + ) + var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: vocabSize) return try withUnsafeMutablePointer(to: &batch) { batchPointer in - let backend = LlamaTokenBackend( + var backend = LlamaTokenBackend( context: context, vocab: vocab, vocabSize: vocabSize, @@ -974,11 +1028,28 @@ import Foundation batch: batchPointer, position: initialPosition, maximumTokens: maxTokens, - endTokens: [], + endTokens: endTokens, tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) } ) - var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - return try generator.generate() + + var output = "" + while backend.remainingTokens > 0 { + bitmask.reset() + let needsMask = matcher.fillNextTokenBitmask(&bitmask) + let token = try backend.sample(using: bitmask, applyMask: needsMask) + if backend.endTokens.contains(token) { + break + } + guard matcher.accept(Int32(token)) else { + throw LlamaLanguageModelError.grammarMismatch + } + if let tokenText = backend.tokenText(token) { + output += tokenText + } + try backend.decode(token) + if matcher.isTerminated { break } + } + return output } } @@ -1105,6 +1176,21 @@ import Foundation } } + mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int { + guard let logits = llama_get_logits(context) else { + return eosToken + } + + if applyMask { + for tokenIndex in 0 ..< vocabSize where !bitmask.isTokenAllowed(tokenIndex) { + logits[tokenIndex] = -Float.infinity + } + } + + let tokenIndex = batch.pointee.n_tokens - 1 + return Int(llama_sampler_sample(sampler, context, tokenIndex)) + } + mutating func sample(from allowedTokens: Set) throws -> Int { guard let logits = llama_get_logits(context) else { return eosToken @@ -1536,6 +1622,8 @@ import Foundation case insufficientMemory case unsupportedFeature case encoderOnlyModel + case schemaEncodingFailed + case grammarMismatch public var errorDescription: String? { switch self { @@ -1557,6 +1645,10 @@ import Foundation return "This LlamaLanguageModel does not support image segments" case .encoderOnlyModel: return "This model is encoder-only (e.g., BERT) and cannot generate text" + case .schemaEncodingFailed: + return "Failed to encode the JSON schema for structured generation" + case .grammarMismatch: + return "Grammar constraints could not be satisfied during generation" } } } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 7b781b2..fd963bb 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -17,6 +17,7 @@ import Foundation import MLXVLM import Tokenizers import Hub + import XGrammar /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). private final class CachedContext: NSObject, @unchecked Sendable { @@ -782,12 +783,48 @@ import Foundation return "\(header):\n\(schemaJSON)" } + private func jsonSchemaString(for schema: GenerationSchema) throws -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + let data = try encoder.encode(schema) + guard let jsonSchema = String(data: data, encoding: .utf8) else { + throw MLXLanguageModelError.schemaEncodingFailed + } + return jsonSchema + } + + private func tokenizerInfo( + for tokenizer: any Tokenizer, + vocabSize: Int, + stopTokens: Set + ) throws -> TokenizerInfo { + guard vocabSize > 0 else { + throw MLXLanguageModelError.invalidVocabSize + } + + var encodedVocab: [String] = [] + encodedVocab.reserveCapacity(vocabSize) + for tokenId in 0 ..< vocabSize { + encodedVocab.append(tokenizer.convertIdToToken(tokenId) ?? "") + } + + let stopTokenIDs = stopTokens.map { Int32($0) } + return try TokenizerInfo( + encodedVocab: encodedVocab, + encoding: .byteLevel, + stopTokenIDs: stopTokenIDs, + addPrefixSpace: false + ) + } + // MARK: - Structured JSON Generation /// Errors that can occur when using MLXLanguageModel. public enum MLXLanguageModelError: Error, LocalizedError { case invalidVocabSize case unsupportedJSONValueType + case schemaEncodingFailed + case grammarMismatch public var errorDescription: String? { switch self { @@ -795,6 +832,10 @@ import Foundation return "Invalid vocabulary size for model output" case .unsupportedJSONValueType: return "Unsupported JSON value type for schema conversion" + case .schemaEncodingFailed: + return "Failed to encode the JSON schema for structured generation" + case .grammarMismatch: + return "Grammar constraints could not be satisfied during generation" } } } @@ -827,13 +868,42 @@ import Foundation maximumTokens: maxTokens, endTokens: [] ) - - var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let json = try generator.generate() + let jsonSchema = try jsonSchemaString(for: schema) + let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true) + let tokenizerInfo = try tokenizerInfo( + for: context.tokenizer, + vocabSize: backend.vocabSize, + stopTokens: backend.endTokens + ) + let matcher = try await grammar.matcher( + for: tokenizerInfo, + stopTokens: backend.endTokens.map { Int32($0) }, + terminatesWithoutStopToken: true + ) + var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: tokenizerInfo.vocabulary.size) + + var backendState = backend + var output = "" + while backendState.remainingTokens > 0 { + bitmask.reset() + let needsMask = matcher.fillNextTokenBitmask(&bitmask) + let token = try backendState.sample(using: bitmask, applyMask: needsMask) + if backendState.endTokens.contains(token) { + break + } + guard matcher.accept(Int32(token)) else { + throw MLXLanguageModelError.grammarMismatch + } + if let tokenText = backendState.tokenText(token) { + output += tokenText + } + try backendState.decode(token) + if matcher.isTerminated { break } + } // Ensure pending MLX operations complete before returning JSON. // This synchronization can be a performance cost if called frequently. Stream().synchronize() - return json + return output } /// Merges system prompts and schema instructions into a user message. @@ -1038,6 +1108,33 @@ import Foundation } } + mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int { + var logits = currentLogits[0..., -1, 0...] + logits = processor?.process(logits: logits) ?? logits + if logits.dtype == .bfloat16 { + logits = logits.asType(.float32) + } + + if applyMask { + var allowedIndices: [UInt32] = [] + allowedIndices.reserveCapacity(vocabSize) + for tokenId in 0 ..< vocabSize where bitmask.isTokenAllowed(tokenId) { + allowedIndices.append(UInt32(tokenId)) + } + guard !allowedIndices.isEmpty else { + throw MLXLanguageModelError.grammarMismatch + } + let allowedArray = MLXArray(allowedIndices) + let maskedLogits = full(logits.shape, values: -Float.infinity) + maskedLogits[0..., allowedArray] = logits[0..., allowedArray] + let sampledToken = sampler.sample(logits: maskedLogits) + return sampledToken.item(Int.self) + } + + let sampledToken = sampler.sample(logits: logits) + return sampledToken.item(Int.self) + } + mutating func sample(from allowedTokens: Set) throws -> Int { guard !allowedTokens.isEmpty else { throw ConstrainedGenerationError.tokenizationFailed