@@ -14,7 +14,7 @@ def test_V_value_iteration(self):
1414 # states:
1515 states = np .array ([i for i in range (0 ,N )])
1616 # actions
17- actions = np .array ([a for a in range (0 ,N )])
17+ actions = np .array ([str ( a ) for a in range (0 ,N )])
1818 # immediate returns:
1919 immediate_returns = np .array ([[3 , 1 ], [2 , 3 ]])
2020 # discount factor:
@@ -33,7 +33,7 @@ def test_policy_value_iteration(self):
3333 # states:
3434 states = np .array ([i for i in range (0 ,N )])
3535 # actions
36- actions = np .array ([a for a in range (0 ,N )])
36+ actions = np .array ([str ( a ) for a in range (0 ,N )])
3737 # immediate returns:
3838 immediate_returns = np .array ([[3 , 1 ], [2 , 3 ]])
3939 # discount factor:
@@ -45,14 +45,14 @@ def test_policy_value_iteration(self):
4545
4646 mdp = dtmdp (states , actions , transition_matrices , immediate_returns , discount_factor )
4747 result = mdp .solve (0 , minimize = True )[1 ]
48- self .assertEqual (result , {0 : 1 , 1 : 0 })
48+ self .assertEqual (result , {0 : '1' , 1 : '0' })
4949 def test_V_policy_iteration (self ):
5050 # number of states:
5151 N = 2
5252 # states:
5353 states = np .array ([i for i in range (0 ,N )])
5454 # actions
55- actions = np .array ([a for a in range (0 ,N )])
55+ actions = np .array ([str ( a ) for a in range (0 ,N )])
5656 # immediate returns:
5757 immediate_returns = np .array ([[3 , 1 ], [2 , 3 ]])
5858 # discount factor:
@@ -71,7 +71,7 @@ def test_policy_policy_iteration(self):
7171 # states:
7272 states = np .array ([i for i in range (0 ,N )])
7373 # actions
74- actions = np .array ([a for a in range (0 ,N )])
74+ actions = np .array ([str ( a ) for a in range (0 ,N )])
7575 # immediate returns:
7676 immediate_returns = np .array ([[3 , 1 ], [2 , 3 ]])
7777 # discount factor:
@@ -83,7 +83,7 @@ def test_policy_policy_iteration(self):
8383
8484 mdp = dtmdp (states , actions , transition_matrices , immediate_returns , discount_factor )
8585 result = mdp .solve (0 , minimize = True )[1 ]
86- self .assertEqual (result , {0 : 1 , 1 : 0 })
86+ self .assertEqual (result , {0 : '1' , 1 : '0' })
8787
8888if __name__ == '__main__' :
8989 unittest .main ()
0 commit comments