From 64153c8b8aa93efb7de0c9a9524927d634e2b578 Mon Sep 17 00:00:00 2001 From: Lucas Borboleta Date: Mon, 22 Feb 2021 16:57:46 +0100 Subject: [PATCH 1/3] Connect (m,n,k) game + a few minor changes --- README.md | 3 +- connectmnk.py | 258 +++++++++++++++++++++++++++++++++++++++++++ mcts.py | 6 +- naughtsandcrosses.py | 5 +- 4 files changed, 266 insertions(+), 6 deletions(-) create mode 100644 connectmnk.py diff --git a/README.md b/README.md index 8436258..95ea7db 100755 --- a/README.md +++ b/README.md @@ -35,7 +35,8 @@ resultDict = searcher.search(initialState=initialState, needDetails=True) print(resultDict.keys()) #currently includes dict_keys(['action', 'expectedReward']) ``` -See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example. +See [naughtsandcrosses.py](./naughtsandcrosses.py) for a simple example. +See also [connectmnk.py](./connectmnk.py) for an example running a full game bewteen two MCTS agents. ## Slow Usage //TODO diff --git a/connectmnk.py b/connectmnk.py new file mode 100644 index 0000000..f98d472 --- /dev/null +++ b/connectmnk.py @@ -0,0 +1,258 @@ +from __future__ import division + +import copy +from mcts import mcts +import random + + +class ConnectMNKState: + """ConnectMNKState models a Connect(m,n,k,1,1) game that generalizes + the famous "Connect Four" itself equal to the Connect(7,6,4,1,1) game. + + Background from wikipedia: + Connect(m,n,k,p,q) games are another generalization of gomoku to a board + with m×n intersections, k in a row needed to win, p stones for each player + to place, and q stones for the first player to place for the first move + only. Each player may play only at the lowest unoccupied place in a column. + In particular, Connect(m,n,6,2,1) is called Connect6. + """ + + playerNames = {1:'O', -1:'X'} + + def __init__(self, mColumns=7, nRows=6, kConnections=4): + self.mColumns = mColumns + self.nRows = nRows + self.kConnections = kConnections + self.board = [ [0 for _ in range(self.mColumns)] for _ in range(self.nRows)] + self.currentPlayer = max(ConnectMNKState.playerNames.keys()) + self.isTerminated = None + self.reward = None + self.possibleActions = None + self.winingPattern = None + + def show(self): + rowText = "" + for columnIndex in range(self.mColumns): + rowText += f" {columnIndex % 10} " + print(rowText) + + for rowIndex in reversed(range(self.nRows)): + rowText = "" + for x in self.board[rowIndex]: + if x in self.playerNames: + rowText += f" {self.playerNames[x]} " + else: + rowText += " . " + rowText += f" {rowIndex % 10} " + print(rowText) + + def getCurrentPlayer(self): + return self.currentPlayer + + def getPossibleActions(self): + if self.possibleActions is None: + self.possibleActions = [] + for columnIndex in range(self.mColumns): + for rowIndex in range(self.nRows): + if self.board[rowIndex][columnIndex] == 0: + action = Action(player=self.currentPlayer, + columnIndex=columnIndex, + rowIndex=rowIndex) + self.possibleActions.append(action) + break + # Shuflle actions in order to be less predicatable when MCTS is setup with a few explorations + # Maybe better to have it here than in the MCTS engine? + random.shuffle(self.possibleActions) + return self.possibleActions + + def takeAction(self, action): + newState = copy.copy(self) + newState.board = copy.deepcopy(newState.board) + newState.board[action.rowIndex][action.columnIndex] = action.player + newState.currentPlayer = self.currentPlayer * -1 + newState.isTerminated = None + newState.possibleActions = None + newState.winingPattern = None + return newState + + def isTerminal(self): + if self.isTerminated is None: + self.isTerminated = False + for rowIndex in range(self.nRows): + line = self.board[rowIndex] + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-row" + break + + if not self.isTerminated: + for columnIndex in range(self.mColumns): + line = [] + for rowIndex in range(self.nRows): + line.append(self.board[rowIndex][columnIndex]) + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-column" + break + + if not self.isTerminated: + # diagonal: rowIndex = columnIndex + parameter + for parameter in range(1 - self.mColumns, self.nRows): + line = [] + for columnIndex in range(self.mColumns): + rowIndex = columnIndex + parameter + if 0 <= rowIndex < self.nRows: + line.append(self.board[rowIndex][columnIndex]) + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-diagonal" + break + + if not self.isTerminated: + # antidiagonal: rowIndex = - columnIndex + parameter + for parameter in range(0, self.mColumns + self.nRows): + line = [] + for columnIndex in range(self.mColumns): + rowIndex = -columnIndex + parameter + if 0 <= rowIndex < self.nRows: + line.append(self.board[rowIndex][columnIndex]) + lineReward = self.__getLineReward(line) + if lineReward != 0: + self.isTerminated = True + self.reward = lineReward + self.winingPattern = "k-in-antidiagonal" + break + + if not self.isTerminated and len(self.getPossibleActions()) == 0: + self.isTerminated = True + self.reward = 0 + + return self.isTerminated + + def __getLineReward(self, line): + lineReward = 0 + if len(line) >= self.kConnections: + for player in ConnectMNKState.playerNames.keys(): + playerLine = [x == player for x in line] + playerConnections = 0 + for x in playerLine: + if x: + playerConnections += 1 + if playerConnections == self.kConnections: + lineReward = player + break + else: + playerConnections = 0 + if lineReward != 0: + break + return lineReward + + def getReward(self): + assert self.isTerminal() + assert self.reward is not None + return self.reward + + +class Action(): + def __init__(self, player, columnIndex, rowIndex): + self.player = player + self.rowIndex = rowIndex + self.columnIndex = columnIndex + + def __str__(self): + return str((self.columnIndex, self.rowIndex)) + + def __repr__(self): + return str(self) + + def __eq__(self, other): + return self.__class__ == (other.__class__ and + self.player == other.player and + self.columnIndex == other.columnIndex and + self.rowIndex == other.rowIndex) + + def __hash__(self): + return hash((self.columnIndex, self.rowIndex, self.player)) + + +def extractStatistics(searcher, action): + statistics = {} + statistics['rootNumVisits'] = searcher.root.numVisits + statistics['rootTotalReward'] = searcher.root.totalReward + statistics['actionNumVisits'] = searcher.root.children[action].numVisits + statistics['actionTotalReward'] = searcher.root.children[action].totalReward + return statistics + + +def main(): + """Run a full match between two MCTS searchers, possibly with different + parametrization, playing a Connect(m,n,k) game. + + Extraction of MCTS statistics is examplified. + + The game parameters (m,n,k) are randomly chosen. + """ + + searchers = {} + searchers["mcts-1500ms"] = mcts(timeLimit=1_500) + searchers["mcts-1000ms"] = mcts(timeLimit=1_000) + searchers["mcts-500ms"] = mcts(timeLimit=500) + searchers["mcts-250ms"] = mcts(timeLimit=250) + + playerNames = ConnectMNKState.playerNames + + playerSearcherNames = {} + for player in sorted(playerNames.keys()): + playerSearcherNames[player] = random.choice(sorted(searchers.keys())) + + runnableGames = list() + runnableGames.append((3, 3, 3)) + runnableGames.append((7, 6, 4)) + runnableGames.append((8, 7, 5)) + runnableGames.append((9, 8, 6)) + (m, n, k) = random.choice(runnableGames) + currentState = ConnectMNKState(mColumns=m, nRows=n, kConnections=k) + + turn = 0 + currentState.show() + while not currentState.isTerminal(): + turn += 1 + player = currentState.getCurrentPlayer() + action_count = len(currentState.getPossibleActions()) + + searcherName = playerSearcherNames[player] + searcher = searchers[searcherName] + + action = searcher.search(initialState=currentState) + statistics = extractStatistics(searcher, action) + currentState = currentState.takeAction(action) + + print(f"at turn {turn} player {playerNames[player]}={player} ({searcherName})" + + f" takes action (column, row)={action} amongst {action_count} possibilities") + + print("mcts statitics:" + + f" chosen action= {statistics['actionTotalReward']} total reward" + + f" over {statistics['actionNumVisits']} visits /" + f" all explored actions= {statistics['rootTotalReward']} total reward" + + f" over {statistics['rootNumVisits']} visits") + + print('-'*120) + currentState.show() + + print('-'*120) + if currentState.getReward() == 0: + print(f"Connect(m={m},n={n},k={k}) game terminates; nobody wins") + else: + print(f"Connect(m={m},n={n},k={k}) game terminates;" + + f" player {playerNames[player]}={player} ({searcherName}) wins" + + f" with pattern {currentState.winingPattern}") + + +if __name__ == "__main__": + main() diff --git a/mcts.py b/mcts.py index 3ea88f8..bda1e94 100755 --- a/mcts.py +++ b/mcts.py @@ -34,7 +34,7 @@ def __str__(self): return "%s: {%s}"%(self.__class__.__name__, ', '.join(s)) class mcts(): - def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2), + def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=math.sqrt(2), rolloutPolicy=randomPolicy): if timeLimit != None: if iterationLimit != None: @@ -109,8 +109,8 @@ def getBestChild(self, node, explorationValue): bestValue = float("-inf") bestNodes = [] for child in node.children.values(): - nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt( - 2 * math.log(node.numVisits) / child.numVisits) + nodeValue = (node.state.getCurrentPlayer() * child.totalReward / child.numVisits + + explorationValue * math.sqrt(math.log(node.numVisits) / child.numVisits)) if nodeValue > bestValue: bestValue = nodeValue bestNodes = [child] diff --git a/naughtsandcrosses.py b/naughtsandcrosses.py index 9d490a3..87abebe 100755 --- a/naughtsandcrosses.py +++ b/naughtsandcrosses.py @@ -39,7 +39,8 @@ def isTerminal(self): [self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]: if abs(sum(diagonal)) == 3: return True - return reduce(operator.mul, sum(self.board, []), 1) + return reduce(operator.mul, sum(self.board, []), 1) != 0 + def getReward(self): for row in self.board: @@ -52,7 +53,7 @@ def getReward(self): [self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]: if abs(sum(diagonal)) == 3: return sum(diagonal) / 3 - return False + return 0 class Action(): From 8c90ac1688cdc5c12c186320d05aae35b3338b20 Mon Sep 17 00:00:00 2001 From: Lucas Borboleta Date: Thu, 30 Nov 2023 21:02:31 +0100 Subject: [PATCH 2/3] Add printed info in connectmnk example/demo --- __pycache__/mcts.cpython-310.pyc | Bin 0 -> 4091 bytes __pycache__/mcts.cpython-38.pyc | Bin 0 -> 4053 bytes connectmnk.py | 13 +++++++++++++ 3 files changed, 13 insertions(+) create mode 100644 __pycache__/mcts.cpython-310.pyc create mode 100644 __pycache__/mcts.cpython-38.pyc diff --git a/__pycache__/mcts.cpython-310.pyc b/__pycache__/mcts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017799285a0531d1d93558c84a566cf88559d104 GIT binary patch literal 4091 zcmZWsU2_{r744ocjbvGN?2tgRwG0qLv?L~wt%41yRKQM_rFLDZ#Dpr9gBYvZ_Q<0d zyL*(_VnhX;s$Kp84?IY3o4>>3yz)Z+0t*$+?U7|U8P)CS@9Dnx+;gwe`S}{d^Y?$= z{^OSy82cABW*-}k`w;y(3c&=AS&z3kr@9qeEz7*^mTlfn%MrG4Ua^+@DHE>nUNPaF zSS=5&ity3$(W(gR32Rp0U`N*Eg)gG#k&e=2qr<6d@>qELkTre`*H-)X7zRxI&hCAZeoPQNFUOmBS}WxwffZ+)@0`DClt$@J>MaZxXi z_b82{&M_%dFQM?c#g}-UFLLW`&GNtTJ++AL4&>ZJ7Pt@5AE1z|#RY@Cxv*egdtzrt zxX_|k_?eP&BNehkh&g?zPP-TzevC@abJj*V;$zsNjaHX>tQ_`i+k)zjDcy+1JyF_7 z5ipt+RmV%wro}>}ANJ$;_`%Tu6d^_79fS&cD5^=n_Y`o-bTOBuSs0fpE&R@26hmW4 z(=Hq>9u}qNzo#-R`BLAJZr5?e<|mjLQ$IZ(~?}mx?n(ni^SCfz=qHUVA8ffWKZf zeH}E^hIoxgER?9gk7Vmln_z;~L7lIR# z+a#CFsHLuA)Hb7*aQ2z{0orcaHjV{4gp&bYZ6^n)+OBad0O#YiFiFx3g!fzq+iA8J zgh3YdWDrNaC=1|QkRSp*hQwNBcj3^K!R;^lda_bR*`8GA*QRK{xAYBp^5FVs=Qj*` zw9~c>c9jJEsvz6L(ln8++Vs;JEK_?L#(g=VQaZ!ubjqS?4!e1LAd5LIL)F zL7bK+Ujuz-Aqb+P(l}208E{1iRX;+}YD_<4R;rJvbAB?jq|X^}8C((SRv`Kc3g7aH zAgG+klq}{A8J`lx!*l#dkaQsASNdxdnKAnb8JD;MY`br%mCUACM_eGRtL!ox+LIm| zroM()28KY~D`%lkUrFi|0lSqGzRsR`LyC81&f!NC`7pwTJ+}VFzGORGfWu#zzHW6? z8?$vp&C%SL4c**Qk8}4c)~!S4uOgQJ1Ccu?#;d!S!FUK2uiCpW@VqfcF#jnr|5Wf}E|@TWTIl3MWB0dCgnZZK<<0!`@S` zE~fJc)gt5N2I-Lw!qP<%UER~&1ib=q27)mG{ zm8O}xjv3~51{+0P9?4EWlaJGWA_^BFbyun;<`i{A>|ak0XL!*Ckb2s%p*%teIU~=e zS2ALqVoMapPbkr7%_UNbb(w`xtjo{w$Q-eR+fV2yCMbF3oqmB16OZe>W>F9apS`tS zHtZU2SRZl+f!(nNb2D0a&oHE`HQI*m2{;h2C^-V)kXG@ojXD50a=J+QxYrAyWRT}b?$(7!9%b#KFw-7`3b%(N?!3Cp^Su=KsYEK`k zPnB{D4Jd)>ODMT&luCh$;4ed-AOe$O#!v~(Qg)7-x`9RD6N0>=SfeP{F(guWMkz?? z&W`@>k?&ygV~B}{#2bsSt@mHY9V|@jND4N|)$Fk;VvrbHmJy@ND6t|JgWInRZ7@FX zTF<$x833`-BMRIuaG$_HxI9gFa!0l?*96MU$qz zu{pT>WH0TK;vnXv#w<8LI*5Bp&M4H-2S;`BK~GNS3TZV)R71rpVgo>T1p% z)&;tMj=)PvDOy0hY2p4|n7f>*9iphyRFpWsX_xts$&5@2c^AF!+Vz*{+kt3OfzKQG z!M6V!EE%hwL-x4?S7 zW#mJu81-Oo(>vxDDROwZm)rY}x|Mr?hr7#%&f%Qtv(zWKGpgk-Wv7R8+1#j}JA`xk zkL28plAY#!;qJs~n3XA@_nd;UE**6dRJhhxl!8Bs!(*w=hz}`GnC&6yK2#ZiknXDQ z;7*D>P7!{Q*PB3W(69&to`!4_>`VHesS-ML;hJ>5RZY+`gYPPzJTQio;Tgo(K!SON zJ8yl<0b~e3)=SFejWK44yvZLn=S|M30&JkZM+MnJty4jNmCBU!6RHv4C_8aN}MKhW{CA&NQtgn#lD)RX<9GOEHy_rK$ hW@{GA+{tJCc&s~rbWrMSe{l?vU&pr3@%SNNd@cjMX zx4-`OJY)Z%&irGevx=f0qY_N;nDzK3=hU`hYtu4sd($>=XVVe3a9*)Z_cJD3;k{zQ zJGM4G^eVzf&quE!tY@rMeS;HOix<9#UPd}fll2a#q3L7c?W3%r=nYgk8?y-n87^gf zY@ugQ*o2SxA6RnDVVSdEQ};NdU*-y)c{w}gmr(x7|Hu;SE81VlId-rEYWougDXjJ3 z&Gj_7l}Xi$k}wW*7G^To3w4mBK|j?x+Ky!qb}|U4?*y%CQH%7(bammWFcE35@H%^G z)RD!~uFRfHH+@p>DQcf5LLNR)Dpf`G!C^=C>8!%mnJQ{oxG&3vMHPFW%5Op?T2|p2 zISQ}*@N1@;(7^vz*6wV5sio3evBZa!+`6B320fW%dh4?&``uuB>x;FGXIs5ardRq$ zMZFa7Ng79;BT}ZOV;J|j#g}-KFY`KAix}z9@ccvCUq#XHqmpcs3kF+rVZpZc)W(i* zp+T?kGbQDEDrASyGJB{^rx^Qvh(^wH)W1S(AiW!&M5=68_~7w1%K zmaVs3MN5lXyWLC0Af`{$+wFrv7?&SZmocq=MAfMwO^d9lz-Vjg~B1PTz^A8(V;KS9%`RB$~HPVoNeD$Kg~g|cwf!{9-=YhKyX2Fn&hUL zwbVzLwau(0oPDMqqVJY{<5r+SxER3Ic5*g<){7vy0-PD{zx zK-^gfdZ?&0j?+N~R1rSa4OB(*sfv~QyOKfCMM{Jn- zD?~By17`;}lFJnDx|L&opS`G$VCd=0BZ_pG;liF+e`DXW9WKDzugqAtI<8IFeZKrUT1&z7}Nu&QlXv4~9n2&Bkmba`C zAA#Y&WBb0ca{ip81+1U3*L>aR6_jkPe4rZGP&f&K$ZK9yx24Y33}>JCa#6jzD-)Pi z-Mdd;C>7`-MHpT+dM?esayL#pVXW`1oGb?&nbVutEIXML(;(Y4WKcLNO*7TPndWvS zD(do3b_SVznhp|CxCo=WQvHl%tRqT)dLo|TLl+S0X~Tl@5TWCgG+SQDgmsE3Q5Zj? z=wI8hn|m?4egT^nt{ZtQ?xtciuxj`+$NZq$OgXR44O3=lw}Mo_{n zC^SqcbGxOZ-XJzSacev5?DtjL4|hx6P^7|>qd@6JtvrD;2K5nSQa7oha8NX6yHo6z zukF%Ne@RoGGP%*)&iFCr9-$aOFbAsP!}CTR^Cs*N!Ok9PpD2aC4JAN3sJVK-R0>oC ze;IL;DMB$5s03})>KE7m1|g^`iY@9lR2kavj6#sUor?XoxT{$F6h)h~+c3qV1(DB9 zEl8?0snoo{6d|Y>EAF5{0#h1At_6pGG_t|)yld@OR5PO%Src~6`LP8uzy-R90e6uY z=hiVtjs7ta%oC&~yzS+V1;@k{n;DXGiXO({?e&e}rDuETAQnL)UxJ-M(w)5kgfClu znM9EvnzX=plcH-Cb}W-pH-f$-_5MG%1*w^owv$ z1%quobiyxEeDHEFxAz_Ob?yNa?k*oW2g_#6 zQp>qBuH`Oepa;v@!nmG0glqZ_@^gFTKc^8eeFVR$)~68duG1nQ!=|V$Oa0@ stRgQj!;TTb=*);fnXB0_ckh!g75?2a9mfA%e2$p2UZx?oUu!P?A8T=(TL1t6 literal 0 HcmV?d00001 diff --git a/connectmnk.py b/connectmnk.py index f98d472..4ba824d 100644 --- a/connectmnk.py +++ b/connectmnk.py @@ -219,6 +219,16 @@ def main(): (m, n, k) = random.choice(runnableGames) currentState = ConnectMNKState(mColumns=m, nRows=n, kConnections=k) + print() + print(f"Connect m={m} n={n} k={k}") + + print() + for player in sorted(playerNames.keys()): + print(f"player {playerNames[player]} = {player} = {playerSearcherNames[player]}") + + print() + _ = input("main: press enter to start") + turn = 0 currentState.show() while not currentState.isTerminal(): @@ -253,6 +263,9 @@ def main(): f" player {playerNames[player]}={player} ({searcherName}) wins" + f" with pattern {currentState.winingPattern}") + print() + _ = input("main: done ; press enter to terminate") + if __name__ == "__main__": main() From 50b30285229d2cc06e271da11e0582ede61da826 Mon Sep 17 00:00:00 2001 From: Lucas Borboleta Date: Thu, 30 Nov 2023 21:22:46 +0100 Subject: [PATCH 3/3] Add URL about Connect(m,n,k,p,q) --- connectmnk.py | 1 + 1 file changed, 1 insertion(+) diff --git a/connectmnk.py b/connectmnk.py index 4ba824d..5acbed5 100644 --- a/connectmnk.py +++ b/connectmnk.py @@ -15,6 +15,7 @@ class ConnectMNKState: to place, and q stones for the first player to place for the first move only. Each player may play only at the lowest unoccupied place in a column. In particular, Connect(m,n,6,2,1) is called Connect6. + (see also https://en.wikipedia.org/wiki/Gomoku#Theoretical_generalizations) """ playerNames = {1:'O', -1:'X'}