From a041e5a756d44435d86bfd2ff076d08b903c097c Mon Sep 17 00:00:00 2001 From: RainRat Date: Sat, 20 Sep 2025 14:24:01 -0700 Subject: [PATCH] Protect NNUE header parsing against oversized descriptions --- src/nnue/evaluate_nnue.cpp | 48 ++++++++++++++++++++++- test.py | 78 +++++++++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/src/nnue/evaluate_nnue.cpp b/src/nnue/evaluate_nnue.cpp index c13fe535f..0c85a78a1 100644 --- a/src/nnue/evaluate_nnue.cpp +++ b/src/nnue/evaluate_nnue.cpp @@ -90,6 +90,8 @@ namespace Stockfish::Eval::NNUE { Detail::initialize(network[i]); } + constexpr std::uint32_t MaxDescriptionLength = 4096; + // Read network header bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc) { @@ -99,8 +101,52 @@ namespace Stockfish::Eval::NNUE { *hashValue = read_little_endian(stream); size = read_little_endian(stream); if (!stream || version != Version) return false; + if (size > MaxDescriptionLength) return false; + + std::uint64_t remaining = 0; + bool hasRemaining = false; + std::streambuf* buffer = stream.rdbuf(); + + if (!buffer) + return false; + + const std::istream::pos_type currentPos = stream.tellg(); + + if (currentPos != std::istream::pos_type(-1)) + { + stream.seekg(0, std::ios::end); + const std::istream::pos_type endPos = stream.tellg(); + + if (endPos != std::istream::pos_type(-1) && endPos >= currentPos) + { + remaining = static_cast(endPos - currentPos); + hasRemaining = true; + } + + stream.seekg(currentPos, std::ios::beg); + if (!stream) + return false; + } + + if (!hasRemaining) + { + const std::streamsize available = buffer->in_avail(); + + if (available >= 0) + { + remaining = static_cast(available); + hasRemaining = true; + } + } + + if (!hasRemaining) + return false; + + if (remaining < size) + return false; + desc->resize(size); - stream.read(&(*desc)[0], size); + stream.read(desc->data(), static_cast(size)); return !stream.fail(); } diff --git a/test.py b/test.py index ca5d39106..0c56947c0 100644 --- a/test.py +++ b/test.py @@ -1,6 +1,12 @@ # -*- coding: utf-8 -*- import faulthandler +import os +import pathlib +import shutil +import struct +import subprocess +import tempfile import unittest import pyffish as sf @@ -1282,7 +1288,77 @@ def test_get_fog_fen(self): fen = "rnbqkbnr/p1p2ppp/8/Pp1pp3/4P3/8/1PPP1PPP/RNBQKBNR w KQkq b6 0 1" result = sf.get_fog_fen(fen, "fogofwar") self.assertEqual(result, "********/********/2******/Pp*p***1/4P3/4*3/1PPP1PPP/RNBQKBNR w KQkq b6 0 1") - + + +class NnueHeaderValidationTests(unittest.TestCase): + VERSION = 0x7AF32F20 + HASH_VALUE = 1007697522 + MAX_DESCRIPTION_LENGTH = 4096 + + @classmethod + def setUpClass(cls): + cls.root_dir = pathlib.Path(__file__).resolve().parent + binary_name = "stockfish.exe" if os.name == "nt" else "stockfish" + cls.stockfish_path = cls.root_dir / "src" / binary_name + + if not cls.stockfish_path.exists(): + if shutil.which("make") is None: + raise unittest.SkipTest("make is required to build the Stockfish binary") + + subprocess.run(["make", "build"], cwd=cls.root_dir / "src", check=True) + + if not cls.stockfish_path.exists(): + raise unittest.SkipTest("Stockfish binary is not available for NNUE header tests") + + def run_stockfish_with_net(self, payload: bytes) -> subprocess.CompletedProcess: + with tempfile.NamedTemporaryFile(prefix="chess-", suffix=".nnue", delete=False) as tmp: + tmp.write(payload) + tmp_path = pathlib.Path(tmp.name) + + try: + commands = "\n".join( + [ + "uci", + "setoption name UCI_Variant value chess", + "setoption name Use NNUE value true", + f"setoption name EvalFile value {tmp_path.as_posix()}", + "isready", + "go depth 1", + "quit", + ] + ) + "\n" + + result = subprocess.run( + [str(self.stockfish_path)], + input=commands, + text=True, + capture_output=True, + timeout=10, + ) + finally: + tmp_path.unlink(missing_ok=True) + + return result + + def assert_header_failure(self, payload: bytes, message: str) -> None: + result = self.run_stockfish_with_net(payload) + self.assertNotEqual( + result.returncode, + 0, + f"{message}: expected a non-zero exit code. stdout={result.stdout!r} stderr={result.stderr!r}", + ) + self.assertIn("was not loaded successfully", result.stdout, message) + + def test_oversized_description_rejected(self): + oversized_size = self.MAX_DESCRIPTION_LENGTH + 1 + payload = struct.pack("