From 776ad6f0b05f512e0d9c4bf5a7bb7d2003c0b2c0 Mon Sep 17 00:00:00 2001 From: anton-gasse Date: Tue, 4 Jun 2024 23:10:39 +0200 Subject: [PATCH] Feat: added load and save progress methods --- README.md | 13 ++++++++++--- mcts.py | 42 ++++++++++++++++++++++++++++++++---------- naughtsandcrosses.py | 12 ++++++++++-- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 8436258..0bfe6e7 100755 --- a/README.md +++ b/README.md @@ -25,13 +25,20 @@ Once these have been implemented, running MCTS is as simple as initializing your ```python from mcts import mcts -searcher = mcts(timeLimit=1000) -bestAction = searcher.search(initialState=initialState) +searcher = mcts(initialState=initialState, timeLimit=1000) +bestAction = searcher.search() +searcher.save(path='./root') + +newSearcher = mcts(path="./root", iterationLimit=80000) +action = newSearcher.search() + +print(action) ``` Here the unit of `timeLimit=1000` is millisecond. You can also use `iterationLimit=1600` to specify the number of rollouts. Exactly one of `timeLimit` and `iterationLimit` should be specified. The expected reward of best action can be got by setting `needDetails` to `True` in `searcher`. +You can save the current progress with the save() method and create a new instance providing the path to load the progress. ```python -resultDict = searcher.search(initialState=initialState, needDetails=True) +resultDict = searcher.search(needDetails=True) print(resultDict.keys()) #currently includes dict_keys(['action', 'expectedReward']) ``` diff --git a/mcts.py b/mcts.py index 3ea88f8..94ebcdc 100755 --- a/mcts.py +++ b/mcts.py @@ -3,6 +3,7 @@ import time import math import random +import pickle def randomPolicy(state): @@ -34,16 +35,23 @@ def __str__(self): return "%s: {%s}"%(self.__class__.__name__, ', '.join(s)) class mcts(): - def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2), - rolloutPolicy=randomPolicy): - if timeLimit != None: - if iterationLimit != None: + def __init__(self, initialState: treeNode = None, path: str = None, timeLimit=None, iterationLimit=None, + explorationConstant=1 / math.sqrt(2),rolloutPolicy=randomPolicy): + if initialState is not None: + self.root = treeNode(initialState, None) + elif path is not None: + self.load(path=path) + else: + raise ValueError("Either initialState or path must be provided.") + + if timeLimit is not None: + if iterationLimit is not None: raise ValueError("Cannot have both a time limit and an iteration limit") # time taken for each MCTS search in milliseconds self.timeLimit = timeLimit self.limitType = 'time' else: - if iterationLimit == None: + if iterationLimit is None: raise ValueError("Must have either a time limit or an iteration limit") # number of iterations of the search if iterationLimit < 1: @@ -53,8 +61,7 @@ def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / self.explorationConstant = explorationConstant self.rollout = rolloutPolicy - def search(self, initialState, needDetails=False): - self.root = treeNode(initialState, None) + def search(self, needDetails=False): if self.limitType == 'time': timeLimit = time.time() + self.timeLimit / 1000 @@ -79,7 +86,7 @@ def executeRound(self): reward = self.rollout(node.state) self.backpropogate(node, reward) - def selectNode(self, node): + def selectNode(self, node: treeNode): while not node.isTerminal: if node.isFullyExpanded: node = self.getBestChild(node, self.explorationConstant) @@ -87,7 +94,7 @@ def selectNode(self, node): return self.expand(node) return node - def expand(self, node): + def expand(self, node: treeNode): actions = node.state.getPossibleActions() for action in actions: if action not in node.children: @@ -116,4 +123,19 @@ def getBestChild(self, node, explorationValue): bestNodes = [child] elif nodeValue == bestValue: bestNodes.append(child) - return random.choice(bestNodes) \ No newline at end of file + return random.choice(bestNodes) + + def save(self, path: str) -> bool: + try: + with open( path, "wb" ) as f: + pickle.dump(self.root, f) + return True + except Exception as e: + print(e) + + def load(self, path: str) -> bool: + try: + with open( path, "rb" ) as f: + self.root = pickle.load(f) + except Exception as e: + print(e) \ No newline at end of file diff --git a/naughtsandcrosses.py b/naughtsandcrosses.py index 9d490a3..a5ba9e7 100755 --- a/naughtsandcrosses.py +++ b/naughtsandcrosses.py @@ -75,7 +75,15 @@ def __hash__(self): if __name__=="__main__": initialState = NaughtsAndCrossesState() - searcher = mcts(timeLimit=1000) - action = searcher.search(initialState=initialState) + + searcher = mcts(initialState=initialState, timeLimit=1000) + action = searcher.search() print(action) + + searcher.save("./root") + + newSearcher = mcts(path="./root", iterationLimit=8000) + action = newSearcher.search() + + print(action) \ No newline at end of file