diff --git a/Package.resolved b/Package.resolved index 5aa2b99..837d776 100644 --- a/Package.resolved +++ b/Package.resolved @@ -19,6 +19,33 @@ "version" : "1.3.1" } }, + { + "identity" : "llama.swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/llama.swift", + "state" : { + "revision" : "4d57cff84ba85914baa39850157e7c27684db9c8", + "version" : "2.7966.0" + } + }, + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift", + "state" : { + "revision" : "072b684acaae80b6a463abab3a103732f33774bf", + "version" : "0.29.1" + } + }, + { + "identity" : "mlx-swift-lm", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift-lm", + "state" : { + "revision" : "5064b8c5d8ed3b0bbb71385c4124f0fc102e74a2", + "version" : "2.29.3" + } + }, { "identity" : "partialjsondecoder", "kind" : "remoteSourceControl", @@ -37,6 +64,24 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-jinja", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-jinja.git", + "state" : { + "revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0", + "version" : "2.3.1" + } + }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2", + "version" : "1.1.1" + } + }, { "identity" : "swift-syntax", "kind" : "remoteSourceControl", @@ -45,6 +90,15 @@ "revision" : "0687f71944021d616d34d922343dcef086855920", "version" : "600.0.1" } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers", + "state" : { + "revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0", + "version" : "1.1.6" + } } ], "version" : 3 diff --git a/Sources/AnyLanguageModel/Models/CoreMLLanguageModel.swift b/Sources/AnyLanguageModel/Models/CoreMLLanguageModel.swift index cde04e4..a813dbb 100644 --- a/Sources/AnyLanguageModel/Models/CoreMLLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/CoreMLLanguageModel.swift @@ -2,6 +2,7 @@ import Foundation import CoreML import Tokenizers + import JSONSchema @preconcurrency import Generation @preconcurrency import Models @@ -75,13 +76,25 @@ includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("CoreMLLanguageModel only supports generating String content") - } - try validateNoImageSegments(in: session) + if type != String.self { + let jsonString = try await generateStructuredJSON( + session: session, + prompt: prompt, + schema: type.generationSchema, + options: options, + includeSchemaInPrompt: includeSchemaInPrompt + ) + let generatedContent = try GeneratedContent(json: jsonString) + let content = try type.init(generatedContent) + return LanguageModelSession.Response( + content: content, + rawContent: generatedContent, + transcriptEntries: ArraySlice([]) + ) + } + // Convert AnyLanguageModel GenerationOptions to swift-transformers GenerationConfig let generationConfig = toGenerationConfig(options) @@ -99,15 +112,25 @@ // Reset model state for new generation await model.resetState() - let response = await model.generate( + let outputTokens = await model.generate( config: generationConfig, tokens: tokens, model: model.callAsFunction ) + let promptTextPrefix = tokenizer.decode(tokens: tokens) + let fullText = tokenizer.decode(tokens: outputTokens) + let assistantText: String + if fullText.hasPrefix(promptTextPrefix) { + let startIdx = fullText.index(fullText.startIndex, offsetBy: promptTextPrefix.count) + assistantText = String(fullText[startIdx...]) + } else { + assistantText = fullText + } + return LanguageModelSession.Response( - content: response as! Content, - rawContent: GeneratedContent(response), + content: assistantText as! Content, + rawContent: GeneratedContent(assistantText), transcriptEntries: ArraySlice([]) ) } @@ -261,28 +284,343 @@ // MARK: - - private func toGenerationConfig(_ options: GenerationOptions) -> GenerationConfig { - var config = GenerationConfig(maxNewTokens: options.maximumResponseTokens ?? 2048) + @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) + extension CoreMLLanguageModel { + private func toGenerationConfig(_ options: GenerationOptions) -> GenerationConfig { + var config = GenerationConfig(maxNewTokens: options.maximumResponseTokens ?? 2048) + + // Map temperature + if let temperature = options.temperature { + config.temperature = Float(temperature) + } + + // Map sampling mode + if let sampling = options.sampling { + switch sampling.mode { + case .greedy: + config.doSample = false + case .topK(let k, _): + config.doSample = true + config.topK = k + case .nucleus(let p, _): + config.doSample = true + config.topP = Float(p) + } + } + + return config + } + + private func toStructuredGenerationConfig(_ options: GenerationOptions) -> GenerationConfig { + var config = GenerationConfig(maxNewTokens: options.maximumResponseTokens ?? 512) + + config.doSample = true + if let temperature = options.temperature { + config.temperature = Float(temperature) + } else { + config.temperature = 0.2 + } + config.topP = 0.95 + config.repetitionPenalty = 1.1 + + if let sampling = options.sampling { + switch sampling.mode { + case .greedy: + config.doSample = false + case .topK(let k, _): + config.doSample = true + config.topK = k + case .nucleus(let p, _): + config.doSample = true + config.topP = Float(p) + } + } + + return config + } + + private func generateStructuredJSON( + session: LanguageModelSession, + prompt: Prompt, + schema: GenerationSchema, + options: GenerationOptions, + includeSchemaInPrompt: Bool + ) async throws -> String { + let maxTokens = options.maximumResponseTokens ?? 512 + var generationConfig = toStructuredGenerationConfig(options) + + let promptTokens = try structuredPromptTokens( + in: session, + prompt: prompt, + schema: schema, + includeSchemaInPrompt: includeSchemaInPrompt + ) + + generationConfig.maxLength = generationConfig.maxNewTokens + promptTokens.count + generationConfig.eosTokenId = tokenizer.eosTokenId + generationConfig.bosTokenId = tokenizer.bosTokenId + + await model.resetState() + + let tokenTensor = MLTensor(promptTokens.map(Int32.init)).expandingShape(at: 0) + let initialLogits = await model.predictNextTokenScores(tokenTensor, config: generationConfig) + let endTokens = buildEndTokens(tokenizer: tokenizer) + + let backend = try CoreMLTokenBackend( + model: model, + tokenizer: tokenizer, + config: generationConfig, + tokens: promptTokens, + initialLogits: initialLogits, + maximumTokens: maxTokens, + endTokens: endTokens + ) + var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) + let json = try await generator.generate() + return json + } + + private func structuredPromptTokens( + in session: LanguageModelSession, + prompt: Prompt, + schema: GenerationSchema, + includeSchemaInPrompt: Bool + ) throws -> [Int] { + if let chatTemplateHandler = chatTemplateHandler { + var messages = chatTemplateHandler(session.instructions, prompt) + if includeSchemaInPrompt { + let schemaPrompt = schemaPrompt(for: schema) + if !schemaPrompt.isEmpty { + messages.insert(["role": "system", "content": schemaPrompt], at: 0) + } + } + let toolSpecs: [ToolSpec]? = toolsHandler?(session.tools) + return try tokenizer.applyChatTemplate(messages: messages, tools: toolSpecs) + } + + var text = prompt.description + if includeSchemaInPrompt { + let schemaPrompt = schemaPrompt(for: schema) + if !schemaPrompt.isEmpty { + text = "\(schemaPrompt)\n\n\(text)" + } + } + return tokenizer.encode(text: text) + } + + private func schemaPrompt(for schema: GenerationSchema) -> String { + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + guard + let data = try? encoder.encode(schema), + let jsonSchema = try? JSONDecoder().decode(JSONSchema.self, from: data), + let schemaJSON = String(data: data, encoding: .utf8) + else { + return schema.schemaPrompt() + } + + var header = "Respond with valid JSON matching this \(jsonSchema.typeName) schema" + if let description = jsonSchema.description, !description.isEmpty { + header += " (\(description))" + } - // Map temperature - if let temperature = options.temperature { - config.temperature = Float(temperature) + if let constValue = jsonSchema.const, + let data = try? encoder.encode(constValue), + let constString = String(data: data, encoding: .utf8) + { + header += ". Expected value: \(constString)" + } else if let enumValues = jsonSchema.enum, !enumValues.isEmpty, + let data = try? encoder.encode(JSONValue.array(enumValues)), + let enumString = String(data: data, encoding: .utf8) + { + header += ". Allowed values: \(enumString)" + } + + return "\(header):\n\(schemaJSON)" } - // Map sampling mode - if let sampling = options.sampling { - switch sampling.mode { - case .greedy: - config.doSample = false - case .topK(let k, _): - config.doSample = true - config.topK = k - case .nucleus(let p, _): - config.doSample = true - config.topP = Float(p) + private func buildEndTokens(tokenizer: any Tokenizer) -> Set { + var tokens: Set = [] + if let eosTokenId = tokenizer.eosTokenId { + tokens.insert(eosTokenId) + } + if let eosToken = tokenizer.eosToken, let eosTokenId = tokenizer.convertTokenToId(eosToken) { + tokens.insert(eosTokenId) + } + return tokens + } + + private struct CoreMLTokenBackend: TokenBackend { + struct MaskCacheKey: Hashable, Sendable { + let vocabSize: Int + let tokens: Set + } + + let model: Models.LanguageModel + let tokenizer: any Tokenizer + let config: GenerationConfig + let logitsProcessorList: LogitsProcessorList + let endTokens: Set + let eosToken: Int + let vocabSize: Int + + var tokens: [Int] + var currentLogits: MLTensor + var remainingTokens: Int + let totalTokenBudget: Int + var maskCache: [MaskCacheKey: MLTensor] = [:] + + init( + model: Models.LanguageModel, + tokenizer: any Tokenizer, + config: GenerationConfig, + tokens: [Int], + initialLogits: MLTensor, + maximumTokens: Int, + endTokens: Set + ) throws { + self.model = model + self.tokenizer = tokenizer + self.config = config + self.tokens = tokens + self.currentLogits = initialLogits + self.remainingTokens = maximumTokens + self.totalTokenBudget = maximumTokens + self.endTokens = endTokens + self.eosToken = config.eosTokenId ?? tokenizer.eosTokenId ?? 0 + self.vocabSize = initialLogits.shape.last ?? 0 + self.logitsProcessorList = CoreMLLanguageModel.makeLogitsProcessorList(config: config) + } + + func tokenize(_ text: String) throws -> [Int] { + tokenizer.encode(text: text, addSpecialTokens: false) } + + func tokenText(_ token: Int) -> String? { + let decoded = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + return decoded.isEmpty ? nil : decoded + } + + func isSpecialToken(_ token: Int) -> Bool { + let raw = tokenizer.decode(tokens: [token], skipSpecialTokens: false) + guard !raw.isEmpty else { return false } + let filtered = tokenizer.decode(tokens: [token], skipSpecialTokens: true) + return filtered.isEmpty + } + + mutating func decode(_ token: Int) async throws { + tokens.append(token) + remainingTokens -= 1 + let tokenTensor = MLTensor(tokens.map(Int32.init)).expandingShape(at: 0) + currentLogits = await model.predictNextTokenScores(tokenTensor, config: config) + } + + mutating func sample(from allowedTokens: Set) async throws -> Int { + guard !allowedTokens.isEmpty else { + throw ConstrainedGenerationError.tokenizationFailed + } + + // Run logits processors on Float32 scores for stable behavior + let inputIds = MLTensor(tokens.map(Int32.init)).expandingShape(at: 0) + let floatScores = + currentLogits.scalarType == Float.self + ? currentLogits + : currentLogits.cast(to: Float.self) + let processedScores = await logitsProcessorList(inputIds, floatScores) + + // Limit candidates to allowed token ids within the current vocab + let vocabSize = processedScores.shape.last ?? self.vocabSize + let candidateTokens = allowedTokens.filter { $0 >= 0 && $0 < vocabSize }.sorted() + guard !candidateTokens.isEmpty else { + throw ConstrainedGenerationError.tokenizationFailed + } + + // Build or reuse a mask tensor that keeps only the allowed tokens. + let cacheKey = MaskCacheKey(vocabSize: vocabSize, tokens: Set(candidateTokens)) + let maskTensor: MLTensor + if let cachedMask = maskCache[cacheKey] { + maskTensor = cachedMask + } else { + var maskValues = Array(repeating: -Float.infinity, count: vocabSize) + for token in candidateTokens { + maskValues[token] = 0 + } + let builtMask = MLTensor(maskValues).reshaped(to: processedScores.shape) + maskCache[cacheKey] = builtMask + maskTensor = builtMask + } + let maskedScores = processedScores + maskTensor + + let tokenTensor: MLTensor + if config.doSample { + // Multinomial sample from candidate probabilities + let probs = maskedScores.softmax(alongAxis: -1) + let prefixShape = Array(maskedScores.shape.dropLast()) + let randomShape = prefixShape + [1] + let rndTensor = MLTensor(randomUniform: randomShape, in: 0 ..< 1, scalarType: Float.self) + let cumulativeProbs = probs.cumulativeSum(alongAxis: -1) + let rnd = + cumulativeProbs.scalarType == Float.self + ? rndTensor : rndTensor.cast(to: cumulativeProbs.scalarType) + + let mask = cumulativeProbs .< rnd + let penalized = mask * 1000.0 + let indexed = penalized + cumulativeProbs + let sampledIndex = indexed.argmin(alongAxis: -1) + tokenTensor = + sampledIndex.scalarType == Int32.self ? sampledIndex : sampledIndex.cast(to: Int32.self) + } else { + // Greedy select the best-scoring candidate + let selectedIndex = maskedScores.argmax(alongAxis: -1) + tokenTensor = + selectedIndex.scalarType == Int32.self ? selectedIndex : selectedIndex.cast(to: Int32.self) + } + + // Materialize the chosen token id + let tokenArray = await tokenTensor.shapedArray(of: Int32.self) + guard let token = tokenArray.scalars.last else { + throw ConstrainedGenerationError.tokenizationFailed + } + + return Int(token) + } + } + + fileprivate static func makeLogitsProcessorList(config: GenerationConfig) -> LogitsProcessorList { + var processors: [any LogitsProcessor] = [] + + if config.repetitionPenalty != 1.0 { + if let processor = try? RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty)) { + processors.append(processor) + } + } + + if config.temperature > 0 && config.temperature != 1.0 { + if let processor = try? TemperatureLogitsWarper(temperature: config.temperature) { + processors.append(processor) + } + } + + if config.topK > 0 && config.topK < Int.max { + if let processor = try? TopKLogitsWarper(topK: config.topK) { + processors.append(processor) + } + } + + if config.topP < 1.0 { + if let processor = try? TopPLogitsWarper(topP: Float(config.topP)) { + processors.append(processor) + } + } + + if let minP = config.minP { + if let processor = try? MinPLogitsWarper(minP: Float(minP)) { + processors.append(processor) + } + } + + return LogitsProcessorList(processors: processors) } - return config } #endif // CoreML diff --git a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift index 9142187..6bcd804 100644 --- a/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift @@ -532,7 +532,7 @@ import Foundation ) } else { let maxTokens = structuredOptions.maximumResponseTokens ?? 512 - let jsonString = try generateStructuredJSON( + let jsonString = try await generateStructuredJSON( context: context, prompt: fullPrompt, schema: type.generationSchema, @@ -921,7 +921,7 @@ import Foundation schema: GenerationSchema, maxTokens: Int, options: ResolvedGenerationOptions - ) throws -> String { + ) async throws -> String { guard let vocab = llama_model_get_vocab(model!) else { throw LlamaLanguageModelError.contextInitializationFailed } @@ -931,11 +931,16 @@ import Foundation throw LlamaLanguageModelError.tokenizationFailed } - var batch = llama_batch_init(Int32(options.batchSize), 0, 1) - defer { llama_batch_free(batch) } + let batchPointer = UnsafeMutablePointer.allocate(capacity: 1) + batchPointer.initialize(to: llama_batch_init(Int32(options.batchSize), 0, 1)) + defer { + llama_batch_free(batchPointer.pointee) + batchPointer.deinitialize(count: 1) + batchPointer.deallocate() + } let hasEncoder = try prepareInitialBatch( - batch: &batch, + batch: &batchPointer.pointee, promptTokens: promptTokens, model: model!, vocab: vocab, @@ -963,23 +968,21 @@ import Foundation applySampling(sampler: samplerPointer, effectiveTemperature: options.temperature, options: options) let vocabSize = Int(llama_vocab_n_tokens(vocab)) - let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens + let initialPosition: Int32 = hasEncoder ? 1 : batchPointer.pointee.n_tokens - return try withUnsafeMutablePointer(to: &batch) { batchPointer in - let backend = LlamaTokenBackend( - context: context, - vocab: vocab, - vocabSize: vocabSize, - sampler: samplerPointer, - batch: batchPointer, - position: initialPosition, - maximumTokens: maxTokens, - endTokens: [], - tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) } - ) - var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - return try generator.generate() - } + let backend = LlamaTokenBackend( + context: context, + vocab: vocab, + vocabSize: vocabSize, + sampler: samplerPointer, + batch: batchPointer, + position: initialPosition, + maximumTokens: maxTokens, + endTokens: [], + tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) } + ) + var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) + return try await generator.generate() } private struct LlamaTokenBackend: TokenBackend { @@ -1080,7 +1083,7 @@ import Foundation tokenToTextFn(token) } - mutating func decode(_ token: Int) throws { + mutating func decode(_ token: Int) async throws { let llamaToken = llama_token(token) batch.pointee.n_tokens = 1 @@ -1105,7 +1108,7 @@ import Foundation } } - mutating func sample(from allowedTokens: Set) throws -> Int { + mutating func sample(from allowedTokens: Set) async throws -> Int { guard let logits = llama_get_logits(context) else { return eosToken } diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 7b781b2..130b103 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -829,7 +829,7 @@ import Foundation ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let json = try generator.generate() + let json = try await generator.generate() // Ensure pending MLX operations complete before returning JSON. // This synchronization can be a performance cost if called frequently. Stream().synchronize() @@ -1021,7 +1021,7 @@ import Foundation return decoded.isEmpty ? nil : decoded } - mutating func decode(_ token: Int) throws { + mutating func decode(_ token: Int) async throws { let inputText = MLXLMCommon.LMInput.Text(tokens: MLXArray([Int32(token)])) let output = model( inputText[text: .newAxis], @@ -1038,7 +1038,7 @@ import Foundation } } - mutating func sample(from allowedTokens: Set) throws -> Int { + mutating func sample(from allowedTokens: Set) async throws -> Int { guard !allowedTokens.isEmpty else { throw ConstrainedGenerationError.tokenizationFailed } diff --git a/Sources/AnyLanguageModel/StructuredGeneration.swift b/Sources/AnyLanguageModel/StructuredGeneration.swift index f5db790..9ca0a3c 100644 --- a/Sources/AnyLanguageModel/StructuredGeneration.swift +++ b/Sources/AnyLanguageModel/StructuredGeneration.swift @@ -10,8 +10,8 @@ protocol TokenBackend { func tokenize(_ text: String) throws -> [Int] func tokenText(_ token: Int) -> String? func isSpecialToken(_ token: Int) -> Bool - mutating func decode(_ token: Int) throws - mutating func sample(from allowedTokens: Set) throws -> Int + mutating func decode(_ token: Int) async throws + mutating func sample(from allowedTokens: Set) async throws -> Int var eosToken: Int { get } var endTokens: Set { get } @@ -125,9 +125,9 @@ struct ConstrainedJSONGenerator { /// - Returns: A JSON string that satisfies the schema. If the backend emits /// an end token early, the partial output is returned. /// - Throws: ``ConstrainedGenerationError`` if generation fails. - mutating func generate() throws -> String { + mutating func generate() async throws -> String { do { - return try generateNode(schema.root) + return try await generateNode(schema.root) } catch let error as ConstrainedGenerationError { if case .earlyTermination(let partial) = error { return partial @@ -201,6 +201,7 @@ struct ConstrainedJSONGenerator { private static func buildValidIntegerTokens(backend: Backend) -> Set { var allowed = Set() for token in 0 ..< backend.vocabSize { + if backend.isSpecialToken(token) { continue } guard let text = backend.tokenText(token), !text.isEmpty else { continue } if text.allSatisfy({ $0.isNumber || $0 == "-" }), text.contains(where: { $0.isNumber }) @@ -214,6 +215,7 @@ struct ConstrainedJSONGenerator { private static func buildValidDecimalTokens(backend: Backend) -> Set { var allowed = Set() for token in 0 ..< backend.vocabSize { + if backend.isSpecialToken(token) { continue } guard let text = backend.tokenText(token), !text.isEmpty else { continue } if text.allSatisfy({ $0.isNumber || $0 == "-" || $0 == "." }), text.contains(where: { $0.isNumber }) @@ -224,12 +226,12 @@ struct ConstrainedJSONGenerator { return allowed } - private mutating func emit(_ text: String) throws -> String { + private mutating func emit(_ text: String) async throws -> String { for token in try backend.tokenize(text) { guard backend.remainingTokens > 0 else { throw ConstrainedGenerationError.tokenBudgetExceeded } - try backend.decode(token) + try await backend.decode(token) } emittedText += text return text @@ -244,13 +246,13 @@ struct ConstrainedJSONGenerator { return min(remainingAfterClosingQuote, perStringLimit) } - private mutating func generateFreeString(maxTokens: Int) throws -> String { + private mutating func generateFreeString(maxTokens: Int) async throws -> String { var result = "" var generated = 0 while backend.remainingTokens > 0, generated < maxTokens { let allowed = result.isEmpty ? stringInitialAllowedTokens : stringContinuationAllowedTokens - let token = try backend.sample(from: allowed) + let token = try await backend.sample(from: allowed) if backend.endTokens.contains(token) { throw ConstrainedGenerationError.earlyTermination(emittedText) } @@ -260,13 +262,13 @@ struct ConstrainedJSONGenerator { result += text emittedText += text generated += 1 - try backend.decode(token) + try await backend.decode(token) } return result } - private mutating func generateChoice(_ candidates: [String]) throws -> String { + private mutating func generateChoice(_ candidates: [String]) async throws -> String { guard !candidates.isEmpty else { throw ConstrainedGenerationError.tokenizationFailed } @@ -284,7 +286,7 @@ struct ConstrainedJSONGenerator { if hasEmptyCandidate || hasPrefixCollision { let chosen = deterministicChoice(from: candidates) if !chosen.isEmpty { - _ = try emit(chosen) + _ = try await emit(chosen) } return chosen } @@ -303,14 +305,14 @@ struct ConstrainedJSONGenerator { } ) - let token = try backend.sample(from: allowed) + let token = try await backend.sample(from: allowed) if backend.endTokens.contains(token) { throw ConstrainedGenerationError.earlyTermination(emittedText) } let text = backend.tokenText(token) ?? "" emitted += text emittedText += text - try backend.decode(token) + try await backend.decode(token) prefixes = prefixes.filter { $0.count > position && $0[position] == token } position += 1 @@ -337,14 +339,14 @@ struct ConstrainedJSONGenerator { return min(backend.remainingTokens, limit) } - private mutating func generateNumber(_ node: GenerationSchema.NumberNode) throws -> String { + private mutating func generateNumber(_ node: GenerationSchema.NumberNode) async throws -> String { let allowedTokens = node.integerOnly ? integerTerminators : doubleTerminators var result = "" let maxTokens = maxNumberTokens(for: node) var generatedTokens = 0 while backend.remainingTokens > 0, generatedTokens < maxTokens { - let token = try backend.sample(from: allowedTokens) + let token = try await backend.sample(from: allowedTokens) if backend.endTokens.contains(token) { throw ConstrainedGenerationError.earlyTermination(emittedText) } @@ -354,7 +356,7 @@ struct ConstrainedJSONGenerator { result += text emittedText += text generatedTokens += 1 - try backend.decode(token) + try await backend.decode(token) } guard !result.isEmpty else { @@ -389,58 +391,58 @@ struct ConstrainedJSONGenerator { } } - private mutating func generateNode(_ node: GenerationSchema.Node) throws -> String { + private mutating func generateNode(_ node: GenerationSchema.Node) async throws -> String { guard backend.remainingTokens > 0 else { throw ConstrainedGenerationError.tokenBudgetExceeded } switch node { case .object(let objectNode): - return try generateObject(objectNode) + return try await generateObject(objectNode) case .array(let arrayNode): - return try generateArray(arrayNode) + return try await generateArray(arrayNode) case .string(let stringNode): - return try generateString(stringNode) + return try await generateString(stringNode) case .number(let numberNode): - return try generateNumber(numberNode) + return try await generateNumber(numberNode) case .boolean: - return try generateChoice(["true", "false"]) + return try await generateChoice(["true", "false"]) case .ref(let typeName): guard let referenced = schema.defs[typeName] else { throw ConstrainedGenerationError.missingReference(typeName) } - return try generateNode(referenced) + return try await generateNode(referenced) case .anyOf(let variants): guard !variants.isEmpty else { throw ConstrainedGenerationError.emptyAnyOf } if variants.count == 1 { - return try generateNode(variants[0]) + return try await generateNode(variants[0]) } // Choose the first variant to keep selection deterministic. - return try generateNode(variants[0]) + return try await generateNode(variants[0]) } } - private mutating func generateObject(_ node: GenerationSchema.ObjectNode) throws -> String { + private mutating func generateObject(_ node: GenerationSchema.ObjectNode) async throws -> String { let keys = node.properties.keys.sorted() let includedKeys = keys.filter { shouldIncludeOptionalProperty($0, required: node.required) } - var output = try emit("{") + var output = try await emit("{") for (index, key) in includedKeys.enumerated() { - output += try emit("\"\(key)\":") - output += try generateNode(node.properties[key] ?? .string(.init())) + output += try await emit("\"\(key)\":") + output += try await generateNode(node.properties[key] ?? .string(.init())) if index < includedKeys.count - 1 { - output += try emit(",") + output += try await emit(",") } } - output += try emit("}") + output += try await emit("}") return output } - private mutating func generateArray(_ node: GenerationSchema.ArrayNode) throws -> String { + private mutating func generateArray(_ node: GenerationSchema.ArrayNode) async throws -> String { // Derive a default item count from the total token budget when the schema // does not specify explicit minItems/maxItems. We use a small fraction of the // budget and clamp it to a reasonable range to avoid overlong arrays. @@ -464,21 +466,21 @@ struct ConstrainedJSONGenerator { } else { count = defaultCount } - var output = try emit("[") + var output = try await emit("[") for index in 0 ..< count { - output += try generateNode(node.items) + output += try await generateNode(node.items) if index < count - 1 { - output += try emit(",") + output += try await emit(",") } } - output += try emit("]") + output += try await emit("]") return output } - private mutating func generateString(_ node: GenerationSchema.StringNode) throws -> String { - var output = try emit("\"") + private mutating func generateString(_ node: GenerationSchema.StringNode) async throws -> String { + var output = try await emit("\"") let content: String let pattern = node.pattern let regex = try pattern.map { try compilePattern($0) } @@ -496,9 +498,9 @@ struct ConstrainedJSONGenerator { } else { applicableChoices = choices } - content = try generateChoice(applicableChoices) + content = try await generateChoice(applicableChoices) } else { - content = try generateFreeString(maxTokens: maxFreeStringTokens()) + content = try await generateFreeString(maxTokens: maxFreeStringTokens()) } if let pattern, let regex { @@ -510,7 +512,7 @@ struct ConstrainedJSONGenerator { } output += content - output += try emit("\"") + output += try await emit("\"") return output } diff --git a/Tests/AnyLanguageModelTests/CoreMLLanguageModelTests.swift b/Tests/AnyLanguageModelTests/CoreMLLanguageModelTests.swift index 62e2768..432129a 100644 --- a/Tests/AnyLanguageModelTests/CoreMLLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/CoreMLLanguageModelTests.swift @@ -21,24 +21,35 @@ import Testing return true }() - @Suite("CoreMLLanguageModel", .enabled(if: shouldRunCoreMLTests)) + @Suite("CoreMLLanguageModel", .enabled(if: shouldRunCoreMLTests), .serialized) struct CoreMLLanguageModelTests { let modelId = "apple/mistral-coreml" let modelPackageName = "StatefulMistral7BInstructInt4.mlpackage" @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) - func getModel() async throws -> CoreMLLanguageModel { + private static let modelTask = Task { let hasToken = ProcessInfo.processInfo.environment["HF_TOKEN"] != nil let hubApi = HubApi(useOfflineMode: !hasToken) let repoURL = try await hubApi.snapshot( - from: Hub.Repo(id: modelId, type: .models), + from: Hub.Repo(id: "apple/mistral-coreml", type: .models), matching: "*Int4.mlpackage/**" ) { progress in print("Download progress: \(Int(progress.fractionCompleted * 100))%") } - let modelURL = repoURL.appending(component: modelPackageName) - return try await CoreMLLanguageModel(url: modelURL) + let modelURL = repoURL.appending(component: "StatefulMistral7BInstructInt4.mlpackage") + let compiledURL: URL + if modelURL.pathExtension == "mlmodelc" { + compiledURL = modelURL + } else { + compiledURL = try await MLModel.compileModel(at: modelURL) + } + return try await CoreMLLanguageModel(url: compiledURL) + } + + @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) + func getModel() async throws -> CoreMLLanguageModel { + try await Self.modelTask.value } @Test @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) @@ -165,5 +176,69 @@ import Testing #expect(Bool(true)) } } + + @Test @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) + func structuredGenerationSimpleString() async throws { + let model = try await getModel() + let session = LanguageModelSession( + model: model, + instructions: "You are a helpful assistant that generates structured data." + ) + let response = try await session.respond( + to: "Generate a greeting message that says hello", + generating: SimpleString.self + ) + #expect(!response.content.message.isEmpty) + } + + @Test @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) + func structuredGenerationSimpleInt() async throws { + let model = try await getModel() + let session = LanguageModelSession( + model: model, + instructions: "You are a helpful assistant that generates structured data." + ) + let response = try await session.respond( + to: "Generate a count value of 42", + generating: SimpleInt.self + ) + #expect(response.content.count == 42) + let jsonData = response.rawContent.jsonString.data(using: .utf8) + #expect(jsonData != nil) + if let jsonData { + let json = try JSONSerialization.jsonObject(with: jsonData) + let dictionary = json as? [String: Any] + #expect(dictionary != nil) + if let dictionary { + let countValue = dictionary["count"] as? NSNumber + #expect(countValue?.intValue == 42) + } + } + } + + @Test @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) + func structuredGenerationSimpleBool() async throws { + let model = try await getModel() + let session = LanguageModelSession( + model: model, + instructions: "You are a helpful assistant that generates structured data." + ) + let response = try await session.respond( + to: "Generate a boolean value: true", + generating: SimpleBool.self + ) + #expect(response.content.value == true) + let jsonData = response.rawContent.jsonString.data(using: .utf8) + #expect(jsonData != nil) + if let jsonData { + let json = try JSONSerialization.jsonObject(with: jsonData) + let dictionary = json as? [String: Any] + #expect(dictionary != nil) + if let dictionary { + let boolValue = dictionary["value"] as? Bool + #expect(boolValue == true) + } + } + } } #endif // CoreML diff --git a/Tests/AnyLanguageModelTests/Shared/MockTokenBackend.swift b/Tests/AnyLanguageModelTests/Shared/MockTokenBackend.swift index de6595c..9014396 100644 --- a/Tests/AnyLanguageModelTests/Shared/MockTokenBackend.swift +++ b/Tests/AnyLanguageModelTests/Shared/MockTokenBackend.swift @@ -70,12 +70,12 @@ struct MockTokenBackend: TokenBackend { specialTokens.contains(token) } - mutating func decode(_ token: Int) throws { + mutating func decode(_ token: Int) async throws { capture.record(token: token) remainingTokens -= 1 } - mutating func sample(from allowedTokens: Set) throws -> Int { + mutating func sample(from allowedTokens: Set) async throws -> Int { guard !allowedTokens.isEmpty else { throw ConstrainedGenerationError.tokenizationFailed } diff --git a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift index 180ae18..3b5f0b1 100644 --- a/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift +++ b/Tests/AnyLanguageModelTests/StructuredGenerationTests.swift @@ -62,7 +62,7 @@ struct StructuredGenerationTests { return (tokenToText, textToTokens) } - @Test func numberOutOfRangeThrows() throws { + @Test func numberOutOfRangeThrows() async throws { let maps = baseTokenMaps() let numberNode = GenerationSchema.NumberNode( description: nil, @@ -85,7 +85,7 @@ struct StructuredGenerationTests { var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) do { - _ = try generator.generate() + _ = try await generator.generate() Issue.record("Expected number out-of-range error.") } catch let error as ConstrainedGenerationError { guard case .numberOutOfRange = error else { @@ -95,7 +95,7 @@ struct StructuredGenerationTests { } } - @Test func patternMismatchThrows() throws { + @Test func patternMismatchThrows() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -119,7 +119,7 @@ struct StructuredGenerationTests { var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) do { - _ = try generator.generate() + _ = try await generator.generate() Issue.record("Expected pattern mismatch error.") } catch let error as ConstrainedGenerationError { guard case .patternMismatch = error else { @@ -129,7 +129,7 @@ struct StructuredGenerationTests { } } - @Test func emptyStringEnumProducesEmptyValue() throws { + @Test func emptyStringEnumProducesEmptyValue() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -147,11 +147,11 @@ struct StructuredGenerationTests { ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == "\"\"") } - @Test func prefixEnumSelectsLongerCandidateDeterministically() throws { + @Test func prefixEnumSelectsLongerCandidateDeterministically() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -169,11 +169,11 @@ struct StructuredGenerationTests { ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == "\"ab\"") } - @Test func eosStopsGenerationAndReturnsPartialOutput() throws { + @Test func eosStopsGenerationAndReturnsPartialOutput() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -193,7 +193,7 @@ struct StructuredGenerationTests { ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == "\"a") } @@ -228,7 +228,7 @@ struct StructuredGenerationTests { } } - @Test func outputMatchesDecodedTokens() throws { + @Test func outputMatchesDecodedTokens() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -250,11 +250,11 @@ struct StructuredGenerationTests { let capture = backend.capture var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == capture.decodedText) } - @Test func negativeIntegerWithinRange() throws { + @Test func negativeIntegerWithinRange() async throws { let maps = baseTokenMaps() let numberNode = GenerationSchema.NumberNode( description: nil, @@ -276,11 +276,11 @@ struct StructuredGenerationTests { ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == "-1") } - @Test func decimalOutOfRangeThrows() throws { + @Test func decimalOutOfRangeThrows() async throws { let maps = baseTokenMaps() let numberNode = GenerationSchema.NumberNode( description: nil, @@ -303,7 +303,7 @@ struct StructuredGenerationTests { var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) do { - _ = try generator.generate() + _ = try await generator.generate() Issue.record("Expected number out-of-range error.") } catch let error as ConstrainedGenerationError { guard case .numberOutOfRange = error else { @@ -313,7 +313,7 @@ struct StructuredGenerationTests { } } - @Test func tokenBudgetExceededThrows() throws { + @Test func tokenBudgetExceededThrows() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -332,7 +332,7 @@ struct StructuredGenerationTests { var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) do { - _ = try generator.generate() + _ = try await generator.generate() Issue.record("Expected token budget exceeded error.") } catch let error as ConstrainedGenerationError { guard case .tokenBudgetExceeded = error else { @@ -342,7 +342,7 @@ struct StructuredGenerationTests { } } - @Test func anyOfSingleVariantUsesOnlyChoice() throws { + @Test func anyOfSingleVariantUsesOnlyChoice() async throws { let maps = baseTokenMaps() let stringNode = GenerationSchema.StringNode( description: nil, @@ -363,7 +363,7 @@ struct StructuredGenerationTests { ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == "\"a\"") } @@ -398,7 +398,7 @@ struct StructuredGenerationTests { } } - @Test func invalidArrayBoundsThrows() throws { + @Test func invalidArrayBoundsThrows() async throws { let maps = baseTokenMaps() let arrayNode = GenerationSchema.ArrayNode( description: nil, @@ -418,7 +418,7 @@ struct StructuredGenerationTests { var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) do { - _ = try generator.generate() + _ = try await generator.generate() Issue.record("Expected invalid array bounds error.") } catch let error as ConstrainedGenerationError { guard case .invalidArrayBounds = error else { @@ -428,7 +428,7 @@ struct StructuredGenerationTests { } } - @Test func arrayCountIsDeterministic() throws { + @Test func arrayCountIsDeterministic() async throws { let maps = baseTokenMaps() let arrayNode = GenerationSchema.ArrayNode( description: nil, @@ -447,7 +447,7 @@ struct StructuredGenerationTests { ) var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema) - let result = try generator.generate() + let result = try await generator.generate() #expect(result == "[\"a\",\"a\",\"a\"]") } }