Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/nnue/evaluate_nnue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ namespace Stockfish::Eval::NNUE {
Detail::initialize(network[i]);
}

constexpr std::uint32_t MaxDescriptionLength = 4096;

// Read network header
bool read_header(std::istream& stream, std::uint32_t* hashValue, std::string* desc)
{
Expand All @@ -99,8 +101,52 @@ namespace Stockfish::Eval::NNUE {
*hashValue = read_little_endian<std::uint32_t>(stream);
size = read_little_endian<std::uint32_t>(stream);
if (!stream || version != Version) return false;
if (size > MaxDescriptionLength) return false;

std::uint64_t remaining = 0;
bool hasRemaining = false;
std::streambuf* buffer = stream.rdbuf();

if (!buffer)
return false;

const std::istream::pos_type currentPos = stream.tellg();

if (currentPos != std::istream::pos_type(-1))
{
stream.seekg(0, std::ios::end);
const std::istream::pos_type endPos = stream.tellg();

if (endPos != std::istream::pos_type(-1) && endPos >= currentPos)
{
remaining = static_cast<std::uint64_t>(endPos - currentPos);
hasRemaining = true;
}

stream.seekg(currentPos, std::ios::beg);
if (!stream)
return false;
}

if (!hasRemaining)
{
const std::streamsize available = buffer->in_avail();

if (available >= 0)
{
remaining = static_cast<std::uint64_t>(available);
hasRemaining = true;
}
}

if (!hasRemaining)
return false;

if (remaining < size)
return false;

desc->resize(size);
stream.read(&(*desc)[0], size);
stream.read(desc->data(), static_cast<std::streamsize>(size));
return !stream.fail();
}

Expand Down
78 changes: 77 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# -*- coding: utf-8 -*-

import faulthandler
import os
import pathlib
import shutil
import struct
import subprocess
import tempfile
import unittest
import pyffish as sf

Expand Down Expand Up @@ -1282,7 +1288,77 @@ def test_get_fog_fen(self):
fen = "rnbqkbnr/p1p2ppp/8/Pp1pp3/4P3/8/1PPP1PPP/RNBQKBNR w KQkq b6 0 1"
result = sf.get_fog_fen(fen, "fogofwar")
self.assertEqual(result, "********/********/2******/Pp*p***1/4P3/4*3/1PPP1PPP/RNBQKBNR w KQkq b6 0 1")



class NnueHeaderValidationTests(unittest.TestCase):
VERSION = 0x7AF32F20
HASH_VALUE = 1007697522
MAX_DESCRIPTION_LENGTH = 4096

@classmethod
def setUpClass(cls):
cls.root_dir = pathlib.Path(__file__).resolve().parent
binary_name = "stockfish.exe" if os.name == "nt" else "stockfish"
cls.stockfish_path = cls.root_dir / "src" / binary_name

if not cls.stockfish_path.exists():
if shutil.which("make") is None:
raise unittest.SkipTest("make is required to build the Stockfish binary")

subprocess.run(["make", "build"], cwd=cls.root_dir / "src", check=True)

if not cls.stockfish_path.exists():
raise unittest.SkipTest("Stockfish binary is not available for NNUE header tests")

def run_stockfish_with_net(self, payload: bytes) -> subprocess.CompletedProcess:
with tempfile.NamedTemporaryFile(prefix="chess-", suffix=".nnue", delete=False) as tmp:
tmp.write(payload)
tmp_path = pathlib.Path(tmp.name)

try:
commands = "\n".join(
[
"uci",
"setoption name UCI_Variant value chess",
"setoption name Use NNUE value true",
f"setoption name EvalFile value {tmp_path.as_posix()}",
"isready",
"go depth 1",
"quit",
]
) + "\n"

result = subprocess.run(
[str(self.stockfish_path)],
input=commands,
text=True,
capture_output=True,
timeout=10,
)
finally:
tmp_path.unlink(missing_ok=True)

return result

def assert_header_failure(self, payload: bytes, message: str) -> None:
result = self.run_stockfish_with_net(payload)
self.assertNotEqual(
result.returncode,
0,
f"{message}: expected a non-zero exit code. stdout={result.stdout!r} stderr={result.stderr!r}",
)
self.assertIn("was not loaded successfully", result.stdout, message)

def test_oversized_description_rejected(self):
oversized_size = self.MAX_DESCRIPTION_LENGTH + 1
payload = struct.pack("<III", self.VERSION, self.HASH_VALUE, oversized_size)
self.assert_header_failure(payload, "Oversized network description should be rejected")

def test_truncated_description_rejected(self):
declared_size = 16
payload = struct.pack("<III", self.VERSION, self.HASH_VALUE, declared_size) + b"abcd"
self.assert_header_failure(payload, "Truncated network description should be rejected")


if __name__ == '__main__':
unittest.main(verbosity=2)
Loading