Skip to content

Commit f35697d

Browse files
authored
Merge pull request #45 from copa-uniandes/fixes
fixes on dtmdp test
2 parents e3f5878 + 1882515 commit f35697d

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

tests/tests_dtmdp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_V_value_iteration(self):
1414
# states:
1515
states = np.array([i for i in range(0,N)])
1616
# actions
17-
actions = np.array([a for a in range(0,N)])
17+
actions = np.array([str(a) for a in range(0,N)])
1818
# immediate returns:
1919
immediate_returns = np.array([[3, 1], [2, 3]])
2020
# discount factor:
@@ -33,7 +33,7 @@ def test_policy_value_iteration(self):
3333
# states:
3434
states = np.array([i for i in range(0,N)])
3535
# actions
36-
actions = np.array([a for a in range(0,N)])
36+
actions = np.array([str(a) for a in range(0,N)])
3737
# immediate returns:
3838
immediate_returns = np.array([[3, 1], [2, 3]])
3939
# discount factor:
@@ -45,14 +45,14 @@ def test_policy_value_iteration(self):
4545

4646
mdp = dtmdp(states, actions, transition_matrices, immediate_returns, discount_factor)
4747
result = mdp.solve(0, minimize = True)[1]
48-
self.assertEqual(result, {0: 1, 1: 0})
48+
self.assertEqual(result, {0: '1', 1: '0'})
4949
def test_V_policy_iteration(self):
5050
# number of states:
5151
N = 2
5252
# states:
5353
states = np.array([i for i in range(0,N)])
5454
# actions
55-
actions = np.array([a for a in range(0,N)])
55+
actions = np.array([str(a) for a in range(0,N)])
5656
# immediate returns:
5757
immediate_returns = np.array([[3, 1], [2, 3]])
5858
# discount factor:
@@ -71,7 +71,7 @@ def test_policy_policy_iteration(self):
7171
# states:
7272
states = np.array([i for i in range(0,N)])
7373
# actions
74-
actions = np.array([a for a in range(0,N)])
74+
actions = np.array([str(a) for a in range(0,N)])
7575
# immediate returns:
7676
immediate_returns = np.array([[3, 1], [2, 3]])
7777
# discount factor:
@@ -83,7 +83,7 @@ def test_policy_policy_iteration(self):
8383

8484
mdp = dtmdp(states, actions, transition_matrices, immediate_returns, discount_factor)
8585
result = mdp.solve(0, minimize = True)[1]
86-
self.assertEqual(result, {0: 1, 1: 0})
86+
self.assertEqual(result, {0: '1', 1: '0'})
8787

8888
if __name__ == '__main__':
8989
unittest.main()

0 commit comments

Comments
 (0)