From 4563a61ff04d28df085c2b97066af913f77b1cc9 Mon Sep 17 00:00:00 2001 From: ligerlac Date: Mon, 14 Oct 2024 15:55:59 +0200 Subject: [PATCH] bugfix: recognize single layer as first and last --- difflogic/compiled_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/difflogic/compiled_model.py b/difflogic/compiled_model.py index 307e578..e64e765 100644 --- a/difflogic/compiled_model.py +++ b/difflogic/compiled_model.py @@ -139,18 +139,16 @@ def get_gate_code(self, var1, var2, gate_op): def get_layer_code(self, layer_a, layer_b, layer_op, layer_id, prefix_sums): code = [] for var_id, (gate_a, gate_b, gate_op) in enumerate(zip(layer_a, layer_b, layer_op)): - if self.device == 'cpu' and layer_id == len(prefix_sums) - 1: + if layer_id == 0: + a = f"inp[{gate_a}]" + b = f"inp[{gate_b}]" + else: a = f"v{prefix_sums[layer_id - 1] + gate_a}" b = f"v{prefix_sums[layer_id - 1] + gate_b}" + if self.device == 'cpu' and layer_id == len(prefix_sums) - 1: code.append(f"\tout[{var_id}] = {self.get_gate_code(a, b, gate_op)};") else: assert not (self.device == 'cpu' and layer_id >= len(prefix_sums) - 1), (layer_id, len(prefix_sums)) - if layer_id == 0: - a = f"inp[{gate_a}]" - b = f"inp[{gate_b}]" - else: - a = f"v{prefix_sums[layer_id - 1] + gate_a}" - b = f"v{prefix_sums[layer_id - 1] + gate_b}" code.append( f"\tconst {BITS_TO_DTYPE[self.num_bits]} v{prefix_sums[layer_id] + var_id} = {self.get_gate_code(a, b, gate_op)};" )