Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,20 @@ Once these have been implemented, running MCTS is as simple as initializing your
```python
from mcts import mcts

searcher = mcts(timeLimit=1000)
bestAction = searcher.search(initialState=initialState)
searcher = mcts(initialState=initialState, timeLimit=1000)
bestAction = searcher.search()
searcher.save(path='./root')

newSearcher = mcts(path="./root", iterationLimit=80000)
action = newSearcher.search()

print(action)
```
Here the unit of `timeLimit=1000` is millisecond. You can also use `iterationLimit=1600` to specify the number of rollouts. Exactly one of `timeLimit` and `iterationLimit` should be specified. The expected reward of best action can be got by setting `needDetails` to `True` in `searcher`.

You can save the current progress with the save() method and create a new instance providing the path to load the progress.
```python
resultDict = searcher.search(initialState=initialState, needDetails=True)
resultDict = searcher.search(needDetails=True)
print(resultDict.keys()) #currently includes dict_keys(['action', 'expectedReward'])
```

Expand Down
42 changes: 32 additions & 10 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import math
import random
import pickle


def randomPolicy(state):
Expand Down Expand Up @@ -34,16 +35,23 @@ def __str__(self):
return "%s: {%s}"%(self.__class__.__name__, ', '.join(s))

class mcts():
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
rolloutPolicy=randomPolicy):
if timeLimit != None:
if iterationLimit != None:
def __init__(self, initialState: treeNode = None, path: str = None, timeLimit=None, iterationLimit=None,
explorationConstant=1 / math.sqrt(2),rolloutPolicy=randomPolicy):
if initialState is not None:
self.root = treeNode(initialState, None)
elif path is not None:
self.load(path=path)
else:
raise ValueError("Either initialState or path must be provided.")

if timeLimit is not None:
if iterationLimit is not None:
raise ValueError("Cannot have both a time limit and an iteration limit")
# time taken for each MCTS search in milliseconds
self.timeLimit = timeLimit
self.limitType = 'time'
else:
if iterationLimit == None:
if iterationLimit is None:
raise ValueError("Must have either a time limit or an iteration limit")
# number of iterations of the search
if iterationLimit < 1:
Expand All @@ -53,8 +61,7 @@ def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 /
self.explorationConstant = explorationConstant
self.rollout = rolloutPolicy

def search(self, initialState, needDetails=False):
self.root = treeNode(initialState, None)
def search(self, needDetails=False):

if self.limitType == 'time':
timeLimit = time.time() + self.timeLimit / 1000
Expand All @@ -79,15 +86,15 @@ def executeRound(self):
reward = self.rollout(node.state)
self.backpropogate(node, reward)

def selectNode(self, node):
def selectNode(self, node: treeNode):
while not node.isTerminal:
if node.isFullyExpanded:
node = self.getBestChild(node, self.explorationConstant)
else:
return self.expand(node)
return node

def expand(self, node):
def expand(self, node: treeNode):
actions = node.state.getPossibleActions()
for action in actions:
if action not in node.children:
Expand Down Expand Up @@ -116,4 +123,19 @@ def getBestChild(self, node, explorationValue):
bestNodes = [child]
elif nodeValue == bestValue:
bestNodes.append(child)
return random.choice(bestNodes)
return random.choice(bestNodes)

def save(self, path: str) -> bool:
try:
with open( path, "wb" ) as f:
pickle.dump(self.root, f)
return True
except Exception as e:
print(e)

def load(self, path: str) -> bool:
try:
with open( path, "rb" ) as f:
self.root = pickle.load(f)
except Exception as e:
print(e)
12 changes: 10 additions & 2 deletions naughtsandcrosses.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ def __hash__(self):

if __name__=="__main__":
initialState = NaughtsAndCrossesState()
searcher = mcts(timeLimit=1000)
action = searcher.search(initialState=initialState)

searcher = mcts(initialState=initialState, timeLimit=1000)
action = searcher.search()

print(action)

searcher.save("./root")

newSearcher = mcts(path="./root", iterationLimit=8000)
action = newSearcher.search()

print(action)