-
Notifications
You must be signed in to change notification settings - Fork 79
Open
Description
I encounter this error when I try to load this RL model with the path. The context is that the model is a basic chess game player. I trained the model with this python code
import gym
from gym import spaces
import numpy as np
import chess
from stable_baselines3 import PPO
import torch
import onnx
# Custom Chess Environment
class ChessEnv(gym.Env):
def __init__(self):
super(ChessEnv, self).__init__()
self.board = chess.Board()
self.action_space = spaces.Discrete(4672) # Total possible moves in chess
self.observation_space = spaces.Box(low=0, high=1, shape=(8, 8, 12), dtype=np.float32)
def reset(self):
self.board.reset()
return self._get_observation() # Return the numerical representation of the board
def step(self, action):
legal_moves = list(self.board.legal_moves)
if action >= len(legal_moves):
action = action % len(legal_moves) # Map action to a valid legal move
self.board.push(legal_moves[action])
done = self.board.is_game_over()
reward = 1 if self.board.is_checkmate() else 0 # Reward for checkmate (you can adjust logic here)
return self._get_observation(), reward, done, {}
def render(self, mode="human"):
print(self.board) # Print the board for a human-readable representation
def _get_observation(self):
# Convert the board to a numerical representation
board_state = np.zeros((8, 8, 12), dtype=np.float32)
piece_map = self.board.piece_map()
for square, piece in piece_map.items():
piece_type = piece.piece_type - 1
color = int(piece.color)
row, col = divmod(square, 8)
board_state[row, col, piece_type + 6 * color] = 1
return board_state
# Create the environment
env = ChessEnv()
# Define the RL model using PPO
model = PPO("MlpPolicy", env, verbose=1)
# Train the model
games = 5 # You may want to increase this
model.learn(total_timesteps=games) # Increase timesteps for better training
# Save the trained model
model.save("chess_rl_bot")
# Load the trained model (for later use)
model = PPO.load("chess_rl_bot")
# Extract the policy (neural network)
policy = model.policy
# Create a dummy input matching the model’s observation space
dummy_input = torch.zeros(1, *policy.observation_space.shape) # Adjust shape if needed
# Export the model to ONNX format with a specific opset version
onnx_model_path = "ppo_chess_model.onnx"
torch.onnx.export(policy, dummy_input, onnx_model_path, opset_version=9) # Use opset_version=9
# Save the ONNX model
print(f"ONNX model saved to {onnx_model_path}")
And this is the Unity file I have so far that whenI try to run, throws the error
using UnityEngine;
using Unity.Barracuda;
using System.IO;
public class White : Bot {
}
public class WhiteState : BotState
{
private Model model;
private IWorker worker;
private string modelFilePath = "C:/Users/cruzmart/Daniel/web/Chess-Unity/Assets/Model/ppo_chess_model.onnx"; // Update this path to match your model file location
private string outputTensorName = "action"; // Replace with your model's output tensor name
public WhiteState(string playerName, bool isWhite) : base(playerName, isWhite)
{
// Load the model from the file path and initialize a worker to run inference
model = ModelLoader.Load(File.ReadAllBytes(modelFilePath));
worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model);
}
public WhiteState(WhiteState original) : base(original) { }
public override PlayerState Clone() => new WhiteState(this);
public override Vector2Int GetMove()
{
// Get the observation from the game (this will be a numerical representation of the board state)
float[] observation = GetCurrentBoardObservation();
// Create a tensor from the observation array
Tensor inputTensor = new Tensor(1, 8, 8, 12, observation);
worker.Execute(inputTensor);
// Extract the action from the output tensor
Tensor outputTensor = worker.PeekOutput(outputTensorName);
// Convert the model's output to a move (this depends on how the model is structured)
int action = outputTensor.ArgMax()[0]; // Assuming the action is a discrete move, you may need to adjust this
Vector2Int move = ConvertActionToMove(action);
// Clean up tensors
inputTensor.Dispose();
outputTensor.Dispose();
return move;
}
Vector2Int ConvertActionToMove(int action)
{
// Convert the action (a move number) to a chess move
int fromSquare = action / 64;
int toSquare = action % 64;
return new Vector2Int(fromSquare, toSquare);
}
private float[] GetCurrentBoardObservation()
{
// Convert the board state to a 768-element array (8x8x12)
float[] observation = new float[768];
char[,] board = CurrentGame.StringBoard();
for (int y = 0; y < 8; y++)
{
for (int x = 0; x < 8; x++)
{
char piece = board[y, x];
int channel = GetChannelForPiece(piece);
if (channel != -1)
{
observation[y * 96 + x * 12 + channel] = 1f;
}
}
}
return observation;
}
private int GetChannelForPiece(char piece)
{
switch (piece)
{
case 'P': return 0; // White Pawn
case 'N': return 1; // White Knight
case 'B': return 2; // White Bishop
case 'R': return 3; // White Rook
case 'Q': return 4; // White Queen
case 'K': return 5; // White King
case 'p': return 6; // Black Pawn
case 'n': return 7; // Black Knight
case 'b': return 8; // Black Bishop
case 'r': return 9; // Black Rook
case 'q': return 10; // Black Queen
case 'k': return 11; // Black King
default: return -1; // Empty square
}
}
void OnDestroy()
{
// Clean up the worker when done
worker.Dispose();
}
}
Anyone can help with this? This is my first so I apologise in advance for any issues with it.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels