Skip to content

[Feature] Implement Lc0-compatible neural network inference for MCTS with Metal backend #14

@NripeshN

Description

@NripeshN

Summary

Implement neural network inference for MetalFish's MCTS search that uses Lc0-format network weights (.pb files) and produces identical results to Lc0 for the same positions. The implementation should be heavily optimized for Apple Silicon's unified memory architecture.

Background

MetalFish currently has a working MCTS implementation (src/mcts/thread_safe_mcts.cpp) that uses NNUE evaluation. To achieve stronger play, we need to integrate transformer-based neural network evaluation similar to Lc0 (Leela Chess Zero).

We have:

  • Network weights: networks/BT4-1024x15x32h-swa-6147500.pb (365MB transformer network)
  • Reference implementation: the official Lc0 source code
  • Target platform: Apple Silicon Macs (M1/M2/M3/M4) with Metal GPU acceleration

Implementation Strategy: Copy from Lc0

Lc0 is open source (GPL-3.0). You are encouraged to directly copy entire code files from [lc0](https://github.com/LeelaChessZero/lc0) using cp or mv commands. Do not rewrite from scratch what already exists.

Copyright Header Requirement

All copied files MUST have their copyright headers replaced with the MetalFish header:

/*
  MetalFish - A GPU-accelerated UCI chess engine
  Copyright (C) 2025 Nripesh Niketan

  Licensed under GPL-3.0
*/

Namespace and Naming Requirements (CRITICAL)

There must be NO mention of "lc0", "lczero", "leela", or "Leela" anywhere in the final code.

This includes:

  • Namespaces: lczero:: → ✅ MetalFish:: or MF::
  • Class names: Lc0Network → ✅ NeuralNetwork
  • Function names: lc0_encode() → ✅ encode_position()
  • Variable names: lc0_weights → ✅ nn_weights
  • Comments: // Lc0-style encoding → ✅ // Position encoding
  • File names: lc0_backend.cpp → ✅ nn_backend.cpp
  • Macros: LC0_API → ✅ METALFISH_API or remove
  • Include guards: LC0_NEURAL_H → ✅ METALFISH_NN_H

Example transformation:

// BEFORE (Lc0 original)
namespace lczero {
class Lc0Network {
  void Lc0Encode(const lczero::Position& pos);
};
}  // namespace lczero

// AFTER (MetalFish)
namespace MetalFish {
namespace NN {
class Network {
  void encode(const Position& pos);
};
}  // namespace NN
}  // namespace MetalFish

Directory Guidelines

DO:

  • Copy files into our existing directory structure (src/nn/, src/mcts/, src/gpu/)
  • Create sensible new directories if needed (e.g., src/nn/, src/nn/metal/)
  • Maintain a clean, professional codebase structure

DO NOT:

  • Create directories like lc0_implementation/, lc0_copy/, external_lc0/
  • Keep Lc0-specific directory structures that don't fit our layout
  • Leave Lc0 copyright headers in any file
  • Leave any lczero:: namespace references

Example Workflow

git clone https://github.com/LeelaChessZero/lc0 # clone into reference directory locally

# Copy protobuf definitions
cp reference/lc0/src/neural/network.h src/nn/network.h
cp reference/lc0/src/neural/encoder.cc src/nn/encoder.cpp

# Copy Metal backend
cp reference/lc0/src/neural/metal/*.mm src/nn/metal/

# Then for EACH copied file:
# 1. Replace copyright header with MetalFish header
# 2. Change namespace lczero:: to MetalFish::NN::
# 3. Rename any Lc0-prefixed classes/functions
# 4. Update all comments to remove Lc0 references
# 5. Update include guards

Files to Consider Copying

From reference/lc0/src/:

  • neural/network.h - Network interface
  • neural/encoder.cc - Position encoding (112 planes)
  • neural/writer.cc - Protobuf parsing
  • neural/metal/ - Metal backend (MPSGraph)
  • mcts/node.cc - MCTS node structure
  • mcts/search.cc - Search algorithms
  • chess/board.cc - Board representation (for encoding compatibility)
  • utils/weights_adapter.cc - Weight loading utilities

Requirements

1. Neural Network Components

Create src/nn/ directory with:

  • Weight Loading (loader.h, loader.cpp)

    • Parse Lc0 protobuf format (.pb and .pb.gz)
    • Extract transformer weights, policy head, value head, moves-left head
    • Support for BT4 (Big Transformer 4) architecture
  • Position Encoding (encoder.h, encoder.cpp)

    • Encode chess positions into 112-plane input format (identical to Lc0)
    • 8 history positions × 13 planes + 8 auxiliary planes
    • Handle board flipping for black-to-move positions
    • Support canonical format transformations
  • Policy Tables (policy_tables.h, policy_tables.cpp)

    • Map between UCI moves and neural network policy indices
    • 1858 policy outputs for standard chess
    • Attention policy map for transformer networks
  • Metal Backend (metal/ or nn/metal/)

    • Use MPSGraph for transformer inference
    • Optimize for unified memory (zero-copy between CPU/GPU)
    • Support batch inference for MCTS

2. MCTS Integration

Update src/mcts/:

  • NNMCTSEvaluator (nn_mcts_evaluator.h, nn_mcts_evaluator.cpp)

    • Bridge between MCTS and neural network
    • Cache evaluations (transposition table)
    • Apply policy to MCTS edges
    • Return Q value (win-draw-loss) for backpropagation
  • ThreadSafeMCTS updates

    • Use NN policy for move ordering
    • Use NN value for leaf evaluation
    • Match Lc0's PUCT formula exactly

3. Verification Test Suite

Create tests/test_nn_comparison.cpp that verifies:

// Standard benchmark positions - MUST return identical moves
const std::vector<std::string> kBenchmarkPositions = {
    // Starting position
    "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
    
    // Kiwipete - famous test position
    "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 10",
    
    // Endgame positions
    "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 11",
    
    // Complex middlegame
    "4rrk1/pp1n3p/3q2pQ/2p1pb2/2PP4/2P3N1/P2B2PP/4RRK1 b - - 7 19",
    
    // Tactical positions
    "r3r1k1/2p2ppp/p1p1bn2/8/1q2P3/2NPQN2/PPP3PP/R4RK1 b - - 2 15",
    "r1bbk1nr/pp3p1p/2n5/1N4p1/2Np1B2/8/PPP2PPP/2KR1B1R w kq - 0 13",
    "r1bq1rk1/ppp1nppp/4n3/3p3Q/3P4/1BP1B3/PP1N2PP/R4RK1 w - - 1 16",
    "4r1k1/r1q2ppp/ppp2n2/4P3/5Rb1/1N1BQ3/PPP3PP/R5K1 w - - 1 17",
    
    // More complex positions
    "2rqkb1r/ppp2p2/2npb1p1/1N1Nn2p/2P1PP2/8/PP2B1PP/R1BQK2R b KQ - 0 11",
    "r1bq1r1k/b1p1npp1/p2p3p/1p6/3PP3/1B2NN2/PP3PPP/R2Q1RK1 w - - 1 16",
    
    // Pawn endgames
    "8/1p3pp1/7p/5P1P/2k3P1/8/2K2P2/8 w - - 0 1",
    "8/pp2r1k1/2p1p3/3pP2p/1P1P1P1P/P5KR/8/8 w - - 0 1",
    
    // Rook endgames
    "5k2/7R/4P2p/5K2/p1r2P1p/8/8/8 b - - 0 1",
    "6k1/6p1/P6p/r1N5/5p2/7P/1b3PP1/4R1K1 w - - 0 1",
    
    // Queen vs pieces
    "3q2k1/pb3p1p/4pbp1/2r5/PpN2N2/1P2P2P/5PP1/Q2R2K1 b - - 4 26",
};

Test criteria (100% match required):

  1. Raw NN Output Match

    • Policy logits must match reference within floating-point tolerance (1e-5)
    • WDL (Win/Draw/Loss) outputs must match exactly
    • Q value must match exactly
  2. MCTS Best Move Match

    • With identical parameters (nodes=800, cpuct=1.745, etc.)
    • MetalFish must select the same best move as reference for all 15 positions
    • This validates the entire pipeline: encoding → inference → MCTS

Implementation Notes

Conventions to Match (from reference implementation)

  1. Board Representation: Store the board from the side-to-move's perspective. When black is to move, the board is mirrored vertically.

  2. Q Value Convention: Q values are stored from the perspective of "the player who just moved" to reach a node. During backpropagation, negate at each level.

  3. Policy Encoding: Use MoveToNNIndex() and MoveFromNNIndex() with the correct transform for the position.

  4. Edge Sorting: After applying NN policy, sort edges by policy value (descending).

Apple Silicon Optimizations

  • Use unified memory for zero-copy GPU access
  • Leverage Metal Performance Shaders Graph (MPSGraph) for transformer ops
  • 128-byte cache line alignment for M-series chips

Final Integration

The NN implementation must integrate cleanly with our existing MCTS in src/mcts/thread_safe_mcts.cpp. The evaluator should be a drop-in replacement that:

  1. Takes a Position from our codebase
  2. Returns policy probabilities for legal moves
  3. Returns a Q value (or WDL) for backpropagation
  4. Supports batched evaluation for efficiency

Acceptance Criteria

  • All 15 benchmark positions return identical best moves to reference
  • Raw NN outputs match reference within tolerance
  • No memory leaks or crashes during extended testing
  • Performance: Extremely optimized for Apple Silicon chips with unified memory
  • Clean integration with existing MCTS code
  • All files have MetalFish copyright headers
  • No directories named after external projects
  • No lczero:: or lc0 namespace/naming anywhere in code
  • No comments referencing Lc0/Leela - this is MetalFish

References

Test Command

cd metalfish/build
cmake --build . --target metalfish_tests
./metalfish_tests nn_comparison

Expected output:

=== Neural Network Comparison Tests ===
Position 1/15: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1
  Reference best move: e2e4
  MetalFish best move: e2e4
  ✓ MATCH

Position 2/15: r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 10
  Reference best move: e2a6
  MetalFish best move: e2a6
  ✓ MATCH

... (all 15 positions)

Results: 15/15 positions match (100%)

Metadata

Metadata

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions