-
Notifications
You must be signed in to change notification settings - Fork 1
Description
Summary
Implement neural network inference for MetalFish's MCTS search that uses Lc0-format network weights (.pb files) and produces identical results to Lc0 for the same positions. The implementation should be heavily optimized for Apple Silicon's unified memory architecture.
Background
MetalFish currently has a working MCTS implementation (src/mcts/thread_safe_mcts.cpp) that uses NNUE evaluation. To achieve stronger play, we need to integrate transformer-based neural network evaluation similar to Lc0 (Leela Chess Zero).
We have:
- Network weights:
networks/BT4-1024x15x32h-swa-6147500.pb(365MB transformer network) - Reference implementation: the official Lc0 source code
- Target platform: Apple Silicon Macs (M1/M2/M3/M4) with Metal GPU acceleration
Implementation Strategy: Copy from Lc0
Lc0 is open source (GPL-3.0). You are encouraged to directly copy entire code files from [lc0](https://github.com/LeelaChessZero/lc0) using cp or mv commands. Do not rewrite from scratch what already exists.
Copyright Header Requirement
All copied files MUST have their copyright headers replaced with the MetalFish header:
/*
MetalFish - A GPU-accelerated UCI chess engine
Copyright (C) 2025 Nripesh Niketan
Licensed under GPL-3.0
*/Namespace and Naming Requirements (CRITICAL)
There must be NO mention of "lc0", "lczero", "leela", or "Leela" anywhere in the final code.
This includes:
- ❌ Namespaces:
lczero::→ ✅MetalFish::orMF:: - ❌ Class names:
Lc0Network→ ✅NeuralNetwork - ❌ Function names:
lc0_encode()→ ✅encode_position() - ❌ Variable names:
lc0_weights→ ✅nn_weights - ❌ Comments:
// Lc0-style encoding→ ✅// Position encoding - ❌ File names:
lc0_backend.cpp→ ✅nn_backend.cpp - ❌ Macros:
LC0_API→ ✅METALFISH_APIor remove - ❌ Include guards:
LC0_NEURAL_H→ ✅METALFISH_NN_H
Example transformation:
// BEFORE (Lc0 original)
namespace lczero {
class Lc0Network {
void Lc0Encode(const lczero::Position& pos);
};
} // namespace lczero
// AFTER (MetalFish)
namespace MetalFish {
namespace NN {
class Network {
void encode(const Position& pos);
};
} // namespace NN
} // namespace MetalFishDirectory Guidelines
DO:
- Copy files into our existing directory structure (
src/nn/,src/mcts/,src/gpu/) - Create sensible new directories if needed (e.g.,
src/nn/,src/nn/metal/) - Maintain a clean, professional codebase structure
DO NOT:
- Create directories like
lc0_implementation/,lc0_copy/,external_lc0/ - Keep Lc0-specific directory structures that don't fit our layout
- Leave Lc0 copyright headers in any file
- Leave any
lczero::namespace references
Example Workflow
git clone https://github.com/LeelaChessZero/lc0 # clone into reference directory locally
# Copy protobuf definitions
cp reference/lc0/src/neural/network.h src/nn/network.h
cp reference/lc0/src/neural/encoder.cc src/nn/encoder.cpp
# Copy Metal backend
cp reference/lc0/src/neural/metal/*.mm src/nn/metal/
# Then for EACH copied file:
# 1. Replace copyright header with MetalFish header
# 2. Change namespace lczero:: to MetalFish::NN::
# 3. Rename any Lc0-prefixed classes/functions
# 4. Update all comments to remove Lc0 references
# 5. Update include guardsFiles to Consider Copying
From reference/lc0/src/:
neural/network.h- Network interfaceneural/encoder.cc- Position encoding (112 planes)neural/writer.cc- Protobuf parsingneural/metal/- Metal backend (MPSGraph)mcts/node.cc- MCTS node structuremcts/search.cc- Search algorithmschess/board.cc- Board representation (for encoding compatibility)utils/weights_adapter.cc- Weight loading utilities
Requirements
1. Neural Network Components
Create src/nn/ directory with:
-
Weight Loading (
loader.h,loader.cpp)- Parse Lc0 protobuf format (
.pband.pb.gz) - Extract transformer weights, policy head, value head, moves-left head
- Support for BT4 (Big Transformer 4) architecture
- Parse Lc0 protobuf format (
-
Position Encoding (
encoder.h,encoder.cpp)- Encode chess positions into 112-plane input format (identical to Lc0)
- 8 history positions × 13 planes + 8 auxiliary planes
- Handle board flipping for black-to-move positions
- Support canonical format transformations
-
Policy Tables (
policy_tables.h,policy_tables.cpp)- Map between UCI moves and neural network policy indices
- 1858 policy outputs for standard chess
- Attention policy map for transformer networks
-
Metal Backend (
metal/ornn/metal/)- Use MPSGraph for transformer inference
- Optimize for unified memory (zero-copy between CPU/GPU)
- Support batch inference for MCTS
2. MCTS Integration
Update src/mcts/:
-
NNMCTSEvaluator (
nn_mcts_evaluator.h,nn_mcts_evaluator.cpp)- Bridge between MCTS and neural network
- Cache evaluations (transposition table)
- Apply policy to MCTS edges
- Return Q value (win-draw-loss) for backpropagation
-
ThreadSafeMCTS updates
- Use NN policy for move ordering
- Use NN value for leaf evaluation
- Match Lc0's PUCT formula exactly
3. Verification Test Suite
Create tests/test_nn_comparison.cpp that verifies:
// Standard benchmark positions - MUST return identical moves
const std::vector<std::string> kBenchmarkPositions = {
// Starting position
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
// Kiwipete - famous test position
"r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 10",
// Endgame positions
"8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 11",
// Complex middlegame
"4rrk1/pp1n3p/3q2pQ/2p1pb2/2PP4/2P3N1/P2B2PP/4RRK1 b - - 7 19",
// Tactical positions
"r3r1k1/2p2ppp/p1p1bn2/8/1q2P3/2NPQN2/PPP3PP/R4RK1 b - - 2 15",
"r1bbk1nr/pp3p1p/2n5/1N4p1/2Np1B2/8/PPP2PPP/2KR1B1R w kq - 0 13",
"r1bq1rk1/ppp1nppp/4n3/3p3Q/3P4/1BP1B3/PP1N2PP/R4RK1 w - - 1 16",
"4r1k1/r1q2ppp/ppp2n2/4P3/5Rb1/1N1BQ3/PPP3PP/R5K1 w - - 1 17",
// More complex positions
"2rqkb1r/ppp2p2/2npb1p1/1N1Nn2p/2P1PP2/8/PP2B1PP/R1BQK2R b KQ - 0 11",
"r1bq1r1k/b1p1npp1/p2p3p/1p6/3PP3/1B2NN2/PP3PPP/R2Q1RK1 w - - 1 16",
// Pawn endgames
"8/1p3pp1/7p/5P1P/2k3P1/8/2K2P2/8 w - - 0 1",
"8/pp2r1k1/2p1p3/3pP2p/1P1P1P1P/P5KR/8/8 w - - 0 1",
// Rook endgames
"5k2/7R/4P2p/5K2/p1r2P1p/8/8/8 b - - 0 1",
"6k1/6p1/P6p/r1N5/5p2/7P/1b3PP1/4R1K1 w - - 0 1",
// Queen vs pieces
"3q2k1/pb3p1p/4pbp1/2r5/PpN2N2/1P2P2P/5PP1/Q2R2K1 b - - 4 26",
};Test criteria (100% match required):
-
Raw NN Output Match
- Policy logits must match reference within floating-point tolerance (1e-5)
- WDL (Win/Draw/Loss) outputs must match exactly
- Q value must match exactly
-
MCTS Best Move Match
- With identical parameters (nodes=800, cpuct=1.745, etc.)
- MetalFish must select the same best move as reference for all 15 positions
- This validates the entire pipeline: encoding → inference → MCTS
Implementation Notes
Conventions to Match (from reference implementation)
-
Board Representation: Store the board from the side-to-move's perspective. When black is to move, the board is mirrored vertically.
-
Q Value Convention: Q values are stored from the perspective of "the player who just moved" to reach a node. During backpropagation, negate at each level.
-
Policy Encoding: Use
MoveToNNIndex()andMoveFromNNIndex()with the correct transform for the position. -
Edge Sorting: After applying NN policy, sort edges by policy value (descending).
Apple Silicon Optimizations
- Use unified memory for zero-copy GPU access
- Leverage Metal Performance Shaders Graph (MPSGraph) for transformer ops
- 128-byte cache line alignment for M-series chips
Final Integration
The NN implementation must integrate cleanly with our existing MCTS in src/mcts/thread_safe_mcts.cpp. The evaluator should be a drop-in replacement that:
- Takes a
Positionfrom our codebase - Returns policy probabilities for legal moves
- Returns a Q value (or WDL) for backpropagation
- Supports batched evaluation for efficiency
Acceptance Criteria
- All 15 benchmark positions return identical best moves to reference
- Raw NN outputs match reference within tolerance
- No memory leaks or crashes during extended testing
- Performance: Extremely optimized for Apple Silicon chips with unified memory
- Clean integration with existing MCTS code
- All files have MetalFish copyright headers
- No directories named after external projects
- No
lczero::orlc0namespace/naming anywhere in code - No comments referencing Lc0/Leela - this is MetalFish
References
- Reference implementation source available at lc0
- Network format documentation: https://lczero.org/dev/wiki/technical-explanation-of-leela-chess-zero/
- BT4 architecture: Big Transformer with 1024 embedding, 15 layers, 32 attention heads
Test Command
cd metalfish/build
cmake --build . --target metalfish_tests
./metalfish_tests nn_comparisonExpected output:
=== Neural Network Comparison Tests ===
Position 1/15: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
Reference best move: e2e4
MetalFish best move: e2e4
✓ MATCH
Position 2/15: r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 10
Reference best move: e2a6
MetalFish best move: e2a6
✓ MATCH
... (all 15 positions)
Results: 15/15 positions match (100%)