diff --git a/README.md b/README.md index 8436258..95ea7db 100755 --- a/README.md +++ b/README.md @@ -35,7 +35,8 @@ resultDict = searcher.search(initialState=initialState, needDetails=True) print(resultDict.keys()) #currently includes dict_keys(['action', 'expectedReward']) ``` -See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example. +See [naughtsandcrosses.py](./naughtsandcrosses.py) for a simple example. +See also [connectmnk.py](./connectmnk.py) for an example running a full game bewteen two MCTS agents. ## Slow Usage //TODO diff --git a/__pycache__/mcts.cpython-310.pyc b/__pycache__/mcts.cpython-310.pyc new file mode 100644 index 0000000..0177992 Binary files /dev/null and b/__pycache__/mcts.cpython-310.pyc differ diff --git a/__pycache__/mcts.cpython-38.pyc b/__pycache__/mcts.cpython-38.pyc new file mode 100644 index 0000000..18dfd6f Binary files /dev/null and b/__pycache__/mcts.cpython-38.pyc differ diff --git a/connectmnk.py b/connectmnk.py new file mode 100644 index 0000000..5acbed5 --- /dev/null +++ b/connectmnk.py @@ -0,0 +1,272 @@ +from __future__ import division + +import copy +from mcts import mcts +import random + + +class ConnectMNKState: + """ConnectMNKState models a Connect(m,n,k,1,1) game that generalizes + the famous "Connect Four" itself equal to the Connect(7,6,4,1,1) game. + + Background from wikipedia: + Connect(m,n,k,p,q) games are another generalization of gomoku to a board + with m×n intersections, k in a row needed to win, p stones for each player + to place, and q stones for the first player to place for the first move + only. Each player may play only at the lowest unoccupied place in a column. + In particular, Connect(m,n,6,2,1) is called Connect6. + (see also https://en.wikipedia.org/wiki/Gomoku#Theoretical_generalizations) + """ + + playerNames = {1:'O', -1:'X'} + + def __init__(self, mColumns=7, nRows=6, kConnections=4): + self.mColumns = mColumns + self.nRows = nRows + self.kConnections = kConnections + self.board = [ [0 for _ in range(self.mColumns)] for _ in range(self.nRows)] + self.currentPlayer = max(ConnectMNKState.playerNames.keys()) + self.isTerminated = None + self.reward = None + self.possibleActions = None + self.winingPattern = None + + def show(self): + rowText = "" + for columnIndex in range(self.mColumns): + rowText += f" {columnIndex % 10} " + print(rowText) + + for rowIndex in reversed(range(self.nRows)): + rowText = "" + for x in self.board[rowIndex]: + if x in self.playerNames: + rowText += f" {self.playerNames[x]} " + else: + rowText += " . " + rowText += f" {rowIndex % 10} " + print(rowText) + + def getCurrentPlayer(self): + return self.currentPlayer + + def getPossibleActions(self): + if self.possibleActions is None: + self.possibleActions = [] + for columnIndex in range(self.mColumns): + for rowIndex in range(self.nRows): + if self.board[rowIndex][columnIndex] == 0: + action = Action(player=self.currentPlayer, + columnIndex=columnIndex, + rowIndex=rowIndex) + self.possibleActions.append(action) + break + # Shuflle actions in order to be less predicatable when MCTS is setup with a few explorations + # Maybe better to have it here than in the MCTS engine? + random.shuffle(self.possibleActions) + return self.possibleActions + + def takeAction(self, action): + newState = copy.copy(self) + newState.board = copy.deepcopy(newState.board) + newState.board[action.rowIndex][action.columnIndex] = action.player + newState.currentPlayer = self.currentPlayer * -1 + newState.isTerminated = None + newState.possibleActions = None + newState.winingPattern = None + return newState + + def isTerminal(self): + if self.isTerminated is None: + self.isTerminated = False + for rowIndex in range(self.nRows): + line = self.board[rowIndex] + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-row" + break + + if not self.isTerminated: + for columnIndex in range(self.mColumns): + line = [] + for rowIndex in range(self.nRows): + line.append(self.board[rowIndex][columnIndex]) + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-column" + break + + if not self.isTerminated: + # diagonal: rowIndex = columnIndex + parameter + for parameter in range(1 - self.mColumns, self.nRows): + line = [] + for columnIndex in range(self.mColumns): + rowIndex = columnIndex + parameter + if 0 <= rowIndex < self.nRows: + line.append(self.board[rowIndex][columnIndex]) + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-diagonal" + break + + if not self.isTerminated: + # antidiagonal: rowIndex = - columnIndex + parameter + for parameter in range(0, self.mColumns + self.nRows): + line = [] + for columnIndex in range(self.mColumns): + rowIndex = -columnIndex + parameter + if 0 <= rowIndex < self.nRows: + line.append(self.board[rowIndex][columnIndex]) + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-antidiagonal" + break + + if not self.isTerminated and len(self.getPossibleActions()) == 0: + self.isTerminated = True + self.reward = 0 + + return self.isTerminated + + def __getLineReward(self, line): + lineReward = 0 + if len(line) >= self.kConnections: + for player in ConnectMNKState.playerNames.keys(): + playerLine = [x == player for x in line] + playerConnections = 0 + for x in playerLine: + if x: + playerConnections += 1 + if playerConnections == self.kConnections: + lineReward = player + break + else: + playerConnections = 0 + if lineReward != 0: + break + return lineReward + + def getReward(self): + assert self.isTerminal() + assert self.reward is not None + return self.reward + + +class Action(): + def __init__(self, player, columnIndex, rowIndex): + self.player = player + self.rowIndex = rowIndex + self.columnIndex = columnIndex + + def __str__(self): + return str((self.columnIndex, self.rowIndex)) + + def __repr__(self): + return str(self) + + def __eq__(self, other): + return self.__class__ == (other.__class__ and + self.player == other.player and + self.columnIndex == other.columnIndex and + self.rowIndex == other.rowIndex) + + def __hash__(self): + return hash((self.columnIndex, self.rowIndex, self.player)) + + +def extractStatistics(searcher, action): + statistics = {} + statistics['rootNumVisits'] = searcher.root.numVisits + statistics['rootTotalReward'] = searcher.root.totalReward + statistics['actionNumVisits'] = searcher.root.children[action].numVisits + statistics['actionTotalReward'] = searcher.root.children[action].totalReward + return statistics + + +def main(): + """Run a full match between two MCTS searchers, possibly with different + parametrization, playing a Connect(m,n,k) game. + + Extraction of MCTS statistics is examplified. + + The game parameters (m,n,k) are randomly chosen. + """ + + searchers = {} + searchers["mcts-1500ms"] = mcts(timeLimit=1_500) + searchers["mcts-1000ms"] = mcts(timeLimit=1_000) + searchers["mcts-500ms"] = mcts(timeLimit=500) + searchers["mcts-250ms"] = mcts(timeLimit=250) + + playerNames = ConnectMNKState.playerNames + + playerSearcherNames = {} + for player in sorted(playerNames.keys()): + playerSearcherNames[player] = random.choice(sorted(searchers.keys())) + + runnableGames = list() + runnableGames.append((3, 3, 3)) + runnableGames.append((7, 6, 4)) + runnableGames.append((8, 7, 5)) + runnableGames.append((9, 8, 6)) + (m, n, k) = random.choice(runnableGames) + currentState = ConnectMNKState(mColumns=m, nRows=n, kConnections=k) + + print() + print(f"Connect m={m} n={n} k={k}") + + print() + for player in sorted(playerNames.keys()): + print(f"player {playerNames[player]} = {player} = {playerSearcherNames[player]}") + + print() + _ = input("main: press enter to start") + + turn = 0 + currentState.show() + while not currentState.isTerminal(): + turn += 1 + player = currentState.getCurrentPlayer() + action_count = len(currentState.getPossibleActions()) + + searcherName = playerSearcherNames[player] + searcher = searchers[searcherName] + + action = searcher.search(initialState=currentState) + statistics = extractStatistics(searcher, action) + currentState = currentState.takeAction(action) + + print(f"at turn {turn} player {playerNames[player]}={player} ({searcherName})" + + f" takes action (column, row)={action} amongst {action_count} possibilities") + + print("mcts statitics:" + + f" chosen action= {statistics['actionTotalReward']} total reward" + + f" over {statistics['actionNumVisits']} visits /" + f" all explored actions= {statistics['rootTotalReward']} total reward" + + f" over {statistics['rootNumVisits']} visits") + + print('-'*120) + currentState.show() + + print('-'*120) + if currentState.getReward() == 0: + print(f"Connect(m={m},n={n},k={k}) game terminates; nobody wins") + else: + print(f"Connect(m={m},n={n},k={k}) game terminates;" + + f" player {playerNames[player]}={player} ({searcherName}) wins" + + f" with pattern {currentState.winingPattern}") + + print() + _ = input("main: done ; press enter to terminate") + + +if __name__ == "__main__": + main() diff --git a/mcts.py b/mcts.py index 3ea88f8..bda1e94 100755 --- a/mcts.py +++ b/mcts.py @@ -34,7 +34,7 @@ def __str__(self): return "%s: {%s}"%(self.__class__.__name__, ', '.join(s)) class mcts(): - def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2), + def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=math.sqrt(2), rolloutPolicy=randomPolicy): if timeLimit != None: if iterationLimit != None: @@ -109,8 +109,8 @@ def getBestChild(self, node, explorationValue): bestValue = float("-inf") bestNodes = [] for child in node.children.values(): - nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt( - 2 * math.log(node.numVisits) / child.numVisits) + nodeValue = (node.state.getCurrentPlayer() * child.totalReward / child.numVisits + + explorationValue * math.sqrt(math.log(node.numVisits) / child.numVisits)) if nodeValue > bestValue: bestValue = nodeValue bestNodes = [child] diff --git a/naughtsandcrosses.py b/naughtsandcrosses.py index 9d490a3..87abebe 100755 --- a/naughtsandcrosses.py +++ b/naughtsandcrosses.py @@ -39,7 +39,8 @@ def isTerminal(self): [self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]: if abs(sum(diagonal)) == 3: return True - return reduce(operator.mul, sum(self.board, []), 1) + return reduce(operator.mul, sum(self.board, []), 1) != 0 + def getReward(self): for row in self.board: @@ -52,7 +53,7 @@ def getReward(self): [self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]: if abs(sum(diagonal)) == 3: return sum(diagonal) / 3 - return False + return 0 class Action():