Skip to content

Commit 4f8ec6f

Browse files
kochj23claude
andcommitted
security: Restrict model loading to SafeTensors format only
PyTorch pickle files (.bin, .pt) can execute arbitrary code during deserialization — a known supply chain attack vector. MLX Code now: - Rejects any model directory that lacks .safetensors files - Skips non-SafeTensors models during discovery (won't appear in list) - Throws MLXServiceError.unsafeModelFormat with an explanatory message if a user somehow tries to load a rejected model - Logs a warning when a model is skipped during discovery All mlx-community models on HuggingFace use SafeTensors by default, so this has no impact on normal usage. Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
1 parent 04f124c commit 4f8ec6f

1 file changed

Lines changed: 39 additions & 0 deletions

File tree

MLX Code/Services/MLXService.swift

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ actor MLXService {
7373
}
7474

7575
let directory = URL(fileURLWithPath: expandedPath)
76+
77+
// Only allow SafeTensors format — reject PyTorch pickle (.bin/.pt) models
78+
guard isSafeTensorsModel(at: directory) else {
79+
await SecureLogger.shared.error(
80+
"Rejected unsafe model format at: \(expandedPath)",
81+
category: "MLXService"
82+
)
83+
throw MLXServiceError.unsafeModelFormat(expandedPath)
84+
}
85+
7686
let configuration = ModelConfiguration(directory: directory)
7787

7888
modelContainer = try await LLMModelFactory.shared.loadContainer(
@@ -303,6 +313,23 @@ actor MLXService {
303313

304314
// MARK: - Private Helpers
305315

316+
/// Validates that a model directory contains only SafeTensors weights.
317+
/// Returns true if safe, false if PyTorch pickle files (.bin, .pt) are present
318+
/// without any corresponding .safetensors files.
319+
private func isSafeTensorsModel(at url: URL) -> Bool {
320+
guard let contents = try? FileManager.default.contentsOfDirectory(
321+
at: url, includingPropertiesForKeys: nil
322+
) else { return false }
323+
324+
let hasSafeTensors = contents.contains { $0.pathExtension == "safetensors" }
325+
let hasPickle = contents.contains { $0.pathExtension == "bin" || $0.pathExtension == "pt" }
326+
327+
// Reject if pickle weights present without any safetensors counterpart
328+
if hasPickle && !hasSafeTensors { return false }
329+
// Require at least one safetensors file
330+
return hasSafeTensors
331+
}
332+
306333
/// Formats chat messages as a flat prompt string — used as fallback when
307334
/// the model's Jinja chat template is not supported by swift-jinja.
308335
private func formatMessagesAsPrompt(_ messages: [Message]) -> String {
@@ -335,6 +362,15 @@ actor MLXService {
335362
return nil
336363
}
337364

365+
// Only surface SafeTensors models in the UI
366+
guard isSafeTensorsModel(at: url) else {
367+
await SecureLogger.shared.warning(
368+
"Skipping non-SafeTensors model: \(url.lastPathComponent)",
369+
category: "MLXService"
370+
)
371+
return nil
372+
}
373+
338374
var totalSize: Int64 = 0
339375
for file in contents {
340376
if let attributes = try? fileManager.attributesOfItem(atPath: file.path),
@@ -364,6 +400,7 @@ enum MLXServiceError: LocalizedError {
364400
case inferenceInProgress
365401
case invalidParameters
366402
case generationFailed(String)
403+
case unsafeModelFormat(String)
367404

368405
var errorDescription: String? {
369406
switch self {
@@ -381,6 +418,8 @@ enum MLXServiceError: LocalizedError {
381418
return "The generation parameters are invalid"
382419
case .generationFailed(let message):
383420
return "Text generation failed: \(message)"
421+
case .unsafeModelFormat(let path):
422+
return "Unsafe model format rejected: \(path)\n\nMLX Code only loads SafeTensors (.safetensors) models. PyTorch pickle files (.bin, .pt) are not permitted."
384423
}
385424
}
386425
}

0 commit comments

Comments
 (0)