-
Notifications
You must be signed in to change notification settings - Fork 55
Use XGrammar for structured output generation #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ import Foundation | |
| #if Llama | ||
| import JSONSchema | ||
| import LlamaSwift | ||
| import XGrammar | ||
|
|
||
| /// Global storage for the current log level threshold. | ||
| /// This is needed because the C callback can't capture Swift context. | ||
|
|
@@ -532,7 +533,7 @@ import Foundation | |
| ) | ||
| } else { | ||
| let maxTokens = structuredOptions.maximumResponseTokens ?? 512 | ||
| let jsonString = try generateStructuredJSON( | ||
| let jsonString = try await generateStructuredJSON( | ||
| context: context, | ||
| prompt: fullPrompt, | ||
| schema: type.generationSchema, | ||
|
|
@@ -913,6 +914,41 @@ import Foundation | |
| return "\(header):\n\(schemaJSON)" | ||
| } | ||
|
|
||
| private func jsonSchemaString(for schema: GenerationSchema) throws -> String { | ||
| let encoder = JSONEncoder() | ||
| encoder.outputFormatting = [.sortedKeys] | ||
| let data = try encoder.encode(schema) | ||
| guard let jsonSchema = String(data: data, encoding: .utf8) else { | ||
| throw LlamaLanguageModelError.schemaEncodingFailed | ||
| } | ||
| return jsonSchema | ||
| } | ||
|
|
||
| private func tokenizerInfo( | ||
| for vocab: OpaquePointer, | ||
| vocabSize: Int, | ||
| stopTokens: Set<Int> | ||
| ) throws -> TokenizerInfo { | ||
| guard vocabSize > 0 else { | ||
| throw LlamaLanguageModelError.contextInitializationFailed | ||
| } | ||
|
|
||
| var encodedVocab: [String] = [] | ||
| encodedVocab.reserveCapacity(vocabSize) | ||
| for tokenId in 0 ..< vocabSize { | ||
| let token = llama_token(tokenId) | ||
| encodedVocab.append(tokenToText(vocab: vocab, token: token) ?? "") | ||
| } | ||
|
|
||
| let stopTokenIDs = stopTokens.map { Int32($0) } | ||
| return try TokenizerInfo( | ||
| encodedVocab: encodedVocab, | ||
| encoding: .byteFallback, | ||
| stopTokenIDs: stopTokenIDs, | ||
| addPrefixSpace: false | ||
| ) | ||
| } | ||
|
Comment on lines
+927
to
+950
|
||
|
|
||
| // MARK: - Structured JSON Generation | ||
|
|
||
| private func generateStructuredJSON( | ||
|
|
@@ -921,7 +957,7 @@ import Foundation | |
| schema: GenerationSchema, | ||
| maxTokens: Int, | ||
| options: ResolvedGenerationOptions | ||
| ) throws -> String { | ||
| ) async throws -> String { | ||
| guard let vocab = llama_model_get_vocab(model!) else { | ||
| throw LlamaLanguageModelError.contextInitializationFailed | ||
| } | ||
|
|
@@ -964,21 +1000,56 @@ import Foundation | |
|
|
||
| let vocabSize = Int(llama_vocab_n_tokens(vocab)) | ||
| let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens | ||
| let jsonSchema = try jsonSchemaString(for: schema) | ||
| let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true) | ||
| let eosToken = Int(llama_vocab_eos(vocab)) | ||
| let eotTokenValue = llama_vocab_eot(vocab) | ||
| let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken | ||
| let endTokens: Set<Int> = [eosToken, endOfTurnToken] | ||
|
|
||
| let tokenizerInfo = try tokenizerInfo( | ||
| for: vocab, | ||
| vocabSize: vocabSize, | ||
| stopTokens: endTokens | ||
| ) | ||
| let matcher = try await grammar.matcher( | ||
| for: tokenizerInfo, | ||
| stopTokens: endTokens.map { Int32($0) }, | ||
| terminatesWithoutStopToken: true | ||
| ) | ||
| var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: vocabSize) | ||
|
|
||
| return try withUnsafeMutablePointer(to: &batch) { batchPointer in | ||
| let backend = LlamaTokenBackend( | ||
| var backend = LlamaTokenBackend( | ||
| context: context, | ||
| vocab: vocab, | ||
| vocabSize: vocabSize, | ||
| sampler: samplerPointer, | ||
| batch: batchPointer, | ||
| position: initialPosition, | ||
| maximumTokens: maxTokens, | ||
| endTokens: [], | ||
| endTokens: endTokens, | ||
| tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) } | ||
| ) | ||
| var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) | ||
| return try generator.generate() | ||
|
|
||
| var output = "" | ||
| while backend.remainingTokens > 0 { | ||
| bitmask.reset() | ||
| let needsMask = matcher.fillNextTokenBitmask(&bitmask) | ||
| let token = try backend.sample(using: bitmask, applyMask: needsMask) | ||
| if backend.endTokens.contains(token) { | ||
| break | ||
| } | ||
| guard matcher.accept(Int32(token)) else { | ||
| throw LlamaLanguageModelError.grammarMismatch | ||
| } | ||
|
Comment on lines
+1039
to
+1045
|
||
| if let tokenText = backend.tokenText(token) { | ||
| output += tokenText | ||
| } | ||
| try backend.decode(token) | ||
| if matcher.isTerminated { break } | ||
| } | ||
| return output | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1105,6 +1176,21 @@ import Foundation | |
| } | ||
| } | ||
|
|
||
| mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int { | ||
| guard let logits = llama_get_logits(context) else { | ||
| return eosToken | ||
| } | ||
|
|
||
| if applyMask { | ||
| for tokenIndex in 0 ..< vocabSize where !bitmask.isTokenAllowed(tokenIndex) { | ||
| logits[tokenIndex] = -Float.infinity | ||
| } | ||
| } | ||
|
|
||
| let tokenIndex = batch.pointee.n_tokens - 1 | ||
| return Int(llama_sampler_sample(sampler, context, tokenIndex)) | ||
| } | ||
|
|
||
| mutating func sample(from allowedTokens: Set<Int>) throws -> Int { | ||
| guard let logits = llama_get_logits(context) else { | ||
| return eosToken | ||
|
|
@@ -1536,6 +1622,8 @@ import Foundation | |
| case insufficientMemory | ||
| case unsupportedFeature | ||
| case encoderOnlyModel | ||
| case schemaEncodingFailed | ||
| case grammarMismatch | ||
|
|
||
| public var errorDescription: String? { | ||
| switch self { | ||
|
|
@@ -1557,6 +1645,10 @@ import Foundation | |
| return "This LlamaLanguageModel does not support image segments" | ||
| case .encoderOnlyModel: | ||
| return "This model is encoder-only (e.g., BERT) and cannot generate text" | ||
| case .schemaEncodingFailed: | ||
| return "Failed to encode the JSON schema for structured generation" | ||
| case .grammarMismatch: | ||
| return "Grammar constraints could not be satisfied during generation" | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ import Foundation | |
| import MLXVLM | ||
| import Tokenizers | ||
| import Hub | ||
| import XGrammar | ||
|
|
||
| /// Wrapper to store ModelContext in NSCache (requires NSObject subclass). | ||
| private final class CachedContext: NSObject, @unchecked Sendable { | ||
|
|
@@ -782,19 +783,59 @@ import Foundation | |
| return "\(header):\n\(schemaJSON)" | ||
| } | ||
|
|
||
| private func jsonSchemaString(for schema: GenerationSchema) throws -> String { | ||
| let encoder = JSONEncoder() | ||
| encoder.outputFormatting = [.sortedKeys] | ||
| let data = try encoder.encode(schema) | ||
| guard let jsonSchema = String(data: data, encoding: .utf8) else { | ||
| throw MLXLanguageModelError.schemaEncodingFailed | ||
| } | ||
| return jsonSchema | ||
| } | ||
|
|
||
| private func tokenizerInfo( | ||
| for tokenizer: any Tokenizer, | ||
| vocabSize: Int, | ||
| stopTokens: Set<Int> | ||
| ) throws -> TokenizerInfo { | ||
| guard vocabSize > 0 else { | ||
| throw MLXLanguageModelError.invalidVocabSize | ||
| } | ||
|
|
||
| var encodedVocab: [String] = [] | ||
| encodedVocab.reserveCapacity(vocabSize) | ||
| for tokenId in 0 ..< vocabSize { | ||
| encodedVocab.append(tokenizer.convertIdToToken(tokenId) ?? "") | ||
| } | ||
|
|
||
| let stopTokenIDs = stopTokens.map { Int32($0) } | ||
| return try TokenizerInfo( | ||
| encodedVocab: encodedVocab, | ||
| encoding: .byteLevel, | ||
| stopTokenIDs: stopTokenIDs, | ||
| addPrefixSpace: false | ||
| ) | ||
| } | ||
|
|
||
| // MARK: - Structured JSON Generation | ||
|
|
||
| /// Errors that can occur when using MLXLanguageModel. | ||
| public enum MLXLanguageModelError: Error, LocalizedError { | ||
| case invalidVocabSize | ||
| case unsupportedJSONValueType | ||
| case schemaEncodingFailed | ||
| case grammarMismatch | ||
|
|
||
| public var errorDescription: String? { | ||
| switch self { | ||
| case .invalidVocabSize: | ||
| return "Invalid vocabulary size for model output" | ||
| case .unsupportedJSONValueType: | ||
| return "Unsupported JSON value type for schema conversion" | ||
| case .schemaEncodingFailed: | ||
| return "Failed to encode the JSON schema for structured generation" | ||
| case .grammarMismatch: | ||
| return "Grammar constraints could not be satisfied during generation" | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -827,13 +868,42 @@ import Foundation | |
| maximumTokens: maxTokens, | ||
| endTokens: [] | ||
| ) | ||
|
|
||
| var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) | ||
| let json = try generator.generate() | ||
| let jsonSchema = try jsonSchemaString(for: schema) | ||
| let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true) | ||
| let tokenizerInfo = try tokenizerInfo( | ||
| for: context.tokenizer, | ||
| vocabSize: backend.vocabSize, | ||
| stopTokens: backend.endTokens | ||
| ) | ||
| let matcher = try await grammar.matcher( | ||
| for: tokenizerInfo, | ||
| stopTokens: backend.endTokens.map { Int32($0) }, | ||
| terminatesWithoutStopToken: true | ||
| ) | ||
| var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: tokenizerInfo.vocabulary.size) | ||
|
|
||
| var backendState = backend | ||
| var output = "" | ||
| while backendState.remainingTokens > 0 { | ||
| bitmask.reset() | ||
| let needsMask = matcher.fillNextTokenBitmask(&bitmask) | ||
| let token = try backendState.sample(using: bitmask, applyMask: needsMask) | ||
| if backendState.endTokens.contains(token) { | ||
| break | ||
| } | ||
| guard matcher.accept(Int32(token)) else { | ||
| throw MLXLanguageModelError.grammarMismatch | ||
| } | ||
| if let tokenText = backendState.tokenText(token) { | ||
| output += tokenText | ||
| } | ||
| try backendState.decode(token) | ||
| if matcher.isTerminated { break } | ||
| } | ||
| // Ensure pending MLX operations complete before returning JSON. | ||
| // This synchronization can be a performance cost if called frequently. | ||
| Stream().synchronize() | ||
| return json | ||
| return output | ||
| } | ||
|
|
||
| /// Merges system prompts and schema instructions into a user message. | ||
|
|
@@ -1038,6 +1108,33 @@ import Foundation | |
| } | ||
| } | ||
|
|
||
| mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int { | ||
| var logits = currentLogits[0..., -1, 0...] | ||
| logits = processor?.process(logits: logits) ?? logits | ||
| if logits.dtype == .bfloat16 { | ||
| logits = logits.asType(.float32) | ||
| } | ||
|
|
||
| if applyMask { | ||
| var allowedIndices: [UInt32] = [] | ||
| allowedIndices.reserveCapacity(vocabSize) | ||
| for tokenId in 0 ..< vocabSize where bitmask.isTokenAllowed(tokenId) { | ||
| allowedIndices.append(UInt32(tokenId)) | ||
| } | ||
| guard !allowedIndices.isEmpty else { | ||
| throw MLXLanguageModelError.grammarMismatch | ||
| } | ||
| let allowedArray = MLXArray(allowedIndices) | ||
| let maskedLogits = full(logits.shape, values: -Float.infinity) | ||
| maskedLogits[0..., allowedArray] = logits[0..., allowedArray] | ||
| let sampledToken = sampler.sample(logits: maskedLogits) | ||
| return sampledToken.item(Int.self) | ||
| } | ||
|
Comment on lines
+1118
to
+1132
|
||
|
|
||
| let sampledToken = sampler.sample(logits: logits) | ||
| return sampledToken.item(Int.self) | ||
| } | ||
|
|
||
| mutating func sample(from allowedTokens: Set<Int>) throws -> Int { | ||
| guard !allowedTokens.isEmpty else { | ||
| throw ConstrainedGenerationError.tokenizationFailed | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XGrammaris only imported/used inside theMLXandLlamaconditional compilation blocks, but the target depends on it unconditionally. This forces consumers to fetch/buildswift-xgrammareven when neither trait is enabled, and can break traitless builds if that dependency has platform/toolchain constraints. Consider making theXGrammarproduct dependency conditional on theMLXand/orLlamatraits (similar toMLXLLM/LlamaSwift).