diff --git a/poetry.lock b/poetry.lock index f950a52..5da89f5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -145,6 +145,22 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope.interface"] tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"] +[[package]] +name = "autopep8" +version = "2.0.1" +description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "autopep8-2.0.1-py2.py3-none-any.whl", hash = "sha256:be5bc98c33515b67475420b7b1feafc8d32c1a69862498eda4983b45bffd2687"}, + {file = "autopep8-2.0.1.tar.gz", hash = "sha256:d27a8929d8dcd21c0f4b3859d2d07c6c25273727b98afc984c039df0f0d86566"}, +] + +[package.dependencies] +pycodestyle = ">=2.10.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} + [[package]] name = "certifi" version = "2022.12.7" @@ -1437,6 +1453,18 @@ files = [ [package.dependencies] numpy = ">=1.16.6" +[[package]] +name = "pycodestyle" +version = "2.10.0" +description = "Python style guide checker" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pycodestyle-2.10.0-py2.py3-none-any.whl", hash = "sha256:8a4eaf0d0495c7395bdab3589ac2db602797d76207242c17d470186815706610"}, + {file = "pycodestyle-2.10.0.tar.gz", hash = "sha256:347187bdb476329d98f695c213d7295a846d1152ff4fe9bacb8a9590b8ee7053"}, +] + [[package]] name = "pydantic" version = "1.10.4" @@ -2633,4 +2661,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "16d8d643ba8f3ae9cceb9e66a69c0b73135e36db1d5adb983bc79d184e9304ee" +content-hash = "288ed66f3a7b8d6a1ee831ab80db26cc68819224b4146a3e1691964f354aaaeb" diff --git a/pyproject.toml b/pyproject.toml index d1bc391..64dfeb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ torch = "^1.13.0" [tool.poetry.group.dev.dependencies] pytest = "^7.2.1" mypy = "^0.991" +autopep8 = "^2.0.1" [build-system] requires = ["poetry-core"] diff --git a/soft_optim/game.py b/soft_optim/game.py index 96fe705..cb064f6 100644 --- a/soft_optim/game.py +++ b/soft_optim/game.py @@ -1,190 +1,323 @@ """Tic Tac Toe Game implementation""" import re -from typing import Tuple, List +from enum import Enum +from typing import List, Tuple, Optional import numpy as np -class TicTacToeBoard: - """Tic Tac Toe Board +class Player(Enum): + """Tic Tac Toe Player""" - Contains the board state at a single point in time, i.e. 9 squares with 3 - possible values [-,x,o]. - """ + X = "x" + O = "o" + + +class Square(Player): + """Tic Tac Toe Square""" + + # Note this extends player, so it already has X and O (and thus just needs + # the third option of BLANK). + X = "x" + O = "o" + EMPTY = "-" - board_state: np.ndarray - """Board state as a numpy array of shape (3,3) - Note we use integers to represent the state of each square, rather than - strings, which are defined below.""" +class Board: + """Tic Tac Toe Board""" - blank: int = 0 - """Integer representation of blank on the board (numpy array)""" + contains_illegal_move: bool = False + """Flag that the board contains an illegal move""" - x: int = 1 - """Integer representation of x on the board (numpy array)""" + number_columns: int = 3 + """Number of columns""" - o: int = 2 - """Integer representation of o on the board (numpy array)""" + number_rows: int = 3 + """Number of rows""" - def __init__(self, string=None): - # Initialise an empty board - self.board_state = np.full((3,3), self.blank, int) + number_joined_to_win: int = 3 + """Number in a row/column/diagonal to win""" - # Setup the mapping of strings to square values - self.map = {self.x:'x', self.o:'o', self.blank: '-'} + @property + def board_squares(self) -> List[List[Square]]: + """Board squares + + Nested list of squares, in the format [rows x columns]. + + >>> Board().board_squares + [[Square.BLANK, Square.BLANK, Square.BLANK], + [Square.BLANK, Square.BLANK, Square.BLANK], + [Square.BLANK, Square.BLANK, Square.BLANK]] + """ + return self._board_squares - # Parse a string representation of the board state, if given - if string is not None: - self.parse_str(string) + @board_squares.setter + def board_squares(self, value: List[List[Square]]) -> None: + """Board squares setter - def get_valid_moves(self): - ''' return a list of valid (i,j,player) moves ''' - # work out whose turn it is - num_x = np.sum(self.board_state == self.x) - num_o = np.sum(self.board_state == self.o) - if num_x == num_o: - turn = self.x - elif num_x == num_o + 1: - turn = self.o + Checks that the size of the board is correct when setting all squares at + once. + + Params: + value: Nested list in the format rows x columns. For example, the + middle square is `self.board[1][1]`. + """ + # Check number of rows + assert len(value) == self.number_rows, \ + ValueError("Invalid number of rows") + + # Check number of columns + for row in value: + assert len(row) == self.number_columns, \ + ValueError("Invalid number of columns") + + # Set the board + self._board_squares = value + + # Check the number of moves makes sense (note this must be done after + # the board is set) + x_moves = self.number_squares_played(Player.X) + o_moves = self.number_squares_played(Player.O) + assert x_moves == o_moves or x_moves == o_moves + 1, \ + ValueError("Invalid number of moves") + + # TODO: Check there isn't more than one winner + + def __init__( + self, + board_string: Optional[str] = None, + allow_illegal_moves: Optional[bool] = False + ): + """Initialise the board + + Args: + board_string: Board string to parse (of the format "- x o" on three + lines). + """ + # Settings (must be applied first) + self.allow_illegal_moves = allow_illegal_moves + + # Parse the board string if provided + if board_string: + self.board_squares = self._parse_board_string(board_string) + + # Default to an empty board else: - print("Invalid board state") - - # make list - l = [] - for i in range(3): - for j in range(3): - if self.board_state[i,j] == self.blank: - l.append((i,j,turn)) - return l - - - def make_move(self, i, j, player): - # check if legal - if i >= 3 or i < 0 or j >= 3 or j < 0: - print("Index out of bounds") - elif self.board_state[i,j] != self.blank: - print("Not a blank square") - elif player != self.x and player != self.o: - print("Invalid player") - - # modify board - self.board_state[i,j] = player - - - def check_win(self): - for player in [self.x, self.o]: - won = False - # check columns - if np.any(np.all(self.board_state == player, axis=0)): - won = True - # check rows - if np.any(np.all(self.board_state == player, axis=1)): - won = True - - # check diagonals - elif np.all(np.diag(self.board_state) == player) \ - or np.all(np.diag(np.fliplr(self.board_state))== player): - won = True - - if won: - return player - - return False - - def __str__(self): - b = self.board_state - out = '' - # convert state to string - for i in range(3): - for j in range(3): - out += f" {self.map[b[i,j]]}" - out += "\n" - return out - - def get_counts(self): - blank_count = np.sum(self.board_state == self.blank) - x_count = np.sum(self.board_state == self.x) - o_count = np.sum(self.board_state == self.o) - return blank_count, x_count, o_count - - def validate_state(self): - """ Check that the board state is valid, including that the number of - x's and o's is correct. + self.board_squares = \ + [[Square.EMPTY] * self.number_columns] * self.number_rows + + @property + def valid_moves(self) -> List[Tuple[int, int]]: + """Valid Moves + + Returns: + List of empty squares + + >>> Board().valid_moves + [[0, 0], [0, 1], [0, 2], + [1, 0], [1, 1], [1, 2], + [2, 0], [2, 1], [2, 2]] """ - blank_count, x_count, o_count = self.get_counts() - assert blank_count + x_count + o_count == 9 - assert (x_count == o_count) or (x_count == o_count + 1) + empty_squares: List[Tuple[int, int]] = [] - return True + for row_idx, row in enumerate(self.board_squares): + for col_idx, square in enumerate(row): + if square == Square.EMPTY: + empty_squares.append((row_idx, col_idx)) - def validate_str(self, string): - """ Check that the string representation of the board state is valid. """ - lines = string.strip("\n").split("\n") + return empty_squares - # Check there are 3 lines - if len(lines) != 3: - raise ValueError("Invalid game string - incorrect number of rows") + def play_move(self, row: int, col: int, player: Player): + """Play a move - # Check that each line is in the correct format - for line in lines: - if re.match(r"^[ ]*[-xo][ ][-xo][ ][-xo][ ]*$", line) is None: - raise ValueError("Invalid game string - invalid row format") + Args: + row: Row + col: Column + player: Player + """ + # Check the square exists + assert self.board_squares[row], \ + ValueError("Invalid move - row doesn't exist") + assert self.board_squares[row][col], \ + ValueError("Invalid move - column doesn't exist") + + # Check for illegal moves, if they're not allowed + if not self.allow_illegal_moves: + # Check the square is empty + assert self.board_squares[row][col] == Square.EMPTY, \ + ValueError("Invalid move - square is not empty") + + # Check it is the player's turn + assert self.current_turn_player == player, \ + ValueError("Invalid move - not this player's turn") + + # TODO: Check there isn't a winner already + + # Modify the board + self._board_squares[row][col] = player + + def get_col(self, col_idx: int) -> List[Square]: + column: List[Square] = [] + + for row_idx in len(self.board_squares): + square = self.board_squares[row_idx][col_idx] + column.append(square) + + return column + + @property + def winner(self) -> Optional[Player]: + # TODO: Allow for winning streak to be less than the full length - return True + # Check rows + for row in self.board_squares: + if len(set(row)) == 1 and row[0] != Square.EMPTY: + return row[0] - def parse_str(self, string: str): - # Ensure the state string is of the correct format - self.validate_str(string) + # Check columns + for col_idx in self.board_squares: + col = self.get_col(col_idx) + if len(set(col)) == 1 and col[0] != Square.EMPTY: + return col[0] + + # Check diagonals TODO + + # Otherwise return None + return None + + def __str__(self) -> str: + """String representation of the board""" + view: str = "" + + for row in self.board_squares: + for square in row: + square_view: str = str(square) + " " + view += square_view + view += "\n" + + return view + + def number_squares_played(self, player: Player) -> int: + """Number of squares played by a specific player + + Args: + player: Player + + Returns: + int: Number of squares the player has played + """ + count: int = 0 + + for row in self.board_squares: + for square in row: + if square == player: + count += 1 + + return count + + @property + def current_turn_player(self) -> Optional[Player]: + """The player whose turn it is currently""" + # Count the number of squares played + x_squares = self.number_squares_played(Player.X) + o_squares = self.number_squares_played(Player.O) + + # Allow for a full board + total_squares = self.number_rows * self.number_columns + if x_squares + o_squares == total_squares: + return None + + # Otherwise look at who has played most squares + if o_squares < x_squares: + return Player.O + else: + return Player.X + + # def validate_state(self): + # """ Check that the board state is valid, including that the number of + # x's and o's is correct. + # """ + # blank_count, x_count, o_count = self.get_counts() + # assert blank_count + x_count + o_count == 9 + # assert (x_count == o_count) or (x_count == o_count + 1) + + def _parse_board_string(self, board_string: str): + # Trim proceeding and trailing spaces & new lines + trimmed: str = board_string.strip() # Split the string into lines - lines = string.strip("\n").split("\n") + lines: List[str] = trimmed.split("\n") + + # Initialise the board squares + board_squares: List[List[Square]] = [] + + for line in lines: + row: List[Square] = [] + + trimmed_line: str = line.strip() + + for char in trimmed_line.split(" "): + # Check the character is one of the Square ENUM values + assert Square(char), ValueError( + "Invalid character in board string") + + row.append(Square(char)) + + board_squares.append(row) + + # Check the number of squares + + return board_squares # create a dict that does the opposite of self.map - rev_map = {v:k for k,v in self.map.items()} + rev_map = {v: k for k, v in self.map.items()} # iterate over it and convert to state for i, line in enumerate(lines): l = line.strip(" ").split(" ") for j, char in enumerate(l): - self.board_state[i,j] = rev_map[char] + self.board_squares[i, j] = rev_map[char] + class TicTacToeGame: """ Class to represent a game of Tic Tac Toe at multiple points in time. """ + def __init__(self, - init_board:str = None, - game_string:str = None, - check_valid_move:bool = True, - check_valid_state:bool = True ): + init_board: str = None, + game_string: str = None, + check_valid_move: bool = True, + check_valid_state: bool = True): - self.board = TicTacToeBoard(init_board) - self.history : List[TicTacToeBoard] = [self.board] + self.board = Board(init_board) + self.history: List[Board] = [self.board] # Choose which validity checks to perform - # Check that the board state has only 'x', 'o', and '-' characters + #  Check that the board state has only 'x', 'o', and '-' characters self.check_valid_string = True - # Check that on state change, only one blank piece has changed - self.check_valid_move = check_valid_move + #  Check that on state change, only one blank piece has changed + self.check_valid_move = check_valid_move # Check that the board state is valid, such that #x == #o or #x == #o + 1 self.check_valid_state = check_valid_state def reset(self): - self.board = TicTacToeBoard() + self.board = Board() self.history = [self.board] return self - def validate_move(self, old_state: TicTacToeBoard, new_state: TicTacToeBoard): + def validate_move(self, old_state: Board, new_state: Board): if self.check_valid_state: assert old_state.validate_state() assert new_state.validate_state() if self.check_valid_move: # Check that the new state has one more piece than the old state - assert np.sum(new_state.board_state != old_state.board_state) == 1 - assert np.sum(new_state.board_state == old_state.board_state) == 8 + assert np.sum(new_state.board_squares != + old_state.board_squares) == 1 + assert np.sum(new_state.board_squares == + old_state.board_squares) == 8 old_blank_count, _old_x_count, _old_o_count = old_state.get_counts() new_blank_count, _new_x_count, _new_o_count = new_state.get_counts() @@ -201,7 +334,7 @@ def add_state(self, board_string: str) -> Tuple[int, bool]: # Perform validity checks try: # Load the new state and save it to the history - self.board = TicTacToeBoard(board_string) + self.board = Board(board_string) self.history.append(self.board) if self.check_valid_state: @@ -220,12 +353,12 @@ def add_state(self, board_string: str) -> Tuple[int, bool]: print(e) return 0, False - # If valid, perform win checks + #  If valid, perform win checks outcome = self.board.check_win() return outcome, True def validate_game_string(self, game_string: str) -> Tuple[int, bool]: - self.board = TicTacToeBoard() + self.board = Board() self.history = [self.board] # split game string into board states @@ -246,8 +379,8 @@ def validate_game_string(self, game_string: str) -> Tuple[int, bool]: return final_outcome, True def evaluate_game_string(self, - game_string:str, - ) -> int: + game_string: str, + ) -> int: outcome, valid = self.validate_game_string(game_string) # If the game is not valid, return -1 @@ -262,14 +395,17 @@ def evaluate_game_string(self, return 0.0 + def generate_random_game(): - b = TicTacToeBoard() - game_state_history = [ str(b) ] + b = Board() + game_state_history = [str(b)] for t in range(9): + # Get the player + # Make a random valid move valid_moves = b.get_valid_moves() move = np.random.choice(len(valid_moves)) - b.make_move(*valid_moves[move]) - game_state_history.append( str(b) ) + b.play_move(*valid_moves[move]) + game_state_history.append(str(b)) return "Let's play Tic Tac Toe:\n" + "\n".join(game_state_history) + "<|endoftext|>" @@ -283,7 +419,7 @@ def generate_dataset(number_games: int) -> List[str]: Returns: List: List of games (strings with a full game) """ - return [ generate_random_game() for _ in range(number_games) ] + return [generate_random_game() for _ in range(number_games)] if __name__ == "__main__": diff --git a/soft_optim/tests/test_finetune.py b/soft_optim/tests/test_finetune.py index 278f890..029e457 100644 --- a/soft_optim/tests/test_finetune.py +++ b/soft_optim/tests/test_finetune.py @@ -35,7 +35,7 @@ def test_plain_gpt(self): tokenizer = AutoTokenizer.from_pretrained(model_name) # Infer the full game - full_game:str = infer_game(model, tokenizer) + full_game: str = infer_game(model, tokenizer) # Check it throws an error with pytest.raises(Exception) as exc_info: @@ -43,7 +43,6 @@ def test_plain_gpt(self): assert exc_info - def test_fine_tuned_gpt(self): # Run the model if it hasn't already been run if not valid_games_fine_tuned_checkpoint.exists(): @@ -52,13 +51,12 @@ def test_fine_tuned_gpt(self): # Load the fine-tuned model model_name = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(valid_games_fine_tuned_checkpoint) + model = AutoModelForCausalLM.from_pretrained( + valid_games_fine_tuned_checkpoint) # Infer the game - full_game:str = infer_game(model, tokenizer) + full_game: str = infer_game(model, tokenizer) # Check it is valid res = evaluate_game_string(full_game) assert type(res) == int - - diff --git a/soft_optim/tests/test_game_generator.py b/soft_optim/tests/test_game_generator.py index 0c011d6..8619870 100644 --- a/soft_optim/tests/test_game_generator.py +++ b/soft_optim/tests/test_game_generator.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from soft_optim.game import TicTacToeBoard, TicTacToeGame +from soft_optim.game import Board, TicTacToeGame class TestTicTacToeBoardValidStates: @@ -18,32 +18,33 @@ class TestTicTacToeBoardValidStates: - o -""" def test_initializes_empty_board(self): - board = TicTacToeBoard() + board = Board() expected = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) np.testing.assert_array_equal( - board.board_state, + board.board_squares, expected) assert board.check_win() == False assert board.validate_state() def test_parses_empty_board(self): - board = TicTacToeBoard(self.mock_empty_board_str) + board = Board(self.mock_empty_board_str) expected = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) np.testing.assert_array_equal( - board.board_state, + board.board_squares, expected) assert board.check_win() == False assert board.validate_state() def test_parses_x_won_board(self): - board = TicTacToeBoard(self.mock_x_won_board_str) + board = Board(self.mock_x_won_board_str) expected = np.array([[1, 1, 1], [2, 0, 2], [0, 2, 0]]) np.testing.assert_array_equal( - board.board_state, + board.board_squares, expected) assert board.check_win() == board.x assert board.validate_state() + class TestTicTacToeBoardInvalidStates: """Test invalid states for TicTacToeBoard""" @@ -54,7 +55,7 @@ def test_parser_errors_too_many_lines(self): o - o - o - x x x""" - TicTacToeBoard(invalid_str) + Board(invalid_str) def test_parser_errors_too_many_columns(self): with pytest.raises(ValueError, match='Invalid'): @@ -62,7 +63,7 @@ def test_parser_errors_too_many_columns(self): """x x x x o - o - o -""" - TicTacToeBoard(invalid_str) + Board(invalid_str) def test_parser_errors_invalid_row_character(self): with pytest.raises(ValueError, match='Invalid'): @@ -70,22 +71,23 @@ def test_parser_errors_invalid_row_character(self): """x x y o - o - o -""" - TicTacToeBoard(invalid_str) + Board(invalid_str) def test_parser_errors_x_won_cheat_board(self): mock_x_won_cheat_board_str: str = \ - """x x x + """x x x - o - - - -""" - board = TicTacToeBoard(mock_x_won_cheat_board_str) + board = Board(mock_x_won_cheat_board_str) expected = np.array([[1, 1, 1], [0, 2, 0], [0, 0, 0]]) np.testing.assert_array_equal( - board.board_state, + board.board_squares, expected) assert board.check_win() == board.x with pytest.raises(AssertionError): board.validate_state() + class TestTicTacToeGameValidGames: """Test invalid games for TicTacToeGame""" mock_game_x_win_str: str = \ @@ -119,6 +121,7 @@ def test_parses_valid_game(self): assert outcome == game.board.x assert valid + class TestTicTacToeGameInvalidGames: mock_game_o_win_invalid_wrong_player_place_many: str = \ """- - - @@ -165,60 +168,72 @@ class TestTicTacToeGameInvalidGames: def test_game_invalid_wrong_player_place_many(self): game = TicTacToeGame(check_valid_move=False, check_valid_state=False) - outcome, valid = game.validate_game_string(self.mock_game_o_win_invalid_wrong_player_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_o_win_invalid_wrong_player_place_many) assert valid and game.board.o == outcome game = TicTacToeGame(check_valid_move=True, check_valid_state=False) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_o_win_invalid_wrong_player_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_o_win_invalid_wrong_player_place_many) assert valid game = TicTacToeGame(check_valid_move=False, check_valid_state=True) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_o_win_invalid_wrong_player_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_o_win_invalid_wrong_player_place_many) assert valid game = TicTacToeGame(check_valid_move=True, check_valid_state=True) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_o_win_invalid_wrong_player_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_o_win_invalid_wrong_player_place_many) assert valid def test_game_invalid_wrong_player(self): game = TicTacToeGame(check_valid_move=False, check_valid_state=False) - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_wrong_player) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_wrong_player) assert valid and game.board.x == outcome game = TicTacToeGame(check_valid_move=True, check_valid_state=False) - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_wrong_player) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_wrong_player) assert valid and game.board.x == outcome game = TicTacToeGame(check_valid_move=False, check_valid_state=True) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_wrong_player) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_wrong_player) assert valid game = TicTacToeGame(check_valid_move=True, check_valid_state=True) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_wrong_player) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_wrong_player) assert valid def test_game_invalid_place_many(self): game = TicTacToeGame(check_valid_move=False, check_valid_state=False) - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_place_many) assert valid and game.board.x == outcome game = TicTacToeGame(check_valid_move=True, check_valid_state=False) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_place_many) assert valid and game.board.x == outcome # should return true?? game = TicTacToeGame(check_valid_move=False, check_valid_state=True) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_place_many) assert valid game = TicTacToeGame(check_valid_move=True, check_valid_state=True) with pytest.raises(AssertionError): - outcome, valid = game.validate_game_string(self.mock_game_x_win_invalid_place_many) + outcome, valid = game.validate_game_string( + self.mock_game_x_win_invalid_place_many) assert valid diff --git a/soft_optim/tests/test_quantilizer.py b/soft_optim/tests/test_quantilizer.py index ba2a546..8bd00d5 100644 --- a/soft_optim/tests/test_quantilizer.py +++ b/soft_optim/tests/test_quantilizer.py @@ -3,31 +3,32 @@ from soft_optim.quantilizer import empirical_error_bound + class TestEmpiricalErrorBound: def test_error_bounds_work_with_lots_samples(self): # Set the random seed to avoid a brittle test np.random.seed(0) - + epsilon: float = 0.05 number_distributions_compare: int = 1000 experienced_outside_of_bounds: List[bool] = [] - + # For x meta-samples: for _ in range(number_distributions_compare): # Highest variance we can have occurs with a mean of 0.5 & bernoulli - # distribution + # distribution error_distribution_mean: float = np.random.uniform(0, 1) - sample_errors = np.random.binomial(1, error_distribution_mean, 1000) - + sample_errors = np.random.binomial( + 1, error_distribution_mean, 1000) + error_bound: float = empirical_error_bound( np.zeros(1000), sample_errors, epsilon) - + # See if we were in the error bound is_outside_bounds: bool = error_bound < error_distribution_mean experienced_outside_of_bounds.append(is_outside_bounds) - + # Check epsilon percent of x samples are outside the error bound assert np.mean(experienced_outside_of_bounds) < epsilon - \ No newline at end of file