From 93fa326160f7a047fe884610880a1a7344cbf11f Mon Sep 17 00:00:00 2001 From: Gershom A Date: Mon, 8 Jul 2019 13:14:43 +0100 Subject: [PATCH 1/2] Update sum_tree.py --- sum_tree.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sum_tree.py b/sum_tree.py index 43c906c..4041ddf 100644 --- a/sum_tree.py +++ b/sum_tree.py @@ -7,6 +7,7 @@ class SumTree: write = 0 def __init__(self, capacity): + assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." self.capacity = capacity self.tree = numpy.zeros(2 * capacity - 1) self.data = numpy.zeros(capacity, dtype=object) From 04c9c803f871427577281dd0c9ee30f4a80c9880 Mon Sep 17 00:00:00 2001 From: Gershom A Date: Mon, 8 Jul 2019 13:20:06 +0100 Subject: [PATCH 2/2] Update prioritized_memory.py --- prioritized_memory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/prioritized_memory.py b/prioritized_memory.py index a6973f4..b1f4744 100644 --- a/prioritized_memory.py +++ b/prioritized_memory.py @@ -5,6 +5,9 @@ class Memory: # stored as ( s, a, r, s_ ) in SumTree def __init__(self, capacity, alpha=0.6, beta=0.4, beta_anneal_step=0.001, epsilon=0.00000001): + tree_capacity = 1 + while tree_capacity < size: + tree_capacity *= 2 self.tree = SumTree(capacity) self.capacity = capacity self.a = alpha