From 9f1c8a3fc199ea8d8d876a4edcc754f44f472ae9 Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Fri, 17 Jan 2025 17:52:35 +0100 Subject: [PATCH 1/9] don't use key_stack for test_two_entities --- deathstar/test_demo.py | 96 +++--- src/cascade/dataflow/dataflow.py | 55 ++- src/cascade/runtime/flink_runtime.py | 35 +- tests/integration/flink-runtime/common.py | 69 ++-- .../flink-runtime/test_select_all.py | 320 +++++++++--------- .../flink-runtime/test_two_entities.py | 20 +- 6 files changed, 308 insertions(+), 287 deletions(-) diff --git a/deathstar/test_demo.py b/deathstar/test_demo.py index a7e674e..f709948 100644 --- a/deathstar/test_demo.py +++ b/deathstar/test_demo.py @@ -1,49 +1,49 @@ -from deathstar.demo import DeathstarDemo, DeathstarClient -import time -import pytest - -@pytest.mark.integration -def test_deathstar_demo(): - ds = DeathstarDemo("deathstardemo-test", "dsd-out") - ds.init_runtime() - ds.runtime.run(run_async=True) - print("Populating, press enter to go to the next step when done") - ds.populate() - - client = DeathstarClient("deathstardemo-test", "dsd-out") - input() - print("testing user login") - event = client.user_login() - client.send(event) - - input() - print("testing reserve") - event = client.reserve() - client.send(event) - - input() - print("testing search") - event = client.search_hotel() - client.send(event) - - input() - print("testing recommend (distance)") - time.sleep(0.5) - event = client.recommend(req_param="distance") - client.send(event) - - input() - print("testing recommend (price)") - time.sleep(0.5) - event = client.recommend(req_param="price") - client.send(event) - - print(client.client._futures) - input() - print("done!") - print(client.client._futures) - - -if __name__ == "__main__": - test_deathstar_demo() \ No newline at end of file +# from deathstar.demo import DeathstarDemo, DeathstarClient +# import time +# import pytest + +# @pytest.mark.integration +# def test_deathstar_demo(): +# ds = DeathstarDemo("deathstardemo-test", "dsd-out") +# ds.init_runtime() +# ds.runtime.run(run_async=True) +# print("Populating, press enter to go to the next step when done") +# ds.populate() + +# client = DeathstarClient("deathstardemo-test", "dsd-out") +# input() +# print("testing user login") +# event = client.user_login() +# client.send(event) + +# input() +# print("testing reserve") +# event = client.reserve() +# client.send(event) + +# input() +# print("testing search") +# event = client.search_hotel() +# client.send(event) + +# input() +# print("testing recommend (distance)") +# time.sleep(0.5) +# event = client.recommend(req_param="distance") +# client.send(event) + +# input() +# print("testing recommend (price)") +# time.sleep(0.5) +# event = client.recommend(req_param="price") +# client.send(event) + +# print(client.client._futures) +# input() +# print("done!") +# print(client.client._futures) + + +# if __name__ == "__main__": +# test_deathstar_demo() \ No newline at end of file diff --git a/src/cascade/dataflow/dataflow.py b/src/cascade/dataflow/dataflow.py index 2e3d890..ee18005 100644 --- a/src/cascade/dataflow/dataflow.py +++ b/src/cascade/dataflow/dataflow.py @@ -23,10 +23,10 @@ class Filter: @dataclass class Node(ABC): """Base class for Nodes.""" + id: int = field(init=False) """This node's unique id.""" - _id_counter: int = field(init=False, default=0, repr=False) outgoing_edges: list['Edge'] = field(init=False, default_factory=list, repr=False) @@ -41,8 +41,27 @@ class OpNode(Node): A `Dataflow` may reference the same `StatefulOperator` multiple times. The `StatefulOperator` that this node belongs to is referenced by `cls`.""" - operator: Operator + entity: Type method_type: Union[InitClass, InvokeMethod, Filter] + read_key_from: str + """Which variable to take as the key for this StatefulOperator""" + assign_result_to: Optional[str] = field(default=None) + """What variable to assign the result of this node to, if any.""" + is_conditional: bool = field(default=False) + """Whether or not the boolean result of this node dictates the following path.""" + collect_target: Optional['CollectTarget'] = field(default=None) + """Whether the result of this node should go to a CollectNode.""" + +@dataclass +class StatelessOpNode(Node): + """A node in a `Dataflow` corresponding to a method call of a `StatelessOperator`. + + A `Dataflow` may reference the same `StatefulOperator` multiple times. + The `StatefulOperator` that this node belongs to is referenced by `cls`.""" + dataflow: 'DataFlow' + method_type: InvokeMethod + """Which variable to take as the key for this StatefulOperator""" + assign_result_to: Optional[str] = None is_conditional: bool = False """Whether or not the boolean result of this node dictates the following path.""" @@ -176,7 +195,7 @@ class Event(): target: 'Node' """The Node that this Event wants to go to.""" - key_stack: list[str] + # key_stack: list[str] """The keys this event is concerned with. The top of the stack, i.e. `key_stack[-1]`, should always correspond to a key on the StatefulOperator of `target.cls` if `target` is an `OpNode`.""" @@ -203,7 +222,7 @@ def __post_init__(self): self._id = Event._id_counter Event._id_counter += 1 - def propogate(self, key_stack, result) -> Union['EventResult', list['Event']]: + def propogate(self, result) -> Union['EventResult', list['Event']]: """Propogate this event through the Dataflow.""" # TODO: keys should be structs containing Key and Opnode (as we need to know the entity (cls) and method to invoke for that particular key) @@ -216,23 +235,23 @@ def propogate(self, key_stack, result) -> Union['EventResult', list['Event']]: if len(targets) == 0: return EventResult(self._id, result) else: - keys = key_stack.pop() - if not isinstance(keys, list): - keys = [keys] + # keys = key_stack.pop() + # if not isinstance(keys, list): + # keys = [keys] collect_targets: list[Optional[CollectTarget]] # Events with SelectAllNodes need to be assigned a CollectTarget if isinstance(self.target, SelectAllNode): collect_targets = [ - CollectTarget(self.target.collect_target, len(keys), i) - for i in range(len(keys)) + CollectTarget(self.target.collect_target, len(targets), i) + for i in range(len(targets)) ] elif isinstance(self.target, OpNode) and self.target.collect_target is not None: collect_targets = [ - self.target.collect_target for i in range(len(keys)) + self.target.collect_target for i in range(len(targets)) ] else: - collect_targets = [self.collect_target for i in range(len(keys))] + collect_targets = [self.collect_target for i in range(len(targets))] if isinstance(self.target, OpNode) and self.target.is_conditional: # In this case there will be two targets depending on the condition @@ -249,13 +268,13 @@ def propogate(self, key_stack, result) -> Union['EventResult', list['Event']]: return [Event( target_true if result else target_false, - key_stack + [key], + # key_stack + [key], self.variable_map, self.dataflow, _id=self._id, collect_target=ct) - for key, ct in zip(keys, collect_targets)] + for ct in collect_targets] elif len(targets) == 1: # We assume that all keys need to go to the same target @@ -263,26 +282,26 @@ def propogate(self, key_stack, result) -> Union['EventResult', list['Event']]: return [Event( targets[0], - key_stack + [key], + # key_stack + [key], self.variable_map, self.dataflow, _id=self._id, collect_target=ct) - for key, ct in zip(keys, collect_targets)] + for ct in collect_targets] else: # An event with multiple targets should have the same number of # keys in a list on top of its key stack - assert len(targets) == len(keys) + # assert len(targets) == len(keys) return [Event( target, - key_stack + [key], + # key_stack + [key], self.variable_map, self.dataflow, _id=self._id, collect_target=ct) - for target, key, ct in zip(targets, keys, collect_targets)] + for target, ct in zip(targets, collect_targets)] @dataclass class EventResult(): diff --git a/src/cascade/runtime/flink_runtime.py b/src/cascade/runtime/flink_runtime.py index 763582e..15e5a8f 100644 --- a/src/cascade/runtime/flink_runtime.py +++ b/src/cascade/runtime/flink_runtime.py @@ -12,7 +12,7 @@ from pyflink.datastream.connectors.kafka import KafkaOffsetsInitializer, KafkaRecordSerializationSchema, KafkaSource, KafkaSink from pyflink.datastream import ProcessFunction, StreamExecutionEnvironment import pickle -from cascade.dataflow.dataflow import Arrived, CollectNode, CollectTarget, Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, Node, NotArrived, OpNode, Operator, Result, SelectAllNode +from cascade.dataflow.dataflow import Arrived, CollectNode, CollectTarget, Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, Node, NotArrived, OpNode, Operator, Result, SelectAllNode, StatelessOpNode from cascade.dataflow.operator import StatefulOperator, StatelessOperator from confluent_kafka import Producer, Consumer import logging @@ -49,12 +49,13 @@ def open(self, runtime_context: RuntimeContext): self.state: ValueState = runtime_context.get_state(descriptor) def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): - key_stack = event.key_stack + # key_stack = event.key_stack # should be handled by filters on this FlinkOperator assert(isinstance(event.target, OpNode)) - assert(isinstance(event.target.operator, StatefulOperator)) - assert(event.target.operator.entity == self.operator.entity) + assert(event.target.entity == self.operator.entity) + key = ctx.get_current_key() + assert(key is not None) logger.debug(f"FlinkOperator {self.operator.entity.__name__}[{ctx.get_current_key()}]: Processing: {event.target.method_type}") if isinstance(event.target.method_type, InitClass): @@ -64,8 +65,8 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): # Register the created key in FlinkSelectAllOperator register_key_event = Event( - FlinkRegisterKeyNode(key_stack[-1], self.operator.entity), - [], + FlinkRegisterKeyNode(key, self.operator.entity), + # [], {}, None, _id = event._id @@ -74,11 +75,11 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): yield register_key_event # Pop this key from the key stack so that we exit - key_stack.pop() + # key_stack.pop() self.state.update(pickle.dumps(result)) elif isinstance(event.target.method_type, InvokeMethod): state = pickle.loads(self.state.value()) - result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, state=state, key_stack=key_stack) + result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, state=state, key_stack=[]) # TODO: check if state actually needs to be updated if state is not None: @@ -93,7 +94,7 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): if event.target.assign_result_to is not None: event.variable_map[event.target.assign_result_to] = result - new_events = event.propogate(key_stack, result) + new_events = event.propogate(result) if isinstance(new_events, EventResult): logger.debug(f"FlinkOperator {self.operator.entity.__name__}[{ctx.get_current_key()}]: Returned {new_events}") yield new_events @@ -113,8 +114,7 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): key_stack = event.key_stack # should be handled by filters on this FlinkOperator - assert(isinstance(event.target, OpNode)) - assert(isinstance(event.target.operator, StatelessOperator)) + assert(isinstance(event.target, StatelessOpNode)) logger.debug(f"FlinkStatelessOperator {self.operator.dataflow.name}[{event._id}]: Processing: {event.target.method_type}") if isinstance(event.target.method_type, InvokeMethod): @@ -422,18 +422,17 @@ def init(self, kafka_broker="localhost:9092", bundle_time=1, bundle_size=5, para not (isinstance(e.target, SelectAllNode) or isinstance(e.target, FlinkRegisterKeyNode))) ) - event_stream_2 = select_all_stream.union(not_select_all_stream) + operator_stream = select_all_stream.union(not_select_all_stream) - operator_stream = event_stream_2.filter(lambda e: isinstance(e.target, OpNode)).name("OPERATOR STREAM") self.stateful_op_stream = ( operator_stream - .filter(lambda e: isinstance(e.target.operator, StatefulOperator)) + .filter(lambda e: isinstance(e.target, OpNode)) ) self.stateless_op_stream = ( operator_stream - .filter(lambda e: isinstance(e.target.operator, StatelessOperator)) + .filter(lambda e: isinstance(e.target, StatelessOpNode)) ) self.merge_op_stream = ( @@ -455,8 +454,8 @@ def add_operator(self, flink_op: FlinkOperator): """Add a `FlinkOperator` to the Flink datastream.""" op_stream = ( - self.stateful_op_stream.filter(lambda e: e.target.operator.entity == flink_op.operator.entity) - .key_by(lambda e: e.key_stack[-1]) + self.stateful_op_stream.filter(lambda e: e.target.entity == flink_op.operator.entity) + .key_by(lambda e: e.variable_map[e.target.read_key_from]) .process(flink_op) .name("STATEFUL OP: " + flink_op.operator.entity.__name__) ) @@ -467,7 +466,7 @@ def add_stateless_operator(self, flink_op: FlinkStatelessOperator): op_stream = ( self.stateless_op_stream - .filter(lambda e: e.target.operator.dataflow.name == flink_op.operator.dataflow.name) + .filter(lambda e: e.target.dataflow.name == flink_op.operator.dataflow.name) .process(flink_op) .name("STATELESS DATAFLOW: " + flink_op.operator.dataflow.name) ) diff --git a/tests/integration/flink-runtime/common.py b/tests/integration/flink-runtime/common.py index 105fdbd..5a63bdb 100644 --- a/tests/integration/flink-runtime/common.py +++ b/tests/integration/flink-runtime/common.py @@ -40,25 +40,25 @@ def __repr__(self): return f"Item(key='{self.key}', price={self.price})" def update_balance_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() # final function + # key_stack.pop() # final function state.balance += variable_map["amount"] return state.balance >= 0 def get_balance_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() # final function + # key_stack.pop() # final function return state.balance def get_price_compiled(variable_map: dict[str, Any], state: Item, key_stack: list[str]) -> Any: - key_stack.pop() # final function + # key_stack.pop() # final function return state.price # Items (or other operators) are passed by key always def buy_item_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map["item_key"]) + # key_stack.append(variable_map["item_key"]) return None def buy_item_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() + # key_stack.pop() state.balance = state.balance - variable_map["item_price"] return state.balance >= 0 @@ -94,40 +94,43 @@ def buy_2_items_1_compiled(variable_map: dict[str, Any], state: User, key_stack: def user_buy_item_df(): df = DataFlow("user.buy_item") - n0 = OpNode(user_op, InvokeMethod("buy_item_0")) - n1 = OpNode(item_op, InvokeMethod("get_price"), assign_result_to="item_price") - n2 = OpNode(user_op, InvokeMethod("buy_item_1")) + n0 = OpNode(User, InvokeMethod("buy_item_0"), read_key_from="user_key") + n1 = OpNode(Item, + InvokeMethod("get_price"), + assign_result_to="item_price", + read_key_from="item_key") + n2 = OpNode(User, InvokeMethod("buy_item_1"), read_key_from="user_key") df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) df.entry = n0 return df -def user_buy_2_items_df(): - df = DataFlow("user.buy_2_items") - n0 = OpNode(user_op, InvokeMethod("buy_2_items_0")) - n3 = CollectNode(assign_result_to="item_prices", read_results_from="item_price") - n1 = OpNode( - item_op, - InvokeMethod("get_price"), - assign_result_to="item_price", - collect_target=CollectTarget(n3, 2, 0) - ) - n2 = OpNode( - item_op, - InvokeMethod("get_price"), - assign_result_to="item_price", - collect_target=CollectTarget(n3, 2, 1) - ) - n4 = OpNode(user_op, InvokeMethod("buy_2_items_1")) - df.add_edge(Edge(n0, n1)) - df.add_edge(Edge(n0, n2)) - df.add_edge(Edge(n1, n3)) - df.add_edge(Edge(n2, n3)) - df.add_edge(Edge(n3, n4)) - df.entry = n0 - return df +# def user_buy_2_items_df(): +# df = DataFlow("user.buy_2_items") +# n0 = OpNode(user_op, InvokeMethod("buy_2_items_0")) +# n3 = CollectNode(assign_result_to="item_prices", read_results_from="item_price") +# n1 = OpNode( +# item_op, +# InvokeMethod("get_price"), +# assign_result_to="item_price", +# collect_target=CollectTarget(n3, 2, 0) +# ) +# n2 = OpNode( +# item_op, +# InvokeMethod("get_price"), +# assign_result_to="item_price", +# collect_target=CollectTarget(n3, 2, 1) +# ) +# n4 = OpNode(user_op, InvokeMethod("buy_2_items_1")) +# df.add_edge(Edge(n0, n1)) +# df.add_edge(Edge(n0, n2)) +# df.add_edge(Edge(n1, n3)) +# df.add_edge(Edge(n2, n3)) +# df.add_edge(Edge(n3, n4)) +# df.entry = n0 +# return df user_op.dataflows = { - "buy_2_items": user_buy_2_items_df(), + # "buy_2_items": user_buy_2_items_df(), "buy_item": user_buy_item_df() } \ No newline at end of file diff --git a/tests/integration/flink-runtime/test_select_all.py b/tests/integration/flink-runtime/test_select_all.py index 62c371e..2b4de65 100644 --- a/tests/integration/flink-runtime/test_select_all.py +++ b/tests/integration/flink-runtime/test_select_all.py @@ -1,164 +1,164 @@ -""" -Basically we need a way to search through all state. -""" -import math -import random -from dataclasses import dataclass -from typing import Any - -from pyflink.datastream.data_stream import CloseableIterator - -from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, OpNode, SelectAllNode -from cascade.dataflow.operator import StatefulOperator, StatelessOperator -from cascade.runtime.flink_runtime import FlinkOperator, FlinkRuntime, FlinkStatelessOperator -from confluent_kafka import Producer -import time -import pytest - -@dataclass -class Geo: - x: int - y: int - -class Hotel: - def __init__(self, name: str, loc: Geo): - self.name = name - self.loc = loc - - def get_name(self) -> str: - return self.name +# """ +# Basically we need a way to search through all state. +# """ +# import math +# import random +# from dataclasses import dataclass +# from typing import Any + +# from pyflink.datastream.data_stream import CloseableIterator + +# from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, OpNode, SelectAllNode +# from cascade.dataflow.operator import StatefulOperator, StatelessOperator +# from cascade.runtime.flink_runtime import FlinkOperator, FlinkRuntime, FlinkStatelessOperator +# from confluent_kafka import Producer +# import time +# import pytest + +# @dataclass +# class Geo: +# x: int +# y: int + +# class Hotel: +# def __init__(self, name: str, loc: Geo): +# self.name = name +# self.loc = loc + +# def get_name(self) -> str: +# return self.name - def distance(self, loc: Geo) -> float: - return math.sqrt((self.loc.x - loc.x) ** 2 + (self.loc.y - loc.y) ** 2) +# def distance(self, loc: Geo) -> float: +# return math.sqrt((self.loc.x - loc.x) ** 2 + (self.loc.y - loc.y) ** 2) - def __repr__(self) -> str: - return f"Hotel({self.name}, {self.loc})" - - -def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: - key_stack.pop() - loc = variable_map["loc"] - return math.sqrt((state.loc.x - loc.x) ** 2 + (state.loc.y - loc.y) ** 2) - -def get_name_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: - key_stack.pop() - return state.name - -hotel_op = StatefulOperator(Hotel, - {"distance": distance_compiled, - "get_name": get_name_compiled}, {}) - - - -def get_nearby(hotels: list[Hotel], loc: Geo, dist: float): - return [hotel.get_name() for hotel in hotels if hotel.distance(loc) < dist] - - -# We compile just the predicate, the select is implemented using a selectall node -def get_nearby_predicate_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): - # the top of the key_stack is already the right key, so in this case we don't need to do anything - # loc = variable_map["loc"] - # we need the hotel_key for later. (body_compiled_0) - variable_map["hotel_key"] = key_stack[-1] - pass - -def get_nearby_predicate_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> bool: - loc = variable_map["loc"] - dist = variable_map["dist"] - hotel_dist = variable_map["hotel_distance"] - # key_stack.pop() # shouldn't pop because this function is stateless - return hotel_dist < dist - -def get_nearby_body_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): - key_stack.append(variable_map["hotel_key"]) - -def get_nearby_body_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> str: - return variable_map["hotel_name"] - -get_nearby_op = StatelessOperator({ - "get_nearby_predicate_compiled_0": get_nearby_predicate_compiled_0, - "get_nearby_predicate_compiled_1": get_nearby_predicate_compiled_1, - "get_nearby_body_compiled_0": get_nearby_body_compiled_0, - "get_nearby_body_compiled_1": get_nearby_body_compiled_1, -}, None) - -# dataflow for getting all hotels within region -df = DataFlow("get_nearby") -n7 = CollectNode("get_nearby_result", "get_nearby_body") -n0 = SelectAllNode(Hotel, n7) -n1 = OpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_0")) -n2 = OpNode(hotel_op, InvokeMethod("distance"), assign_result_to="hotel_distance") -n3 = OpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_1"), is_conditional=True) -n4 = OpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_0")) -n5 = OpNode(hotel_op, InvokeMethod("get_name"), assign_result_to="hotel_name") -n6 = OpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_1"), assign_result_to="get_nearby_body") - -df.add_edge(Edge(n0, n1)) -df.add_edge(Edge(n1, n2)) -df.add_edge(Edge(n2, n3)) -df.add_edge(Edge(n3, n4, if_conditional=True)) -df.add_edge(Edge(n3, n7, if_conditional=False)) -df.add_edge(Edge(n4, n5)) -df.add_edge(Edge(n5, n6)) -df.add_edge(Edge(n6, n7)) -get_nearby_op.dataflow = df - -@pytest.mark.integration -def test_nearby_hotels(): - runtime = FlinkRuntime("test_nearby_hotels") - runtime.init() - runtime.add_operator(FlinkOperator(hotel_op)) - runtime.add_stateless_operator(FlinkStatelessOperator(get_nearby_op)) - - # Create Hotels - hotels = [] - init_hotel = OpNode(hotel_op, InitClass()) - random.seed(42) - for i in range(20): - coord_x = random.randint(-10, 10) - coord_y = random.randint(-10, 10) - hotel = Hotel(f"h_{i}", Geo(coord_x, coord_y)) - event = Event(init_hotel, [hotel.name], {"name": hotel.name, "loc": hotel.loc}, None) - runtime.send(event) - hotels.append(hotel) - - collected_iterator: CloseableIterator = runtime.run(run_async=True, collect=True) - records = [] - def wait_for_event_id(id: int) -> EventResult: - for record in collected_iterator: - records.append(record) - print(f"Collected record: {record}") - if record.event_id == id: - return record +# def __repr__(self) -> str: +# return f"Hotel({self.name}, {self.loc})" + + +# def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: +# key_stack.pop() +# loc = variable_map["loc"] +# return math.sqrt((state.loc.x - loc.x) ** 2 + (state.loc.y - loc.y) ** 2) + +# def get_name_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: +# key_stack.pop() +# return state.name + +# hotel_op = StatefulOperator(Hotel, +# {"distance": distance_compiled, +# "get_name": get_name_compiled}, {}) + + + +# def get_nearby(hotels: list[Hotel], loc: Geo, dist: float): +# return [hotel.get_name() for hotel in hotels if hotel.distance(loc) < dist] + + +# # We compile just the predicate, the select is implemented using a selectall node +# def get_nearby_predicate_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): +# # the top of the key_stack is already the right key, so in this case we don't need to do anything +# # loc = variable_map["loc"] +# # we need the hotel_key for later. (body_compiled_0) +# variable_map["hotel_key"] = key_stack[-1] +# pass + +# def get_nearby_predicate_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> bool: +# loc = variable_map["loc"] +# dist = variable_map["dist"] +# hotel_dist = variable_map["hotel_distance"] +# # key_stack.pop() # shouldn't pop because this function is stateless +# return hotel_dist < dist + +# def get_nearby_body_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): +# key_stack.append(variable_map["hotel_key"]) + +# def get_nearby_body_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> str: +# return variable_map["hotel_name"] + +# get_nearby_op = StatelessOperator({ +# "get_nearby_predicate_compiled_0": get_nearby_predicate_compiled_0, +# "get_nearby_predicate_compiled_1": get_nearby_predicate_compiled_1, +# "get_nearby_body_compiled_0": get_nearby_body_compiled_0, +# "get_nearby_body_compiled_1": get_nearby_body_compiled_1, +# }, None) + +# # dataflow for getting all hotels within region +# df = DataFlow("get_nearby") +# n7 = CollectNode("get_nearby_result", "get_nearby_body") +# n0 = SelectAllNode(Hotel, n7) +# n1 = OpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_0")) +# n2 = OpNode(hotel_op, InvokeMethod("distance"), assign_result_to="hotel_distance") +# n3 = OpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_1"), is_conditional=True) +# n4 = OpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_0")) +# n5 = OpNode(hotel_op, InvokeMethod("get_name"), assign_result_to="hotel_name") +# n6 = OpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_1"), assign_result_to="get_nearby_body") + +# df.add_edge(Edge(n0, n1)) +# df.add_edge(Edge(n1, n2)) +# df.add_edge(Edge(n2, n3)) +# df.add_edge(Edge(n3, n4, if_conditional=True)) +# df.add_edge(Edge(n3, n7, if_conditional=False)) +# df.add_edge(Edge(n4, n5)) +# df.add_edge(Edge(n5, n6)) +# df.add_edge(Edge(n6, n7)) +# get_nearby_op.dataflow = df + +# @pytest.mark.integration +# def test_nearby_hotels(): +# runtime = FlinkRuntime("test_nearby_hotels") +# runtime.init() +# runtime.add_operator(FlinkOperator(hotel_op)) +# runtime.add_stateless_operator(FlinkStatelessOperator(get_nearby_op)) + +# # Create Hotels +# hotels = [] +# init_hotel = OpNode(hotel_op, InitClass()) +# random.seed(42) +# for i in range(20): +# coord_x = random.randint(-10, 10) +# coord_y = random.randint(-10, 10) +# hotel = Hotel(f"h_{i}", Geo(coord_x, coord_y)) +# event = Event(init_hotel, [hotel.name], {"name": hotel.name, "loc": hotel.loc}, None) +# runtime.send(event) +# hotels.append(hotel) + +# collected_iterator: CloseableIterator = runtime.run(run_async=True, collect=True) +# records = [] +# def wait_for_event_id(id: int) -> EventResult: +# for record in collected_iterator: +# records.append(record) +# print(f"Collected record: {record}") +# if record.event_id == id: +# return record - def wait_for_n_records(num: int) -> list[EventResult]: - i = 0 - n_records = [] - for record in collected_iterator: - i += 1 - records.append(record) - n_records.append(record) - print(f"Collected record: {record}") - if i == num: - return n_records - - print("creating hotels") - # Wait for hotels to be created - wait_for_n_records(20) - time.sleep(3) # wait for all hotels to be registered - - dist = 5 - loc = Geo(0, 0) - # because of how the key stack works, we need to supply a key here - event = Event(n0, ["workaround_key"], {"loc": loc, "dist": dist}, df) - runtime.send(event, flush=True) +# def wait_for_n_records(num: int) -> list[EventResult]: +# i = 0 +# n_records = [] +# for record in collected_iterator: +# i += 1 +# records.append(record) +# n_records.append(record) +# print(f"Collected record: {record}") +# if i == num: +# return n_records + +# print("creating hotels") +# # Wait for hotels to be created +# wait_for_n_records(20) +# time.sleep(3) # wait for all hotels to be registered + +# dist = 5 +# loc = Geo(0, 0) +# # because of how the key stack works, we need to supply a key here +# event = Event(n0, ["workaround_key"], {"loc": loc, "dist": dist}, df) +# runtime.send(event, flush=True) - nearby = [] - for hotel in hotels: - if hotel.distance(loc) < dist: - nearby.append(hotel.name) - - event_result = wait_for_event_id(event._id) - results = [r for r in event_result.result if r != None] - print(nearby) - assert set(results) == set(nearby) \ No newline at end of file +# nearby = [] +# for hotel in hotels: +# if hotel.distance(loc) < dist: +# nearby.append(hotel.name) + +# event_result = wait_for_event_id(event._id) +# results = [r for r in event_result.result if r != None] +# print(nearby) +# assert set(results) == set(nearby) \ No newline at end of file diff --git a/tests/integration/flink-runtime/test_two_entities.py b/tests/integration/flink-runtime/test_two_entities.py index 9d2e0cf..54309fa 100644 --- a/tests/integration/flink-runtime/test_two_entities.py +++ b/tests/integration/flink-runtime/test_two_entities.py @@ -15,19 +15,19 @@ def test_two_entities(): # Create a User object foo_user = User("foo", 100) - init_user_node = OpNode(user_op, InitClass()) - event = Event(init_user_node, ["foo"], {"key": "foo", "balance": 100}, None) + init_user_node = OpNode(User, InitClass(), read_key_from="key") + event = Event(init_user_node, {"key": "foo", "balance": 100}, None) runtime.send(event) # Create an Item object fork_item = Item("fork", 5) - init_item_node = OpNode(item_op, InitClass()) - event = Event(init_item_node, ["fork"], {"key": "fork", "price": 5}, None) + init_item_node = OpNode(Item, InitClass(), read_key_from="key") + event = Event(init_item_node, {"key": "fork", "price": 5}, None) runtime.send(event) # Create an expensive Item house_item = Item("house", 1000) - event = Event(init_item_node, ["house"], {"key": "house", "price": 1000}, None) + event = Event(init_item_node, {"key": "house", "price": 1000}, None) runtime.send(event) # Have the User object buy the item @@ -35,10 +35,10 @@ def test_two_entities(): df = user_op.dataflows["buy_item"] # User with key "foo" buys item with key "fork" - user_buys_fork = Event(df.entry, ["foo"], {"item_key": "fork"}, df) + user_buys_fork = Event(df.entry, {"user_key": "foo", "item_key": "fork"}, df) runtime.send(user_buys_fork, flush=True) - collected_iterator: CloseableIterator = runtime.run(run_async=True, collect=True) + collected_iterator: CloseableIterator = runtime.run(run_async=True, output="collect") records = [] def wait_for_event_id(id: int) -> EventResult: @@ -53,8 +53,8 @@ def wait_for_event_id(id: int) -> EventResult: assert buy_fork_result.result == True # Send an event to check if the balance was updated - user_get_balance_node = OpNode(user_op, InvokeMethod("get_balance")) - user_get_balance = Event(user_get_balance_node, ["foo"], {}, None) + user_get_balance_node = OpNode(User, InvokeMethod("get_balance"), read_key_from="key") + user_get_balance = Event(user_get_balance_node, {"key": "foo"}, None) runtime.send(user_get_balance, flush=True) # See that the user's balance has gone down @@ -63,7 +63,7 @@ def wait_for_event_id(id: int) -> EventResult: # User with key "foo" buys item with key "house" foo_user.buy_item(house_item) - user_buys_house = Event(df.entry, ["foo"], {"item_key": "house"}, df) + user_buys_house = Event(df.entry, {"user_key": "foo", "item_key": "house"}, df) runtime.send(user_buys_house, flush=True) # Balance becomes negative when house is bought From b2d33293ff04c443aacf3ebffe29c060efdc2a21 Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 10:49:34 +0100 Subject: [PATCH 2/9] remove key stack from collect operator --- src/cascade/runtime/flink_runtime.py | 2 +- tests/integration/flink-runtime/common.py | 56 ++++--- ...e_operator.py => test_collect_operator.py} | 140 +++++++++--------- 3 files changed, 98 insertions(+), 100 deletions(-) rename tests/integration/flink-runtime/{test_merge_operator.py => test_collect_operator.py} (71%) diff --git a/src/cascade/runtime/flink_runtime.py b/src/cascade/runtime/flink_runtime.py index 15e5a8f..2a04b34 100644 --- a/src/cascade/runtime/flink_runtime.py +++ b/src/cascade/runtime/flink_runtime.py @@ -209,7 +209,7 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): collection = [r.val for r in collection if r.val is not None] # type: ignore (r is of type Arrived) event.variable_map[target_node.assign_result_to] = collection - new_events = event.propogate(event.key_stack, collection) + new_events = event.propogate(collection) self.collection.clear() if isinstance(new_events, EventResult): diff --git a/tests/integration/flink-runtime/common.py b/tests/integration/flink-runtime/common.py index 5a63bdb..7a676e2 100644 --- a/tests/integration/flink-runtime/common.py +++ b/tests/integration/flink-runtime/common.py @@ -64,13 +64,9 @@ def buy_item_1_compiled(variable_map: dict[str, Any], state: User, key_stack: li def buy_2_items_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append( - [variable_map["item1_key"], variable_map["item2_key"]] - ) return None def buy_2_items_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() state.balance -= variable_map["item_prices"][0] + variable_map["item_prices"][1] return state.balance >= 0 @@ -105,32 +101,34 @@ def user_buy_item_df(): df.entry = n0 return df -# def user_buy_2_items_df(): -# df = DataFlow("user.buy_2_items") -# n0 = OpNode(user_op, InvokeMethod("buy_2_items_0")) -# n3 = CollectNode(assign_result_to="item_prices", read_results_from="item_price") -# n1 = OpNode( -# item_op, -# InvokeMethod("get_price"), -# assign_result_to="item_price", -# collect_target=CollectTarget(n3, 2, 0) -# ) -# n2 = OpNode( -# item_op, -# InvokeMethod("get_price"), -# assign_result_to="item_price", -# collect_target=CollectTarget(n3, 2, 1) -# ) -# n4 = OpNode(user_op, InvokeMethod("buy_2_items_1")) -# df.add_edge(Edge(n0, n1)) -# df.add_edge(Edge(n0, n2)) -# df.add_edge(Edge(n1, n3)) -# df.add_edge(Edge(n2, n3)) -# df.add_edge(Edge(n3, n4)) -# df.entry = n0 -# return df +def user_buy_2_items_df(): + df = DataFlow("user.buy_2_items") + n0 = OpNode(User, InvokeMethod("buy_2_items_0"), read_key_from="user_key") + n3 = CollectNode(assign_result_to="item_prices", read_results_from="item_price") + n1 = OpNode( + Item, + InvokeMethod("get_price"), + assign_result_to="item_price", + collect_target=CollectTarget(n3, 2, 0), + read_key_from="item1_key" + ) + n2 = OpNode( + Item, + InvokeMethod("get_price"), + assign_result_to="item_price", + collect_target=CollectTarget(n3, 2, 1), + read_key_from="item2_key" + ) + n4 = OpNode(User, InvokeMethod("buy_2_items_1"), read_key_from="user_key") + df.add_edge(Edge(n0, n1)) + df.add_edge(Edge(n0, n2)) + df.add_edge(Edge(n1, n3)) + df.add_edge(Edge(n2, n3)) + df.add_edge(Edge(n3, n4)) + df.entry = n0 + return df user_op.dataflows = { - # "buy_2_items": user_buy_2_items_df(), + "buy_2_items": user_buy_2_items_df(), "buy_item": user_buy_item_df() } \ No newline at end of file diff --git a/tests/integration/flink-runtime/test_merge_operator.py b/tests/integration/flink-runtime/test_collect_operator.py similarity index 71% rename from tests/integration/flink-runtime/test_merge_operator.py rename to tests/integration/flink-runtime/test_collect_operator.py index d136d99..574c739 100644 --- a/tests/integration/flink-runtime/test_merge_operator.py +++ b/tests/integration/flink-runtime/test_collect_operator.py @@ -1,71 +1,71 @@ -"""A test script for dataflows with merge operators""" - -from pyflink.datastream.data_stream import CloseableIterator -from common import Item, User, item_op, user_op -from cascade.dataflow.dataflow import Event, EventResult, InitClass, InvokeMethod, OpNode -from cascade.runtime.flink_runtime import FlinkOperator, FlinkRuntime -import pytest - -@pytest.mark.integration -def test_merge_operator(): - runtime = FlinkRuntime("test_merge_operator") - runtime.init() - runtime.add_operator(FlinkOperator(item_op)) - runtime.add_operator(FlinkOperator(user_op)) - - - # Create a User object - foo_user = User("foo", 100) - init_user_node = OpNode(user_op, InitClass()) - event = Event(init_user_node, ["foo"], {"key": "foo", "balance": 100}, None) - runtime.send(event) - - # Create an Item object - fork_item = Item("fork", 5) - init_item_node = OpNode(item_op, InitClass()) - event = Event(init_item_node, ["fork"], {"key": "fork", "price": 5}, None) - runtime.send(event) - - # Create another Item - spoon_item = Item("spoon", 3) - event = Event(init_item_node, ["spoon"], {"key": "spoon", "price": 3}, None) - runtime.send(event, flush=True) - - collected_iterator: CloseableIterator = runtime.run(run_async=True, collect=True) - records = [] - - def wait_for_event_id(id: int) -> EventResult: - for record in collected_iterator: - records.append(record) - print(f"Collected record: {record}") - if record.event_id == id: - return record - - # Make sure the user & items are initialised - wait_for_event_id(event._id) - - # Have the User object buy the item - foo_user.buy_2_items(fork_item, spoon_item) - df = user_op.dataflows["buy_2_items"] - - # User with key "foo" buys item with key "fork" - user_buys_cutlery = Event(df.entry, ["foo"], {"item1_key": "fork", "item2_key": "spoon"}, df) - runtime.send(user_buys_cutlery, flush=True) - - - # Check that we were able to buy the fork - buy_fork_result = wait_for_event_id(user_buys_cutlery._id) - assert buy_fork_result.result == True - - # Send an event to check if the balance was updated - user_get_balance_node = OpNode(user_op, InvokeMethod("get_balance")) - user_get_balance = Event(user_get_balance_node, ["foo"], {}, None) - runtime.send(user_get_balance, flush=True) - - # See that the user's balance has gone down - get_balance = wait_for_event_id(user_get_balance._id) - assert get_balance.result == 92 - - collected_iterator.close() - +"""A test script for dataflows with merge operators""" + +from pyflink.datastream.data_stream import CloseableIterator +from common import Item, User, item_op, user_op +from cascade.dataflow.dataflow import Event, EventResult, InitClass, InvokeMethod, OpNode +from cascade.runtime.flink_runtime import FlinkOperator, FlinkRuntime +import pytest + +@pytest.mark.integration +def test_merge_operator(): + runtime = FlinkRuntime("test_collect_operator") + runtime.init() + runtime.add_operator(FlinkOperator(item_op)) + runtime.add_operator(FlinkOperator(user_op)) + + + # Create a User object + foo_user = User("foo", 100) + init_user_node = OpNode(User, InitClass(), read_key_from="key") + event = Event(init_user_node, {"key": "foo", "balance": 100}, None) + runtime.send(event) + + # Create an Item object + fork_item = Item("fork", 5) + init_item_node = OpNode(Item, InitClass(), read_key_from="key") + event = Event(init_item_node, {"key": "fork", "price": 5}, None) + runtime.send(event) + + # Create another Item + spoon_item = Item("spoon", 3) + event = Event(init_item_node, {"key": "spoon", "price": 3}, None) + runtime.send(event, flush=True) + + collected_iterator: CloseableIterator = runtime.run(run_async=True, output="collect") + records = [] + + def wait_for_event_id(id: int) -> EventResult: + for record in collected_iterator: + records.append(record) + print(f"Collected record: {record}") + if record.event_id == id: + return record + + # Make sure the user & items are initialised + wait_for_event_id(event._id) + + # Have the User object buy the item + foo_user.buy_2_items(fork_item, spoon_item) + df = user_op.dataflows["buy_2_items"] + + # User with key "foo" buys item with key "fork" + user_buys_cutlery = Event(df.entry, {"user_key": "foo", "item1_key": "fork", "item2_key": "spoon"}, df) + runtime.send(user_buys_cutlery, flush=True) + + + # Check that we were able to buy the fork + buy_fork_result = wait_for_event_id(user_buys_cutlery._id) + assert buy_fork_result.result == True + + # Send an event to check if the balance was updated + user_get_balance_node = OpNode(User, InvokeMethod("get_balance"), read_key_from="key") + user_get_balance = Event(user_get_balance_node, {"key": "foo"}, None) + runtime.send(user_get_balance, flush=True) + + # See that the user's balance has gone down + get_balance = wait_for_event_id(user_get_balance._id) + assert get_balance.result == 92 + + collected_iterator.close() + print(records) \ No newline at end of file From 7d3677007c458da6ae4b91bbd4560bd533f2d38f Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 12:05:54 +0100 Subject: [PATCH 3/9] remove key stack from select all operator --- src/cascade/dataflow/dataflow.py | 34 +- src/cascade/runtime/flink_runtime.py | 14 +- .../flink-runtime/test_select_all.py | 316 +++++++++--------- 3 files changed, 184 insertions(+), 180 deletions(-) diff --git a/src/cascade/dataflow/dataflow.py b/src/cascade/dataflow/dataflow.py index ee18005..2e12559 100644 --- a/src/cascade/dataflow/dataflow.py +++ b/src/cascade/dataflow/dataflow.py @@ -39,8 +39,8 @@ def __post_init__(self): class OpNode(Node): """A node in a `Dataflow` corresponding to a method call of a `StatefulOperator`. - A `Dataflow` may reference the same `StatefulOperator` multiple times. - The `StatefulOperator` that this node belongs to is referenced by `cls`.""" + A `Dataflow` may reference the same entity multiple times. + The `StatefulOperator` that this node belongs to is referenced by `entity`.""" entity: Type method_type: Union[InitClass, InvokeMethod, Filter] read_key_from: str @@ -58,7 +58,7 @@ class StatelessOpNode(Node): A `Dataflow` may reference the same `StatefulOperator` multiple times. The `StatefulOperator` that this node belongs to is referenced by `cls`.""" - dataflow: 'DataFlow' + operator: Operator # should be StatelessOperator but circular import! method_type: InvokeMethod """Which variable to take as the key for this StatefulOperator""" @@ -76,6 +76,7 @@ class SelectAllNode(Node): Think of this as executing `SELECT * FROM cls`""" cls: Type collect_target: 'CollectNode' + assign_key_to: str @dataclass @@ -222,12 +223,10 @@ def __post_init__(self): self._id = Event._id_counter Event._id_counter += 1 - def propogate(self, result) -> Union['EventResult', list['Event']]: + def propogate(self, result, select_all_keys: Optional[list[str]]=None) -> Union['EventResult', list['Event']]: """Propogate this event through the Dataflow.""" - # TODO: keys should be structs containing Key and Opnode (as we need to know the entity (cls) and method to invoke for that particular key) - # the following method only works because we assume all the keys have the same entity and method - if self.dataflow is None:# or len(key_stack) == 0: + if self.dataflow is None: return EventResult(self._id, result) targets = self.dataflow.get_neighbors(self.target) @@ -235,17 +234,26 @@ def propogate(self, result) -> Union['EventResult', list['Event']]: if len(targets) == 0: return EventResult(self._id, result) else: - # keys = key_stack.pop() - # if not isinstance(keys, list): - # keys = [keys] collect_targets: list[Optional[CollectTarget]] # Events with SelectAllNodes need to be assigned a CollectTarget if isinstance(self.target, SelectAllNode): + assert select_all_keys + assert len(targets) == 1 + n = len(select_all_keys) collect_targets = [ - CollectTarget(self.target.collect_target, len(targets), i) - for i in range(len(targets)) + CollectTarget(self.target.collect_target, n, i) + for i in range(n) ] + return [Event( + targets[0], + self.variable_map | {self.target.assign_key_to: key}, + self.dataflow, + _id=self._id, + collect_target=ct) + + for ct, key in zip(collect_targets, select_all_keys)] + elif isinstance(self.target, OpNode) and self.target.collect_target is not None: collect_targets = [ self.target.collect_target for i in range(len(targets)) @@ -253,7 +261,7 @@ def propogate(self, result) -> Union['EventResult', list['Event']]: else: collect_targets = [self.collect_target for i in range(len(targets))] - if isinstance(self.target, OpNode) and self.target.is_conditional: + if (isinstance(self.target, OpNode) or isinstance(self.target, StatelessOpNode)) and self.target.is_conditional: # In this case there will be two targets depending on the condition edges = self.dataflow.nodes[self.target.id].outgoing_edges diff --git a/src/cascade/runtime/flink_runtime.py b/src/cascade/runtime/flink_runtime.py index 2a04b34..d9dfd33 100644 --- a/src/cascade/runtime/flink_runtime.py +++ b/src/cascade/runtime/flink_runtime.py @@ -111,21 +111,20 @@ def __init__(self, operator: StatelessOperator) -> None: def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): - key_stack = event.key_stack # should be handled by filters on this FlinkOperator assert(isinstance(event.target, StatelessOpNode)) logger.debug(f"FlinkStatelessOperator {self.operator.dataflow.name}[{event._id}]: Processing: {event.target.method_type}") if isinstance(event.target.method_type, InvokeMethod): - result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, key_stack=key_stack) + result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, key_stack=[]) else: raise Exception(f"A StatelessOperator cannot compute event type: {event.target.method_type}") if event.target.assign_result_to is not None: event.variable_map[event.target.assign_result_to] = result - new_events = event.propogate(key_stack, result) + new_events = event.propogate(result) if isinstance(new_events, EventResult): logger.debug(f"FlinkStatelessOperator {self.operator.dataflow.name}[{event._id}]: Returned {new_events}") yield new_events @@ -157,11 +156,13 @@ def process_element(self, event: Event, ctx: 'ProcessFunction.Context'): logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Selecting all") # Yield all the keys we now about - event.key_stack.append(state) + # event.key_stack.append(state) + new_keys = state num_events = len(state) # Propogate the event to the next node - new_events = event.propogate(event.key_stack, None) + new_events = event.propogate(None, select_all_keys=new_keys) + print(len(new_events), num_events) assert num_events == len(new_events) logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Propogated {num_events} events with target: {event.target.collect_target}") @@ -173,7 +174,6 @@ class FlinkCollectOperator(KeyedProcessFunction): """Flink implementation of a merge operator.""" def __init__(self): #, merge_node: MergeNode) -> None: self.collection: ValueState = None # type: ignore (expect state to be initialised on .open()) - #self.node = merge_node def open(self, runtime_context: RuntimeContext): descriptor = ValueStateDescriptor("merge_state", Types.PICKLED_BYTE_ARRAY()) @@ -466,7 +466,7 @@ def add_stateless_operator(self, flink_op: FlinkStatelessOperator): op_stream = ( self.stateless_op_stream - .filter(lambda e: e.target.dataflow.name == flink_op.operator.dataflow.name) + .filter(lambda e: e.target.operator.dataflow.name == flink_op.operator.dataflow.name) .process(flink_op) .name("STATELESS DATAFLOW: " + flink_op.operator.dataflow.name) ) diff --git a/tests/integration/flink-runtime/test_select_all.py b/tests/integration/flink-runtime/test_select_all.py index 2b4de65..9ade211 100644 --- a/tests/integration/flink-runtime/test_select_all.py +++ b/tests/integration/flink-runtime/test_select_all.py @@ -1,164 +1,160 @@ -# """ -# Basically we need a way to search through all state. -# """ -# import math -# import random -# from dataclasses import dataclass -# from typing import Any - -# from pyflink.datastream.data_stream import CloseableIterator - -# from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, OpNode, SelectAllNode -# from cascade.dataflow.operator import StatefulOperator, StatelessOperator -# from cascade.runtime.flink_runtime import FlinkOperator, FlinkRuntime, FlinkStatelessOperator -# from confluent_kafka import Producer -# import time -# import pytest - -# @dataclass -# class Geo: -# x: int -# y: int - -# class Hotel: -# def __init__(self, name: str, loc: Geo): -# self.name = name -# self.loc = loc - -# def get_name(self) -> str: -# return self.name +""" +The select all operator is used to fetch all keys for a single entity +""" +import math +import random +from dataclasses import dataclass +from typing import Any + +from pyflink.datastream.data_stream import CloseableIterator + +from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, Event, EventResult, InitClass, InvokeMethod, OpNode, SelectAllNode, StatelessOpNode +from cascade.dataflow.operator import StatefulOperator, StatelessOperator +from cascade.runtime.flink_runtime import FlinkOperator, FlinkRuntime, FlinkStatelessOperator +import time +import pytest + +@dataclass +class Geo: + x: int + y: int + +class Hotel: + def __init__(self, name: str, loc: Geo): + self.name = name + self.loc = loc + + def get_name(self) -> str: + return self.name -# def distance(self, loc: Geo) -> float: -# return math.sqrt((self.loc.x - loc.x) ** 2 + (self.loc.y - loc.y) ** 2) + def distance(self, loc: Geo) -> float: + return math.sqrt((self.loc.x - loc.x) ** 2 + (self.loc.y - loc.y) ** 2) -# def __repr__(self) -> str: -# return f"Hotel({self.name}, {self.loc})" - - -# def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: -# key_stack.pop() -# loc = variable_map["loc"] -# return math.sqrt((state.loc.x - loc.x) ** 2 + (state.loc.y - loc.y) ** 2) - -# def get_name_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: -# key_stack.pop() -# return state.name - -# hotel_op = StatefulOperator(Hotel, -# {"distance": distance_compiled, -# "get_name": get_name_compiled}, {}) - - - -# def get_nearby(hotels: list[Hotel], loc: Geo, dist: float): -# return [hotel.get_name() for hotel in hotels if hotel.distance(loc) < dist] - - -# # We compile just the predicate, the select is implemented using a selectall node -# def get_nearby_predicate_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): -# # the top of the key_stack is already the right key, so in this case we don't need to do anything -# # loc = variable_map["loc"] -# # we need the hotel_key for later. (body_compiled_0) -# variable_map["hotel_key"] = key_stack[-1] -# pass - -# def get_nearby_predicate_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> bool: -# loc = variable_map["loc"] -# dist = variable_map["dist"] -# hotel_dist = variable_map["hotel_distance"] -# # key_stack.pop() # shouldn't pop because this function is stateless -# return hotel_dist < dist - -# def get_nearby_body_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): -# key_stack.append(variable_map["hotel_key"]) - -# def get_nearby_body_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> str: -# return variable_map["hotel_name"] - -# get_nearby_op = StatelessOperator({ -# "get_nearby_predicate_compiled_0": get_nearby_predicate_compiled_0, -# "get_nearby_predicate_compiled_1": get_nearby_predicate_compiled_1, -# "get_nearby_body_compiled_0": get_nearby_body_compiled_0, -# "get_nearby_body_compiled_1": get_nearby_body_compiled_1, -# }, None) - -# # dataflow for getting all hotels within region -# df = DataFlow("get_nearby") -# n7 = CollectNode("get_nearby_result", "get_nearby_body") -# n0 = SelectAllNode(Hotel, n7) -# n1 = OpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_0")) -# n2 = OpNode(hotel_op, InvokeMethod("distance"), assign_result_to="hotel_distance") -# n3 = OpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_1"), is_conditional=True) -# n4 = OpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_0")) -# n5 = OpNode(hotel_op, InvokeMethod("get_name"), assign_result_to="hotel_name") -# n6 = OpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_1"), assign_result_to="get_nearby_body") - -# df.add_edge(Edge(n0, n1)) -# df.add_edge(Edge(n1, n2)) -# df.add_edge(Edge(n2, n3)) -# df.add_edge(Edge(n3, n4, if_conditional=True)) -# df.add_edge(Edge(n3, n7, if_conditional=False)) -# df.add_edge(Edge(n4, n5)) -# df.add_edge(Edge(n5, n6)) -# df.add_edge(Edge(n6, n7)) -# get_nearby_op.dataflow = df - -# @pytest.mark.integration -# def test_nearby_hotels(): -# runtime = FlinkRuntime("test_nearby_hotels") -# runtime.init() -# runtime.add_operator(FlinkOperator(hotel_op)) -# runtime.add_stateless_operator(FlinkStatelessOperator(get_nearby_op)) - -# # Create Hotels -# hotels = [] -# init_hotel = OpNode(hotel_op, InitClass()) -# random.seed(42) -# for i in range(20): -# coord_x = random.randint(-10, 10) -# coord_y = random.randint(-10, 10) -# hotel = Hotel(f"h_{i}", Geo(coord_x, coord_y)) -# event = Event(init_hotel, [hotel.name], {"name": hotel.name, "loc": hotel.loc}, None) -# runtime.send(event) -# hotels.append(hotel) - -# collected_iterator: CloseableIterator = runtime.run(run_async=True, collect=True) -# records = [] -# def wait_for_event_id(id: int) -> EventResult: -# for record in collected_iterator: -# records.append(record) -# print(f"Collected record: {record}") -# if record.event_id == id: -# return record + def __repr__(self) -> str: + return f"Hotel({self.name}, {self.loc})" + + +def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: + loc = variable_map["loc"] + return math.sqrt((state.loc.x - loc.x) ** 2 + (state.loc.y - loc.y) ** 2) + +def get_name_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: + return state.name + +hotel_op = StatefulOperator(Hotel, + {"distance": distance_compiled, + "get_name": get_name_compiled}, {}) + + + +def get_nearby(hotels: list[Hotel], loc: Geo, dist: float): + return [hotel.get_name() for hotel in hotels if hotel.distance(loc) < dist] + + +# We compile just the predicate, the select is implemented using a selectall node +def get_nearby_predicate_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): + # the top of the key_stack is already the right key, so in this case we don't need to do anything + # loc = variable_map["loc"] + # we need the hotel_key for later. (body_compiled_0) + # variable_map["hotel_key"] = key_stack[-1] + pass + +def get_nearby_predicate_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> bool: + loc = variable_map["loc"] + dist = variable_map["dist"] + hotel_dist = variable_map["hotel_distance"] + # key_stack.pop() # shouldn't pop because this function is stateless + return hotel_dist < dist + +def get_nearby_body_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): + pass + +def get_nearby_body_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> str: + return variable_map["hotel_name"] + +get_nearby_op = StatelessOperator({ + "get_nearby_predicate_compiled_0": get_nearby_predicate_compiled_0, + "get_nearby_predicate_compiled_1": get_nearby_predicate_compiled_1, + "get_nearby_body_compiled_0": get_nearby_body_compiled_0, + "get_nearby_body_compiled_1": get_nearby_body_compiled_1, +}, None) + +# dataflow for getting all hotels within region +df = DataFlow("get_nearby") +n7 = CollectNode("get_nearby_result", "get_nearby_body") +n0 = SelectAllNode(Hotel, n7, assign_key_to="hotel_key") +n1 = StatelessOpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_0")) +n2 = OpNode(Hotel, InvokeMethod("distance"), assign_result_to="hotel_distance", read_key_from="hotel_key") +n3 = StatelessOpNode(get_nearby_op, InvokeMethod("get_nearby_predicate_compiled_1"), is_conditional=True) +n4 = StatelessOpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_0")) +n5 = OpNode(Hotel, InvokeMethod("get_name"), assign_result_to="hotel_name", read_key_from="hotel_key") +n6 = StatelessOpNode(get_nearby_op, InvokeMethod("get_nearby_body_compiled_1"), assign_result_to="get_nearby_body") + +df.add_edge(Edge(n0, n1)) +df.add_edge(Edge(n1, n2)) +df.add_edge(Edge(n2, n3)) +df.add_edge(Edge(n3, n4, if_conditional=True)) +df.add_edge(Edge(n3, n7, if_conditional=False)) +df.add_edge(Edge(n4, n5)) +df.add_edge(Edge(n5, n6)) +df.add_edge(Edge(n6, n7)) +get_nearby_op.dataflow = df + +@pytest.mark.integration +def test_nearby_hotels(): + runtime = FlinkRuntime("test_nearby_hotels") + runtime.init() + runtime.add_operator(FlinkOperator(hotel_op)) + runtime.add_stateless_operator(FlinkStatelessOperator(get_nearby_op)) + + # Create Hotels + hotels = [] + init_hotel = OpNode(Hotel, InitClass(), read_key_from="name") + random.seed(42) + for i in range(20): + coord_x = random.randint(-10, 10) + coord_y = random.randint(-10, 10) + hotel = Hotel(f"h_{i}", Geo(coord_x, coord_y)) + event = Event(init_hotel, {"name": hotel.name, "loc": hotel.loc}, None) + runtime.send(event) + hotels.append(hotel) + + collected_iterator: CloseableIterator = runtime.run(run_async=True, output='collect') + records = [] + def wait_for_event_id(id: int) -> EventResult: + for record in collected_iterator: + records.append(record) + print(f"Collected record: {record}") + if record.event_id == id: + return record -# def wait_for_n_records(num: int) -> list[EventResult]: -# i = 0 -# n_records = [] -# for record in collected_iterator: -# i += 1 -# records.append(record) -# n_records.append(record) -# print(f"Collected record: {record}") -# if i == num: -# return n_records - -# print("creating hotels") -# # Wait for hotels to be created -# wait_for_n_records(20) -# time.sleep(3) # wait for all hotels to be registered - -# dist = 5 -# loc = Geo(0, 0) -# # because of how the key stack works, we need to supply a key here -# event = Event(n0, ["workaround_key"], {"loc": loc, "dist": dist}, df) -# runtime.send(event, flush=True) + def wait_for_n_records(num: int) -> list[EventResult]: + i = 0 + n_records = [] + for record in collected_iterator: + i += 1 + records.append(record) + n_records.append(record) + print(f"Collected record: {record}") + if i == num: + return n_records + + print("creating hotels") + # Wait for hotels to be created + wait_for_n_records(20) + time.sleep(10) # wait for all hotels to be registered + + dist = 5 + loc = Geo(0, 0) + event = Event(n0, {"loc": loc, "dist": dist}, df) + runtime.send(event, flush=True) -# nearby = [] -# for hotel in hotels: -# if hotel.distance(loc) < dist: -# nearby.append(hotel.name) - -# event_result = wait_for_event_id(event._id) -# results = [r for r in event_result.result if r != None] -# print(nearby) -# assert set(results) == set(nearby) \ No newline at end of file + nearby = [] + for hotel in hotels: + if hotel.distance(loc) < dist: + nearby.append(hotel.name) + + event_result = wait_for_event_id(event._id) + results = [r for r in event_result.result if r != None] + print(nearby) + assert set(results) == set(nearby) \ No newline at end of file From 70b9bc8a47d247533d36ffb078312c0f61b6dff8 Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 13:04:06 +0100 Subject: [PATCH 4/9] Cleanup propagation --- src/cascade/dataflow/dataflow.py | 160 ++++++++++++---------- src/cascade/runtime/flink_runtime.py | 8 +- tests/integration/flink-runtime/common.py | 6 - 3 files changed, 90 insertions(+), 84 deletions(-) diff --git a/src/cascade/dataflow/dataflow.py b/src/cascade/dataflow/dataflow.py index 2e12559..a788252 100644 --- a/src/cascade/dataflow/dataflow.py +++ b/src/cascade/dataflow/dataflow.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Type, Union @@ -35,6 +35,10 @@ def __post_init__(self): self.id = Node._id_counter Node._id_counter += 1 + @abstractmethod + def propogate(self, event: 'Event', targets: list['Node'], result: Any, **kwargs) -> list['Event']: + pass + @dataclass class OpNode(Node): """A node in a `Dataflow` corresponding to a method call of a `StatefulOperator`. @@ -45,13 +49,57 @@ class OpNode(Node): method_type: Union[InitClass, InvokeMethod, Filter] read_key_from: str """Which variable to take as the key for this StatefulOperator""" + assign_result_to: Optional[str] = field(default=None) """What variable to assign the result of this node to, if any.""" is_conditional: bool = field(default=False) """Whether or not the boolean result of this node dictates the following path.""" collect_target: Optional['CollectTarget'] = field(default=None) """Whether the result of this node should go to a CollectNode.""" - + + def propogate(self, event: 'Event', targets: List[Node], result: Any) -> list['Event']: + return OpNode.propogate_opnode(self, event, targets, result) + + @staticmethod + def propogate_opnode(node: Union['OpNode', 'StatelessOpNode'], event: 'Event', targets: list[Node], result: Any) -> list['Event']: + if event.collect_target is not None: + # Assign new collect targets + collect_targets = [ + event.collect_target for i in range(len(targets)) + ] + else: + # Keep old collect targets + collect_targets = [node.collect_target for i in range(len(targets))] + + if node.is_conditional: + edges = event.dataflow.nodes[event.target.id].outgoing_edges + true_edges = [edge for edge in edges if edge.if_conditional] + false_edges = [edge for edge in edges if not edge.if_conditional] + if not (len(true_edges) == len(false_edges) == 1): + print(edges) + assert len(true_edges) == len(false_edges) == 1 + target_true = true_edges[0].to_node + target_false = false_edges[0].to_node + + + return [Event( + target_true if result else target_false, + event.variable_map, + event.dataflow, + _id=event._id, + collect_target=ct) + + for ct in collect_targets] + else: + return [Event( + target, + event.variable_map, + event.dataflow, + _id=event._id, + collect_target=ct) + + for target, ct in zip(targets, collect_targets)] + @dataclass class StatelessOpNode(Node): """A node in a `Dataflow` corresponding to a method call of a `StatelessOperator`. @@ -63,11 +111,15 @@ class StatelessOpNode(Node): """Which variable to take as the key for this StatefulOperator""" assign_result_to: Optional[str] = None + """What variable to assign the result of this node to, if any.""" is_conditional: bool = False """Whether or not the boolean result of this node dictates the following path.""" collect_target: Optional['CollectTarget'] = None """Whether the result of this node should go to a CollectNode.""" + def propogate(self, event: 'Event', targets: List[Node], result: Any) -> List['Event']: + return OpNode.propogate_opnode(self, event, targets, result) + @dataclass class SelectAllNode(Node): """A node type that will yield all items of an entity filtered by @@ -78,6 +130,21 @@ class SelectAllNode(Node): collect_target: 'CollectNode' assign_key_to: str + def propogate(self, event: 'Event', targets: List[Node], result: Any, keys: list[str]): + targets = event.dataflow.get_neighbors(event.target) + assert len(targets) == 1 + n = len(keys) + collect_targets = [ + CollectTarget(self.collect_target, n, i) + for i in range(n) + ] + return [Event( + targets[0], + event.variable_map | {self.assign_key_to: key}, + event.dataflow, + _id=event._id, + collect_target=ct) + for ct, key in zip(collect_targets, keys)] @dataclass class MergeNode(Node): @@ -98,6 +165,16 @@ class CollectNode(Node): read_results_from: str """The variable name in the variable map that the individual items put their result in.""" + def propogate(self, event: 'Event', targets: List[Node], result: Any, **kwargs) -> List['Event']: + collect_targets = [event.collect_target for i in range(len(targets))] + return [Event( + target, + event.variable_map, + event.dataflow, + _id=event._id, + collect_target=ct) + + for target, ct in zip(targets, collect_targets)] @dataclass class Edge(): @@ -234,82 +311,13 @@ def propogate(self, result, select_all_keys: Optional[list[str]]=None) -> Union[ if len(targets) == 0: return EventResult(self._id, result) else: - - collect_targets: list[Optional[CollectTarget]] - # Events with SelectAllNodes need to be assigned a CollectTarget - if isinstance(self.target, SelectAllNode): - assert select_all_keys - assert len(targets) == 1 - n = len(select_all_keys) - collect_targets = [ - CollectTarget(self.target.collect_target, n, i) - for i in range(n) - ] - return [Event( - targets[0], - self.variable_map | {self.target.assign_key_to: key}, - self.dataflow, - _id=self._id, - collect_target=ct) - - for ct, key in zip(collect_targets, select_all_keys)] - - elif isinstance(self.target, OpNode) and self.target.collect_target is not None: - collect_targets = [ - self.target.collect_target for i in range(len(targets)) - ] - else: - collect_targets = [self.collect_target for i in range(len(targets))] - - if (isinstance(self.target, OpNode) or isinstance(self.target, StatelessOpNode)) and self.target.is_conditional: - # In this case there will be two targets depending on the condition - - edges = self.dataflow.nodes[self.target.id].outgoing_edges - true_edges = [edge for edge in edges if edge.if_conditional] - false_edges = [edge for edge in edges if not edge.if_conditional] - if not (len(true_edges) == len(false_edges) == 1): - print(edges) - assert len(true_edges) == len(false_edges) == 1 - target_true = true_edges[0].to_node - target_false = false_edges[0].to_node - - - return [Event( - target_true if result else target_false, - # key_stack + [key], - self.variable_map, - self.dataflow, - _id=self._id, - collect_target=ct) + current_node = self.target - for ct in collect_targets] - - elif len(targets) == 1: - # We assume that all keys need to go to the same target - # this is only used for SelectAll propogation - - return [Event( - targets[0], - # key_stack + [key], - self.variable_map, - self.dataflow, - _id=self._id, - collect_target=ct) - - for ct in collect_targets] + if isinstance(current_node, SelectAllNode): + assert select_all_keys + return current_node.propogate(self, targets, result, select_all_keys) else: - # An event with multiple targets should have the same number of - # keys in a list on top of its key stack - # assert len(targets) == len(keys) - return [Event( - target, - # key_stack + [key], - self.variable_map, - self.dataflow, - _id=self._id, - collect_target=ct) - - for target, ct in zip(targets, collect_targets)] + return current_node.propogate(self, targets, result) @dataclass class EventResult(): diff --git a/src/cascade/runtime/flink_runtime.py b/src/cascade/runtime/flink_runtime.py index d9dfd33..975526c 100644 --- a/src/cascade/runtime/flink_runtime.py +++ b/src/cascade/runtime/flink_runtime.py @@ -3,7 +3,7 @@ import time import uuid import threading -from typing import Literal, Optional, Type, Union +from typing import Any, Literal, Optional, Type, Union from pyflink.common.typeinfo import Types, get_gateway from pyflink.common import Configuration, DeserializationSchema, SerializationSchema, WatermarkStrategy from pyflink.datastream.connectors import DeliveryGuarantee @@ -36,6 +36,10 @@ class FlinkRegisterKeyNode(Node): key: str cls: Type + def propogate(self, event: Event, targets: list[Node], result: Any, **kwargs) -> list[Event]: + """A key registration event does not propogate.""" + return [] + class FlinkOperator(KeyedProcessFunction): """Wraps an `cascade.dataflow.datflow.StatefulOperator` in a KeyedProcessFunction so that it can run in Flink. """ @@ -162,7 +166,7 @@ def process_element(self, event: Event, ctx: 'ProcessFunction.Context'): # Propogate the event to the next node new_events = event.propogate(None, select_all_keys=new_keys) - print(len(new_events), num_events) + assert isinstance(new_events, list), "SelectAll nodes shouldn't directly produce EventResults" assert num_events == len(new_events) logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Propogated {num_events} events with target: {event.target.collect_target}") diff --git a/tests/integration/flink-runtime/common.py b/tests/integration/flink-runtime/common.py index 7a676e2..67baee6 100644 --- a/tests/integration/flink-runtime/common.py +++ b/tests/integration/flink-runtime/common.py @@ -40,25 +40,19 @@ def __repr__(self): return f"Item(key='{self.key}', price={self.price})" def update_balance_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - # key_stack.pop() # final function state.balance += variable_map["amount"] return state.balance >= 0 def get_balance_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - # key_stack.pop() # final function return state.balance def get_price_compiled(variable_map: dict[str, Any], state: Item, key_stack: list[str]) -> Any: - # key_stack.pop() # final function return state.price -# Items (or other operators) are passed by key always def buy_item_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - # key_stack.append(variable_map["item_key"]) return None def buy_item_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - # key_stack.pop() state.balance = state.balance - variable_map["item_price"] return state.balance >= 0 From 97545d7d53c52cd44feea192643aec4e5fbc710f Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 14:01:31 +0100 Subject: [PATCH 5/9] Remove mention of key_stack in main lib --- src/cascade/dataflow/dataflow.py | 13 ++++--- src/cascade/dataflow/operator.py | 38 +++++++++---------- src/cascade/runtime/flink_runtime.py | 8 +--- tests/integration/flink-runtime/common.py | 14 +++---- .../flink-runtime/test_select_all.py | 17 +++------ 5 files changed, 40 insertions(+), 50 deletions(-) diff --git a/src/cascade/dataflow/dataflow.py b/src/cascade/dataflow/dataflow.py index a788252..960a865 100644 --- a/src/cascade/dataflow/dataflow.py +++ b/src/cascade/dataflow/dataflow.py @@ -1,6 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Callable, List, Optional, Type, Union +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # Prevent circular imports + from cascade.dataflow.operator import StatelessOperator + class Operator(ABC): pass @@ -106,7 +112,7 @@ class StatelessOpNode(Node): A `Dataflow` may reference the same `StatefulOperator` multiple times. The `StatefulOperator` that this node belongs to is referenced by `cls`.""" - operator: Operator # should be StatelessOperator but circular import! + operator: 'StatelessOperator' method_type: InvokeMethod """Which variable to take as the key for this StatefulOperator""" @@ -273,11 +279,6 @@ class Event(): target: 'Node' """The Node that this Event wants to go to.""" - # key_stack: list[str] - """The keys this event is concerned with. - The top of the stack, i.e. `key_stack[-1]`, should always correspond to a key - on the StatefulOperator of `target.cls` if `target` is an `OpNode`.""" - variable_map: dict[str, Any] """A mapping of variable identifiers to values. If `target` is an `OpNode` this map should include the variables needed for that method.""" diff --git a/src/cascade/dataflow/operator.py b/src/cascade/dataflow/operator.py index 6fca4d6..56d3e45 100644 --- a/src/cascade/dataflow/operator.py +++ b/src/cascade/dataflow/operator.py @@ -10,20 +10,19 @@ class MethodCall(Generic[T], Protocol): It corresponds to functions with the following signature: ```py - def my_compiled_method(*args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any: + def my_compiled_method(variable_map: dict[str, Any], state: T) -> Any ... ``` - `T` corresponds to a Python class, which, if modified, should return as the 2nd item in the tuple. - - The first item in the returned tuple corresponds to the actual return value of the function. + The variable_map contains a mapping from identifiers (variables/keys) to + their values. + The state of type `T` corresponds to a Python class. - The third item in the tuple corresponds to the `key_stack` which should be updated accordingly. - Notably, a terminal function should pop a key off the `key_stack`, whereas a function that calls - other functions should push the correct key(s) onto the `key_stack`. + + The value returned corresponds to the value treturned by the function. """ - def __call__(self, variable_map: dict[str, Any], state: T, key_stack: list[str]) -> dict[str, Any]: ... + def __call__(self, variable_map: dict[str, Any], state: T) -> Any: ... """@private""" @@ -61,14 +60,13 @@ def buy_item(self, item: Item) -> bool: Here, the class could be turned into a StatefulOperator as follows: ```py - def user_get_balance(variable_map: dict[str, Any], state: User, key_stack: list[str]): - key_stack.pop() + def user_get_balance(variable_map: dict[str, Any], state: User): return state.balance - def user_buy_item_0(variable_map: dict[str, Any], state: User, key_stack: list[str]): - key_stack.append(variable_map['item_key']) + def user_buy_item_0(variable_map: dict[str, Any], state: User): + pass - def user_buy_item_1(variable_map: dict[str, Any], state: User, key_stack: list[str]): + def user_buy_item_1(variable_map: dict[str, Any], state: User): state.balance -= variable_map['item_get_price'] return state.balance >= 0 @@ -100,19 +98,19 @@ def handle_init_class(self, *args, **kwargs) -> T: """Create an instance of the underlying class. Equivalent to `T.__init__(*args, **kwargs)`.""" return self.entity(*args, **kwargs) - def handle_invoke_method(self, method: InvokeMethod, variable_map: dict[str, Any], state: T, key_stack: list[str]) -> dict[str, Any]: + def handle_invoke_method(self, method: InvokeMethod, variable_map: dict[str, Any], state: T) -> dict[str, Any]: """Invoke the method of the underlying class. The `cascade.dataflow.dataflow.InvokeMethod` object must contain a method identifier that exists on the underlying compiled class functions. - The state `T` and key_stack is passed along to the function, and may be modified. + The state `T` is passed along to the function, and may be modified. """ - return self._methods[method.method_name](variable_map=variable_map, state=state, key_stack=key_stack) + return self._methods[method.method_name](variable_map=variable_map, state=state) class StatelessMethodCall(Protocol): - def __call__(self, variable_map: dict[str, Any], key_stack: list[str]) -> Any: ... + def __call__(self, variable_map: dict[str, Any]) -> Any: ... """@private""" @@ -123,13 +121,13 @@ def __init__(self, methods: dict[str, StatelessMethodCall], dataflow: DataFlow) self._methods = methods self.dataflow = dataflow - def handle_invoke_method(self, method: InvokeMethod, variable_map: dict[str, Any], key_stack: list[str]) -> dict[str, Any]: + def handle_invoke_method(self, method: InvokeMethod, variable_map: dict[str, Any]) -> dict[str, Any]: """Invoke the method of the underlying class. The `cascade.dataflow.dataflow.InvokeMethod` object must contain a method identifier that exists on the underlying compiled class functions. - The state `T` and key_stack is passed along to the function, and may be modified. + The state `T` is passed along to the function, and may be modified. """ - return self._methods[method.method_name](variable_map=variable_map, key_stack=key_stack) + return self._methods[method.method_name](variable_map=variable_map) diff --git a/src/cascade/runtime/flink_runtime.py b/src/cascade/runtime/flink_runtime.py index 975526c..67cf194 100644 --- a/src/cascade/runtime/flink_runtime.py +++ b/src/cascade/runtime/flink_runtime.py @@ -53,7 +53,6 @@ def open(self, runtime_context: RuntimeContext): self.state: ValueState = runtime_context.get_state(descriptor) def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): - # key_stack = event.key_stack # should be handled by filters on this FlinkOperator assert(isinstance(event.target, OpNode)) @@ -70,7 +69,6 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): # Register the created key in FlinkSelectAllOperator register_key_event = Event( FlinkRegisterKeyNode(key, self.operator.entity), - # [], {}, None, _id = event._id @@ -79,11 +77,10 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): yield register_key_event # Pop this key from the key stack so that we exit - # key_stack.pop() self.state.update(pickle.dumps(result)) elif isinstance(event.target.method_type, InvokeMethod): state = pickle.loads(self.state.value()) - result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, state=state, key_stack=[]) + result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, state=state) # TODO: check if state actually needs to be updated if state is not None: @@ -121,7 +118,7 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): logger.debug(f"FlinkStatelessOperator {self.operator.dataflow.name}[{event._id}]: Processing: {event.target.method_type}") if isinstance(event.target.method_type, InvokeMethod): - result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map, key_stack=[]) + result = self.operator.handle_invoke_method(event.target.method_type, variable_map=event.variable_map) else: raise Exception(f"A StatelessOperator cannot compute event type: {event.target.method_type}") @@ -160,7 +157,6 @@ def process_element(self, event: Event, ctx: 'ProcessFunction.Context'): logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Selecting all") # Yield all the keys we now about - # event.key_stack.append(state) new_keys = state num_events = len(state) diff --git a/tests/integration/flink-runtime/common.py b/tests/integration/flink-runtime/common.py index 67baee6..ccec426 100644 --- a/tests/integration/flink-runtime/common.py +++ b/tests/integration/flink-runtime/common.py @@ -39,28 +39,28 @@ def get_price(self) -> int: def __repr__(self): return f"Item(key='{self.key}', price={self.price})" -def update_balance_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def update_balance_compiled(variable_map: dict[str, Any], state: User) -> Any: state.balance += variable_map["amount"] return state.balance >= 0 -def get_balance_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def get_balance_compiled(variable_map: dict[str, Any], state: User) -> Any: return state.balance -def get_price_compiled(variable_map: dict[str, Any], state: Item, key_stack: list[str]) -> Any: +def get_price_compiled(variable_map: dict[str, Any], state: Item) -> Any: return state.price -def buy_item_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def buy_item_0_compiled(variable_map: dict[str, Any], state: User) -> Any: return None -def buy_item_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def buy_item_1_compiled(variable_map: dict[str, Any], state: User) -> Any: state.balance = state.balance - variable_map["item_price"] return state.balance >= 0 -def buy_2_items_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def buy_2_items_0_compiled(variable_map: dict[str, Any], state: User) -> Any: return None -def buy_2_items_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def buy_2_items_1_compiled(variable_map: dict[str, Any], state: User) -> Any: state.balance -= variable_map["item_prices"][0] + variable_map["item_prices"][1] return state.balance >= 0 diff --git a/tests/integration/flink-runtime/test_select_all.py b/tests/integration/flink-runtime/test_select_all.py index 9ade211..35c3265 100644 --- a/tests/integration/flink-runtime/test_select_all.py +++ b/tests/integration/flink-runtime/test_select_all.py @@ -34,11 +34,11 @@ def __repr__(self) -> str: return f"Hotel({self.name}, {self.loc})" -def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: +def distance_compiled(variable_map: dict[str, Any], state: Hotel) -> Any: loc = variable_map["loc"] return math.sqrt((state.loc.x - loc.x) ** 2 + (state.loc.y - loc.y) ** 2) -def get_name_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: +def get_name_compiled(variable_map: dict[str, Any], state: Hotel) -> Any: return state.name hotel_op = StatefulOperator(Hotel, @@ -52,24 +52,19 @@ def get_nearby(hotels: list[Hotel], loc: Geo, dist: float): # We compile just the predicate, the select is implemented using a selectall node -def get_nearby_predicate_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): - # the top of the key_stack is already the right key, so in this case we don't need to do anything - # loc = variable_map["loc"] - # we need the hotel_key for later. (body_compiled_0) - # variable_map["hotel_key"] = key_stack[-1] +def get_nearby_predicate_compiled_0(variable_map: dict[str, Any]): pass -def get_nearby_predicate_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> bool: +def get_nearby_predicate_compiled_1(variable_map: dict[str, Any]) -> bool: loc = variable_map["loc"] dist = variable_map["dist"] hotel_dist = variable_map["hotel_distance"] - # key_stack.pop() # shouldn't pop because this function is stateless return hotel_dist < dist -def get_nearby_body_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): +def get_nearby_body_compiled_0(variable_map: dict[str, Any]): pass -def get_nearby_body_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) -> str: +def get_nearby_body_compiled_1(variable_map: dict[str, Any]) -> str: return variable_map["hotel_name"] get_nearby_op = StatelessOperator({ From 4f0679e0047fb90c43ab680daff57510454f2f97 Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 15:15:18 +0100 Subject: [PATCH 6/9] Remove key stack from deathstar --- deathstar/demo.py | 29 +++++---- deathstar/entities/flight.py | 3 +- deathstar/entities/hotel.py | 11 ++-- deathstar/entities/recommendation.py | 60 ++++++++--------- deathstar/entities/search.py | 32 ++++------ deathstar/entities/user.py | 33 +++++----- deathstar/test_demo.py | 96 ++++++++++++++-------------- 7 files changed, 124 insertions(+), 140 deletions(-) diff --git a/deathstar/demo.py b/deathstar/demo.py index c42bdb7..68987a1 100644 --- a/deathstar/demo.py +++ b/deathstar/demo.py @@ -21,9 +21,9 @@ class DeathstarDemo(): def __init__(self, input_topic, output_topic): - self.init_user = OpNode(user_op, InitClass()) - self.init_hotel = OpNode(hotel_op, InitClass()) - self.init_flight = OpNode(flight_op, InitClass()) + self.init_user = OpNode(User, InitClass(), read_key_from="user_id") + self.init_hotel = OpNode(Hotel, InitClass(), read_key_from="key") + self.init_flight = OpNode(Flight, InitClass(), read_key_from="id") self.runtime = FlinkRuntime(input_topic, output_topic) def init_runtime(self): @@ -140,7 +140,7 @@ def populate(self): # populate users self.users = [User(f"Cornell_{i}", str(i) * 10) for i in range(501)] for user in self.users: - event = Event(self.init_user, [user.id], {"user_id": user.id, "password": user.password}, None) + event = Event(self.init_user, {"user_id": user.id, "password": user.password}, None) self.runtime.send(event) # populate hotels @@ -151,7 +151,7 @@ def populate(self): price = prices[i] hotel = Hotel(str(i), 10, geo, rate, price) self.hotels.append(hotel) - event = Event(self.init_hotel, [hotel.key], + event = Event(self.init_hotel, { "key": hotel.key, "cap": hotel.cap, @@ -164,13 +164,13 @@ def populate(self): # populate flights self.flights = [Flight(str(i), 10) for i in range(100)] for flight in self.flights[:-1]: - event = Event(self.init_flight, [flight.id], { + event = Event(self.init_flight, { "id": flight.id, "cap": flight.cap }, None) self.runtime.send(event) flight = self.flights[-1] - event = Event(self.init_flight, [flight.id], { + event = Event(self.init_flight, { "id": flight.id, "cap": flight.cap }, None) @@ -201,7 +201,7 @@ def search_hotel(self): lon = -122.095 + (random.randint(0, 325) - 157.0) / 1000.0 # We don't really use the in_date, out_date information - return Event(search_op.dataflow.entry, ["tempkey"], {"lat": lat, "lon": lon}, search_op.dataflow) + return Event(search_op.dataflow.entry, {"lat": lat, "lon": lon}, search_op.dataflow) def recommend(self, req_param=None): if req_param is None: @@ -214,13 +214,13 @@ def recommend(self, req_param=None): lat = 38.0235 + (random.randint(0, 481) - 240.5) / 1000.0 lon = -122.095 + (random.randint(0, 325) - 157.0) / 1000.0 - return Event(recommend_op.dataflow.entry, ["tempkey"], {"requirement": req_param, "lat": lat, "lon": lon}, recommend_op.dataflow) + return Event(recommend_op.dataflow.entry, {"requirement": req_param, "lat": lat, "lon": lon}, recommend_op.dataflow) def user_login(self): user_id = random.randint(0, 500) username = f"Cornell_{user_id}" password = str(user_id) * 10 - return Event(OpNode(user_op, InvokeMethod("login")), [username], {"password": password}, None) + return Event(OpNode(User, InvokeMethod("login"), read_key_from="user_key"), {"user_key": username, "password": password}, None) def reserve(self): hotel_id = random.randint(0, 99) @@ -230,7 +230,14 @@ def reserve(self): # user.order(flight, hotel) user_id = "Cornell_" + str(random.randint(0, 500)) - return Event(user_op.dataflows["order"].entry, [user_id], {"flight": str(flight_id), "hotel": str(hotel_id)}, user_op.dataflows["order"]) + return Event( + user_op.dataflows["order"].entry, + { + "user_key": user_id, + "flight_key": str(flight_id), + "hotel_key": str(hotel_id) + }, + user_op.dataflows["order"]) def deathstar_workload_generator(self): search_ratio = 0.6 diff --git a/deathstar/entities/flight.py b/deathstar/entities/flight.py index 60af68b..445ff9e 100644 --- a/deathstar/entities/flight.py +++ b/deathstar/entities/flight.py @@ -18,8 +18,7 @@ def reserve(self) -> bool: #### COMPILED FUNCTIONS (ORACLE) ##### -def reserve_compiled(variable_map: dict[str, Any], state: Flight, key_stack: list[str]) -> Any: - key_stack.pop() +def reserve_compiled(variable_map: dict[str, Any], state: Flight) -> Any: if state.cap <= 0: return False return True diff --git a/deathstar/entities/hotel.py b/deathstar/entities/hotel.py index 6689168..e57386d 100644 --- a/deathstar/entities/hotel.py +++ b/deathstar/entities/hotel.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from cascade.dataflow.operator import StatefulOperator from geopy.distance import distance @@ -59,18 +59,15 @@ def __key__(self) -> int: #### COMPILED FUNCTIONS (ORACLE) ##### -def reserve_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: - key_stack.pop() +def reserve_compiled(variable_map: dict[str, Any], state: Hotel) -> Any: if state.cap <= 0: return False return True -def get_geo_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: - key_stack.pop() +def get_geo_compiled(variable_map: dict[str, Any], state: Hotel) -> Any: return state.geo -def get_price_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any: - key_stack.pop() +def get_price_compiled(variable_map: dict[str, Any], state: Hotel) -> Any: return state.price hotel_op = StatefulOperator( diff --git a/deathstar/entities/recommendation.py b/deathstar/entities/recommendation.py index 7667210..99883ea 100644 --- a/deathstar/entities/recommendation.py +++ b/deathstar/entities/recommendation.py @@ -1,7 +1,7 @@ from typing import Any, Literal -from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode +from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode, StatelessOpNode from cascade.dataflow.operator import StatelessOperator -from deathstar.entities.hotel import Geo, Hotel, hotel_op +from deathstar.entities.hotel import Geo, Hotel # Stateless class Recommendation(): @@ -23,51 +23,43 @@ def get_recommendations(requirement: Literal["distance", "price"], lat: float, l #### COMPILED FUNCTIONS (ORACLE) #### -def get_recs_if_cond(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_if_cond(variable_map: dict[str, Any]): return variable_map["requirement"] == "distance" # list comprehension entry -def get_recs_if_body_0(variable_map: dict[str, Any], key_stack: list[str]): - hotel_key = key_stack[-1] - # The body will need the hotel key (actually, couldn't we just take the top of the key stack again?) - variable_map["hotel_key"] = hotel_key - # The next node (Hotel.get_geo) will need the hotel key - key_stack.append(hotel_key) +def get_recs_if_body_0(variable_map: dict[str, Any]): + pass # list comprehension body -def get_recs_if_body_1(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_if_body_1(variable_map: dict[str, Any]): hotel_geo: Geo = variable_map["hotel_geo"] lat, lon = variable_map["lat"], variable_map["lon"] dist = hotel_geo.distance_km(lat, lon) return (dist, variable_map["hotel_key"]) # after list comprehension -def get_recs_if_body_2(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_if_body_2(variable_map: dict[str, Any]): distances = variable_map["distances"] min_dist = min(distances, key=lambda x: x[0])[0] variable_map["res"] = [hotel for dist, hotel in distances if dist == min_dist] -def get_recs_elif_cond(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_elif_cond(variable_map: dict[str, Any]): return variable_map["requirement"] == "price" # list comprehension entry -def get_recs_elif_body_0(variable_map: dict[str, Any], key_stack: list[str]): - hotel_key = key_stack[-1] - # The body will need the hotel key (actually, couldn't we just take the top of the key stack again?) - variable_map["hotel_key"] = hotel_key - # The next node (Hotel.get_geo) will need the hotel key - key_stack.append(hotel_key) +def get_recs_elif_body_0(variable_map: dict[str, Any]): + pass # list comprehension body -def get_recs_elif_body_1(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_elif_body_1(variable_map: dict[str, Any]): return (variable_map["hotel_price"], variable_map["hotel_key"]) # after list comprehension -def get_recs_elif_body_2(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_elif_body_2(variable_map: dict[str, Any]): prices = variable_map["prices"] min_price = min(prices, key=lambda x: x[0])[0] variable_map["res"] = [hotel for price, hotel in prices if price == min_price] @@ -76,7 +68,7 @@ def get_recs_elif_body_2(variable_map: dict[str, Any], key_stack: list[str]): # a future optimization might instead duplicate this piece of code over the two # branches, in order to reduce the number of splits by one -def get_recs_final(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_final(variable_map: dict[str, Any]): return variable_map["res"] @@ -93,24 +85,24 @@ def get_recs_final(variable_map: dict[str, Any], key_stack: list[str]): }, None) df = DataFlow("get_recommendations") -n1 = OpNode(recommend_op, InvokeMethod("get_recs_if_cond"), is_conditional=True) -n2 = OpNode(recommend_op, InvokeMethod("get_recs_if_body_0")) -n3 = OpNode(hotel_op, InvokeMethod("get_geo"), assign_result_to="hotel_geo") -n4 = OpNode(recommend_op, InvokeMethod("get_recs_if_body_1"), assign_result_to="distance") +n1 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_cond"), is_conditional=True) +n2 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_body_0")) +n3 = OpNode(Hotel, InvokeMethod("get_geo"), assign_result_to="hotel_geo", read_key_from="hotel_key") +n4 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_body_1"), assign_result_to="distance") n5 = CollectNode("distances", "distance") -n6 = OpNode(recommend_op, InvokeMethod("get_recs_if_body_2")) -ns1 = SelectAllNode(Hotel, n5) +n6 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_body_2")) +ns1 = SelectAllNode(Hotel, n5, assign_key_to="hotel_key") -n7 = OpNode(recommend_op, InvokeMethod("get_recs_elif_cond"), is_conditional=True) -n8 = OpNode(recommend_op, InvokeMethod("get_recs_elif_body_0")) -n9 = OpNode(hotel_op, InvokeMethod("get_price"), assign_result_to="hotel_price") -n10 = OpNode(recommend_op, InvokeMethod("get_recs_elif_body_1"), assign_result_to="price") +n7 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_cond"), is_conditional=True) +n8 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_body_0")) +n9 = OpNode(Hotel, InvokeMethod("get_price"), assign_result_to="hotel_price", read_key_from="hotel_key") +n10 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_body_1"), assign_result_to="price") n11 = CollectNode("prices", "price") -n12 = OpNode(recommend_op, InvokeMethod("get_recs_elif_body_2")) -ns2 = SelectAllNode(Hotel, n11) +n12 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_body_2")) +ns2 = SelectAllNode(Hotel, n11, assign_key_to="hotel_key") -n13 = OpNode(recommend_op, InvokeMethod("get_recs_final")) +n13 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_final")) df.add_edge(Edge(n1, ns1, if_conditional=True)) df.add_edge(Edge(n1, n7, if_conditional=False)) diff --git a/deathstar/entities/search.py b/deathstar/entities/search.py index a2782d2..0b508d3 100644 --- a/deathstar/entities/search.py +++ b/deathstar/entities/search.py @@ -1,5 +1,5 @@ from typing import Any -from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode +from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode, StatelessOpNode from cascade.dataflow.operator import StatelessOperator from deathstar.entities.hotel import Geo, Hotel, hotel_op @@ -21,19 +21,11 @@ def nearby(lat: float, lon: float, in_date: int, out_date: int): # predicate 1 -def search_nearby_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): - # We assume that the top of the key stack is the hotel key. - # This assumption holds if the node before this one is a correctly - # configure SelectAllNode. - - hotel_key = key_stack[-1] - # The body will need the hotel key (actually, couldn't we just take the top of the key stack again?) - variable_map["hotel_key"] = hotel_key - # The next node (Hotel.get_geo) will need the hotel key - key_stack.append(hotel_key) +def search_nearby_compiled_0(variable_map: dict[str, Any]): + pass # predicate 2 -def search_nearby_compiled_1(variable_map: dict[str, Any], key_stack: list[str]): +def search_nearby_compiled_1(variable_map: dict[str, Any]): hotel_geo: Geo = variable_map["hotel_geo"] lat, lon = variable_map["lat"], variable_map["lon"] dist = hotel_geo.distance_km(lat, lon) @@ -42,11 +34,11 @@ def search_nearby_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) # body -def search_nearby_compiled_2(variable_map: dict[str, Any], key_stack: list[str]): +def search_nearby_compiled_2(variable_map: dict[str, Any]): return (variable_map["dist"], variable_map["hotel_key"]) # next line -def search_nearby_compiled_3(variable_map: dict[str, Any], key_stack: list[str]): +def search_nearby_compiled_3(variable_map: dict[str, Any]): distances = variable_map["distances"] hotels = [hotel for dist, hotel in sorted(distances)[:5]] return hotels @@ -60,14 +52,14 @@ def search_nearby_compiled_3(variable_map: dict[str, Any], key_stack: list[str]) }, None) df = DataFlow("search_nearby") -n1 = OpNode(search_op, InvokeMethod("search_nearby_compiled_0")) -n2 = OpNode(hotel_op, InvokeMethod("get_geo"), assign_result_to="hotel_geo") -n3 = OpNode(search_op, InvokeMethod("search_nearby_compiled_1"), is_conditional=True) -n4 = OpNode(search_op, InvokeMethod("search_nearby_compiled_2"), assign_result_to="search_body") +n1 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_0")) +n2 = OpNode(Hotel, InvokeMethod("get_geo"), assign_result_to="hotel_geo", read_key_from="hotel_key") +n3 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_1"), is_conditional=True) +n4 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_2"), assign_result_to="search_body") n5 = CollectNode("distances", "search_body") -n0 = SelectAllNode(Hotel, n5) +n0 = SelectAllNode(Hotel, n5, assign_key_to="hotel_key") -n6 = OpNode(search_op, InvokeMethod("search_nearby_compiled_3")) +n6 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_3")) df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) diff --git a/deathstar/entities/user.py b/deathstar/entities/user.py index 0234e91..95b135f 100644 --- a/deathstar/entities/user.py +++ b/deathstar/entities/user.py @@ -21,25 +21,22 @@ def order(self, flight: Flight, hotel: Hotel): #### COMPILED FUNCTIONS (ORACLE) ##### -def check_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def check_compiled(variable_map: dict[str, Any], state: User) -> Any: return state.password == variable_map["password"] -def order_compiled_entry_0(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map["hotel"]) +def order_compiled_entry_0(variable_map: dict[str, Any], state: User) -> Any: + pass -def order_compiled_entry_1(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map["flight"]) +def order_compiled_entry_1(variable_map: dict[str, Any], state: User) -> Any: + pass -def order_compiled_if_cond(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def order_compiled_if_cond(variable_map: dict[str, Any], state: User) -> Any: return variable_map["hotel_reserve"] and variable_map["flight_reserve"] -def order_compiled_if_body(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def order_compiled_if_body(variable_map: dict[str, Any], state: User) -> Any: return True -def order_compiled_else_body(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def order_compiled_else_body(variable_map: dict[str, Any], state: User) -> Any: return False user_op = StatefulOperator( @@ -59,13 +56,13 @@ def order_compiled_else_body(variable_map: dict[str, Any], state: User, key_stac # will try to automatically parallelize this. # There is also no user entry (this could also be an optimization) df = DataFlow("user_order") -n0 = OpNode(user_op, InvokeMethod("order_compiled_entry_0")) -n1 = OpNode(hotel_op, InvokeMethod("reserve"), assign_result_to="hotel_reserve") -n2 = OpNode(user_op, InvokeMethod("order_compiled_entry_1")) -n3 = OpNode(flight_op, InvokeMethod("reserve"), assign_result_to="flight_reserve") -n4 = OpNode(user_op, InvokeMethod("order_compiled_if_cond"), is_conditional=True) -n5 = OpNode(user_op, InvokeMethod("order_compiled_if_body")) -n6 = OpNode(user_op, InvokeMethod("order_compiled_else_body")) +n0 = OpNode(User, InvokeMethod("order_compiled_entry_0"), read_key_from="user_key") +n1 = OpNode(Hotel, InvokeMethod("reserve"), assign_result_to="hotel_reserve", read_key_from="hotel_key") +n2 = OpNode(User, InvokeMethod("order_compiled_entry_1"), read_key_from="user_key") +n3 = OpNode(Flight, InvokeMethod("reserve"), assign_result_to="flight_reserve", read_key_from="flight_key") +n4 = OpNode(User, InvokeMethod("order_compiled_if_cond"), is_conditional=True, read_key_from="user_key") +n5 = OpNode(User, InvokeMethod("order_compiled_if_body"), read_key_from="user_key") +n6 = OpNode(User, InvokeMethod("order_compiled_else_body"), read_key_from="user_key") df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) diff --git a/deathstar/test_demo.py b/deathstar/test_demo.py index f709948..a7e674e 100644 --- a/deathstar/test_demo.py +++ b/deathstar/test_demo.py @@ -1,49 +1,49 @@ -# from deathstar.demo import DeathstarDemo, DeathstarClient -# import time -# import pytest - -# @pytest.mark.integration -# def test_deathstar_demo(): -# ds = DeathstarDemo("deathstardemo-test", "dsd-out") -# ds.init_runtime() -# ds.runtime.run(run_async=True) -# print("Populating, press enter to go to the next step when done") -# ds.populate() - -# client = DeathstarClient("deathstardemo-test", "dsd-out") -# input() -# print("testing user login") -# event = client.user_login() -# client.send(event) - -# input() -# print("testing reserve") -# event = client.reserve() -# client.send(event) - -# input() -# print("testing search") -# event = client.search_hotel() -# client.send(event) - -# input() -# print("testing recommend (distance)") -# time.sleep(0.5) -# event = client.recommend(req_param="distance") -# client.send(event) - -# input() -# print("testing recommend (price)") -# time.sleep(0.5) -# event = client.recommend(req_param="price") -# client.send(event) - -# print(client.client._futures) -# input() -# print("done!") -# print(client.client._futures) - - -# if __name__ == "__main__": -# test_deathstar_demo() \ No newline at end of file +from deathstar.demo import DeathstarDemo, DeathstarClient +import time +import pytest + +@pytest.mark.integration +def test_deathstar_demo(): + ds = DeathstarDemo("deathstardemo-test", "dsd-out") + ds.init_runtime() + ds.runtime.run(run_async=True) + print("Populating, press enter to go to the next step when done") + ds.populate() + + client = DeathstarClient("deathstardemo-test", "dsd-out") + input() + print("testing user login") + event = client.user_login() + client.send(event) + + input() + print("testing reserve") + event = client.reserve() + client.send(event) + + input() + print("testing search") + event = client.search_hotel() + client.send(event) + + input() + print("testing recommend (distance)") + time.sleep(0.5) + event = client.recommend(req_param="distance") + client.send(event) + + input() + print("testing recommend (price)") + time.sleep(0.5) + event = client.recommend(req_param="price") + client.send(event) + + print(client.client._futures) + input() + print("done!") + print(client.client._futures) + + +if __name__ == "__main__": + test_deathstar_demo() \ No newline at end of file From fa6049727622c5e0466e967e9c8f8c6c0cb13f8b Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 15:49:32 +0100 Subject: [PATCH 7/9] remove key_stack from test_programs --- test_programs/expected/checkout_item.py | 20 +++---- test_programs/expected/checkout_two_items.py | 20 +++---- .../expected/deathstar_recommendation.py | 60 ++++++++----------- test_programs/expected/deathstar_search.py | 35 ++++------- test_programs/expected/deathstar_user.py | 33 +++++----- 5 files changed, 71 insertions(+), 97 deletions(-) diff --git a/test_programs/expected/checkout_item.py b/test_programs/expected/checkout_item.py index fd256bf..75a32fa 100644 --- a/test_programs/expected/checkout_item.py +++ b/test_programs/expected/checkout_item.py @@ -1,29 +1,27 @@ from typing import Any -# from ..target.checkout_item import User, Item -# from cascade.dataflow.dataflow import DataFlow, OpNode, InvokeMethod, Edge -def buy_item_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map['item_key']) +from cascade.dataflow.dataflow import DataFlow, Edge, InvokeMethod, OpNode +from test_programs.target.checkout_item import User, Item + +def buy_item_0_compiled(variable_map: dict[str, Any], state: User) -> Any: return None -def buy_item_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def buy_item_1_compiled(variable_map: dict[str, Any], state: User) -> Any: item_price_0 = variable_map['item_price_0'] state.balance -= item_price_0 return state.balance >= 0 -def get_price_0_compiled(variable_map: dict[str, Any], state: Item, key_stack: list[str]) -> Any: - key_stack.pop() +def get_price_0_compiled(variable_map: dict[str, Any], state: Item) -> Any: return state.price def user_buy_item_df(): df = DataFlow("user.buy_item") - n0 = OpNode(User, InvokeMethod("buy_item_0")) - n1 = OpNode(Item, InvokeMethod("get_price"), assign_result_to="item_price") - n2 = OpNode(User, InvokeMethod("buy_item_1")) + n0 = OpNode(User, InvokeMethod("buy_item_0"), read_key_from="user_key") + n1 = OpNode(Item, InvokeMethod("get_price"), assign_result_to="item_price", read_key_from="item_key") + n2 = OpNode(User, InvokeMethod("buy_item_1"), read_key_from="user_key") df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) df.entry = n0 diff --git a/test_programs/expected/checkout_two_items.py b/test_programs/expected/checkout_two_items.py index c3784bd..4784fa0 100644 --- a/test_programs/expected/checkout_two_items.py +++ b/test_programs/expected/checkout_two_items.py @@ -1,15 +1,12 @@ from typing import Any -# from ..target.checkout_item import User, Item -# from cascade.dataflow.dataflow import DataFlow, OpNode, InvokeMethod, Edge +from cascade.dataflow.dataflow import DataFlow, OpNode, InvokeMethod, Edge +from test_programs.target.checkout_two_items import User, Item -def buy_two_items_0_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map['item_1_key']) - key_stack.append(variable_map['item_2_key']) +def buy_two_items_0_compiled(variable_map: dict[str, Any], state: User) -> Any: return None -def buy_two_items_1_compiled(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def buy_two_items_1_compiled(variable_map: dict[str, Any], state: User) -> Any: item_price_1_0 = variable_map['item_price_1_0'] item_price_2_0 = variable_map['item_price_2_0'] total_price_0 = item_price_1_0 + item_price_2_0 @@ -17,16 +14,15 @@ def buy_two_items_1_compiled(variable_map: dict[str, Any], state: User, key_stac return state.balance >= 0 -def get_price_0_compiled(variable_map: dict[str, Any], state: Item, key_stack: list[str]) -> Any: - key_stack.pop() +def get_price_0_compiled(variable_map: dict[str, Any], state: Item) -> Any: return state.price def user_buy_item_df(): df = DataFlow("user.buy_item") - n0 = OpNode(User, InvokeMethod("buy_item_0")) - n1 = OpNode(Item, InvokeMethod("get_price"), assign_result_to="item_price") - n2 = OpNode(User, InvokeMethod("buy_item_1")) + n0 = OpNode(User, InvokeMethod("buy_item_0"), read_key_from="user_key") + n1 = OpNode(Item, InvokeMethod("get_price"), assign_result_to="item_price", read_key_from="item_key") + n2 = OpNode(User, InvokeMethod("buy_item_1"), read_key_from="user_key") df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) df.entry = n0 diff --git a/test_programs/expected/deathstar_recommendation.py b/test_programs/expected/deathstar_recommendation.py index 436aa5d..8a8a727 100644 --- a/test_programs/expected/deathstar_recommendation.py +++ b/test_programs/expected/deathstar_recommendation.py @@ -1,53 +1,45 @@ from typing import Any, Literal -from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode +from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode, StatelessOpNode from cascade.dataflow.operator import StatelessOperator -def get_recs_if_cond(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_if_cond(variable_map: dict[str, Any]): return variable_map["requirement"] == "distance" # list comprehension entry -def get_recs_if_body_0(variable_map: dict[str, Any], key_stack: list[str]): - hotel_key = key_stack[-1] - # The body will need the hotel key (actually, couldn't we just take the top of the key stack again?) - variable_map["hotel_key"] = hotel_key - # The next node (Hotel.get_geo) will need the hotel key - key_stack.append(hotel_key) +def get_recs_if_body_0(variable_map: dict[str, Any]): + pass # list comprehension body -def get_recs_if_body_1(variable_map: dict[str, Any], key_stack: list[str]): - hotel_geo: Geo = variable_map["hotel_geo"] +def get_recs_if_body_1(variable_map: dict[str, Any]): + hotel_geo = variable_map["hotel_geo"] lat, lon = variable_map["lat"], variable_map["lon"] dist = hotel_geo.distance_km(lat, lon) return (dist, variable_map["hotel_key"]) # after list comprehension -def get_recs_if_body_2(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_if_body_2(variable_map: dict[str, Any]): distances = variable_map["distances"] min_dist = min(distances, key=lambda x: x[0])[0] variable_map["res"] = [hotel for dist, hotel in distances if dist == min_dist] -def get_recs_elif_cond(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_elif_cond(variable_map: dict[str, Any]): return variable_map["requirement"] == "price" # list comprehension entry -def get_recs_elif_body_0(variable_map: dict[str, Any], key_stack: list[str]): - hotel_key = key_stack[-1] - # The body will need the hotel key (actually, couldn't we just take the top of the key stack again?) - variable_map["hotel_key"] = hotel_key - # The next node (Hotel.get_geo) will need the hotel key - key_stack.append(hotel_key) +def get_recs_elif_body_0(variable_map: dict[str, Any]): + pass # list comprehension body -def get_recs_elif_body_1(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_elif_body_1(variable_map: dict[str, Any]): return (variable_map["hotel_price"], variable_map["hotel_key"]) # after list comprehension -def get_recs_elif_body_2(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_elif_body_2(variable_map: dict[str, Any]): prices = variable_map["prices"] min_price = min(prices, key=lambda x: x[0])[0] variable_map["res"] = [hotel for price, hotel in prices if price == min_price] @@ -56,7 +48,7 @@ def get_recs_elif_body_2(variable_map: dict[str, Any], key_stack: list[str]): # a future optimization might instead duplicate this piece of code over the two # branches, in order to reduce the number of splits by one -def get_recs_final(variable_map: dict[str, Any], key_stack: list[str]): +def get_recs_final(variable_map: dict[str, Any]): return variable_map["res"] @@ -74,24 +66,24 @@ def get_recs_final(variable_map: dict[str, Any], key_stack: list[str]): def get_recommendations_df(): df = DataFlow("get_recommendations") - n1 = OpNode(recommend_op, InvokeMethod("get_recs_if_cond"), is_conditional=True) - n2 = OpNode(recommend_op, InvokeMethod("get_recs_if_body_0")) - n3 = OpNode(hotel_op, InvokeMethod("get_geo"), assign_result_to="hotel_geo") - n4 = OpNode(recommend_op, InvokeMethod("get_recs_if_body_1"), assign_result_to="distance") + n1 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_cond"), is_conditional=True) + n2 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_body_0")) + n3 = OpNode(Hotel, InvokeMethod("get_geo"), assign_result_to="hotel_geo", read_key_from="hotel_key") + n4 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_body_1"), assign_result_to="distance") n5 = CollectNode("distances", "distance") - n6 = OpNode(recommend_op, InvokeMethod("get_recs_if_body_2")) - ns1 = SelectAllNode(Hotel, n5) + n6 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_if_body_2")) + ns1 = SelectAllNode(Hotel, n5, assign_key_to="hotel_key") - n7 = OpNode(recommend_op, InvokeMethod("get_recs_elif_cond"), is_conditional=True) - n8 = OpNode(recommend_op, InvokeMethod("get_recs_elif_body_0")) - n9 = OpNode(hotel_op, InvokeMethod("get_price"), assign_result_to="hotel_price") - n10 = OpNode(recommend_op, InvokeMethod("get_recs_elif_body_1"), assign_result_to="price") + n7 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_cond"), is_conditional=True) + n8 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_body_0")) + n9 = OpNode(Hotel, InvokeMethod("get_price"), assign_result_to="hotel_price", read_key_from="hotel_key") + n10 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_body_1"), assign_result_to="price") n11 = CollectNode("prices", "price") - n12 = OpNode(recommend_op, InvokeMethod("get_recs_elif_body_2")) - ns2 = SelectAllNode(Hotel, n11) + n12 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_elif_body_2")) + ns2 = SelectAllNode(Hotel, n11, assign_key_to="hotel_key") - n13 = OpNode(recommend_op, InvokeMethod("get_recs_final")) + n13 = StatelessOpNode(recommend_op, InvokeMethod("get_recs_final")) df.add_edge(Edge(n1, ns1, if_conditional=True)) df.add_edge(Edge(n1, n7, if_conditional=False)) diff --git a/test_programs/expected/deathstar_search.py b/test_programs/expected/deathstar_search.py index 06cbec0..cd20593 100644 --- a/test_programs/expected/deathstar_search.py +++ b/test_programs/expected/deathstar_search.py @@ -1,24 +1,15 @@ from typing import Any -from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode +from cascade.dataflow.dataflow import CollectNode, DataFlow, Edge, InvokeMethod, OpNode, SelectAllNode, StatelessOpNode from cascade.dataflow.operator import StatelessOperator - # predicate 1 -def search_nearby_compiled_0(variable_map: dict[str, Any], key_stack: list[str]): - # We assume that the top of the key stack is the hotel key. - # This assumption holds if the node before this one is a correctly - # configure SelectAllNode. - - hotel_key = key_stack[-1] - # The body will need the hotel key (actually, couldn't we just take the top of the key stack again?) - variable_map["hotel_key"] = hotel_key - # The next node (Hotel.get_geo) will need the hotel key - key_stack.append(hotel_key) +def search_nearby_compiled_0(variable_map: dict[str, Any]): + pass # predicate 2 -def search_nearby_compiled_1(variable_map: dict[str, Any], key_stack: list[str]): - hotel_geo = variable_map["hotel_geo"] +def search_nearby_compiled_1(variable_map: dict[str, Any]): + hotel_geo: Geo = variable_map["hotel_geo"] lat, lon = variable_map["lat"], variable_map["lon"] dist = hotel_geo.distance_km(lat, lon) variable_map["dist"] = dist @@ -26,11 +17,11 @@ def search_nearby_compiled_1(variable_map: dict[str, Any], key_stack: list[str]) # body -def search_nearby_compiled_2(variable_map: dict[str, Any], key_stack: list[str]): +def search_nearby_compiled_2(variable_map: dict[str, Any]): return (variable_map["dist"], variable_map["hotel_key"]) # next line -def search_nearby_compiled_3(variable_map: dict[str, Any], key_stack: list[str]): +def search_nearby_compiled_3(variable_map: dict[str, Any]): distances = variable_map["distances"] hotels = [hotel for dist, hotel in sorted(distances)[:5]] return hotels @@ -45,14 +36,14 @@ def search_nearby_compiled_3(variable_map: dict[str, Any], key_stack: list[str]) def search_nearby_df(): df = DataFlow("search_nearby") - n1 = OpNode(search_op, InvokeMethod("search_nearby_compiled_0")) - n2 = OpNode(hotel_op, InvokeMethod("get_geo"), assign_result_to="hotel_geo") - n3 = OpNode(search_op, InvokeMethod("search_nearby_compiled_1"), is_conditional=True) - n4 = OpNode(search_op, InvokeMethod("search_nearby_compiled_2"), assign_result_to="search_body") + n1 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_0")) + n2 = OpNode(Hotel, InvokeMethod("get_geo"), assign_result_to="hotel_geo", read_key_from="hotel_key") + n3 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_1"), is_conditional=True) + n4 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_2"), assign_result_to="search_body") n5 = CollectNode("distances", "search_body") - n0 = SelectAllNode(Hotel, n5) + n0 = SelectAllNode(Hotel, n5, assign_key_to="hotel_key") - n6 = OpNode(search_op, InvokeMethod("search_nearby_compiled_3")) + n6 = StatelessOpNode(search_op, InvokeMethod("search_nearby_compiled_3")) df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) diff --git a/test_programs/expected/deathstar_user.py b/test_programs/expected/deathstar_user.py index 5aea434..64985ea 100644 --- a/test_programs/expected/deathstar_user.py +++ b/test_programs/expected/deathstar_user.py @@ -2,24 +2,21 @@ from cascade.dataflow.dataflow import DataFlow, Edge, InvokeMethod, OpNode from cascade.dataflow.operator import StatefulOperator -def order_compiled_entry_0(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map["hotel"]) +def order_compiled_entry_0(variable_map: dict[str, Any], state: User) -> Any: + pass -def order_compiled_entry_1(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.append(variable_map["flight"]) +def order_compiled_entry_1(variable_map: dict[str, Any], state: User) -> Any: + pass -def order_compiled_if_cond(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: +def order_compiled_if_cond(variable_map: dict[str, Any], state: User) -> Any: return variable_map["hotel_reserve"] and variable_map["flight_reserve"] -def order_compiled_if_body(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def order_compiled_if_body(variable_map: dict[str, Any], state: User) -> Any: return True -def order_compiled_else_body(variable_map: dict[str, Any], state: User, key_stack: list[str]) -> Any: - key_stack.pop() +def order_compiled_else_body(variable_map: dict[str, Any], state: User) -> Any: return False - user_op = StatefulOperator( User, { @@ -29,7 +26,7 @@ def order_compiled_else_body(variable_map: dict[str, Any], state: User, key_stac "order_compiled_if_body": order_compiled_if_body, "order_compiled_else_body": order_compiled_else_body }, - {} # dataflows (filled later) + {} ) # For now, the dataflow will be serial instead of parallel (calling hotel, then @@ -39,13 +36,13 @@ def order_compiled_else_body(variable_map: dict[str, Any], state: User, key_stac # before the first entity call). def user_order_df(): df = DataFlow("user_order") - n0 = OpNode(user_op, InvokeMethod("order_compiled_entry_0")) - n1 = OpNode(hotel_op, InvokeMethod("reserve"), assign_result_to="hotel_reserve") - n2 = OpNode(user_op, InvokeMethod("order_compiled_entry_1")) - n3 = OpNode(flight_op, InvokeMethod("reserve"), assign_result_to="flight_reserve") - n4 = OpNode(user_op, InvokeMethod("order_compiled_if_cond"), is_conditional=True) - n5 = OpNode(user_op, InvokeMethod("order_compiled_if_body")) - n6 = OpNode(user_op, InvokeMethod("order_compiled_else_body")) + n0 = OpNode(User, InvokeMethod("order_compiled_entry_0"), read_key_from="user_key") + n1 = OpNode(Hotel, InvokeMethod("reserve"), assign_result_to="hotel_reserve", read_key_from="hotel_key") + n2 = OpNode(User, InvokeMethod("order_compiled_entry_1"), read_key_from="user_key") + n3 = OpNode(Flight, InvokeMethod("reserve"), assign_result_to="flight_reserve", read_key_from="flight_key") + n4 = OpNode(User, InvokeMethod("order_compiled_if_cond"), is_conditional=True, read_key_from="user_key") + n5 = OpNode(User, InvokeMethod("order_compiled_if_body"), read_key_from="user_key") + n6 = OpNode(User, InvokeMethod("order_compiled_else_body"), read_key_from="user_key") df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n1, n2)) From 438f1da4e42fdabf755ee8d9e40b7e1e33b39b6a Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 20 Jan 2025 16:36:51 +0100 Subject: [PATCH 8/9] Fix tests --- src/cascade/runtime/python_runtime.py | 36 ++++++++----------- .../flink-runtime/test_collect_operator.py | 4 +-- .../flink-runtime/test_select_all.py | 4 +-- .../flink-runtime/test_two_entities.py | 4 +-- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/cascade/runtime/python_runtime.py b/src/cascade/runtime/python_runtime.py index cf936f3..8743014 100644 --- a/src/cascade/runtime/python_runtime.py +++ b/src/cascade/runtime/python_runtime.py @@ -1,7 +1,8 @@ from logging import Filter import threading +from typing import Type from cascade.dataflow.operator import StatefulOperator, StatelessOperator -from cascade.dataflow.dataflow import CollectNode, Event, EventResult, InitClass, InvokeMethod, OpNode, SelectAllNode +from cascade.dataflow.dataflow import CollectNode, Event, EventResult, InitClass, InvokeMethod, OpNode, SelectAllNode, StatelessOpNode from queue import Empty, Queue class PythonStatefulOperator(): @@ -11,17 +12,15 @@ def __init__(self, operator: StatefulOperator): def process(self, event: Event): assert(isinstance(event.target, OpNode)) - assert(isinstance(event.target.operator, StatefulOperator)) - assert(event.target.operator.entity == self.operator.entity) - key_stack = event.key_stack - key = key_stack[-1] + assert(event.target.entity == self.operator.entity) + + key = event.variable_map[event.target.read_key_from] print(f"PythonStatefulOperator: {event}") if isinstance(event.target.method_type, InitClass): result = self.operator.handle_init_class(*event.variable_map.values()) self.states[key] = result - key_stack.pop() elif isinstance(event.target.method_type, InvokeMethod): state = self.states[key] @@ -29,7 +28,6 @@ def process(self, event: Event): event.target.method_type, variable_map=event.variable_map, state=state, - key_stack=key_stack ) self.states[key] = state @@ -39,7 +37,7 @@ def process(self, event: Event): if event.target.assign_result_to is not None: event.variable_map[event.target.assign_result_to] = result - new_events = event.propogate(key_stack, result) + new_events = event.propogate(result) if isinstance(new_events, EventResult): yield new_events else: @@ -50,17 +48,14 @@ def __init__(self, operator: StatelessOperator): self.operator = operator def process(self, event: Event): - assert(isinstance(event.target, OpNode)) - assert(isinstance(event.target.operator, StatelessOperator)) + assert(isinstance(event.target, StatelessOpNode)) - key_stack = event.key_stack if isinstance(event.target.method_type, InvokeMethod): result = self.operator.handle_invoke_method( event.target.method_type, variable_map=event.variable_map, - key_stack=key_stack ) else: raise Exception(f"A StatelessOperator cannot compute event type: {event.target.method_type}") @@ -68,7 +63,7 @@ def process(self, event: Event): if event.target.assign_result_to is not None: event.variable_map[event.target.assign_result_to] = result - new_events = event.propogate(key_stack, result) + new_events = event.propogate(result) if isinstance(new_events, EventResult): yield new_events else: @@ -81,8 +76,8 @@ def __init__(self): self.events = Queue() self.results = Queue() self.running = False - self.statefuloperators: dict[StatefulOperator, PythonStatefulOperator] = {} - self.statelessoperators: dict[StatelessOperator, PythonStatelessOperator] = {} + self.statefuloperators: dict[Type, PythonStatefulOperator] = {} + self.statelessoperators: dict[str, PythonStatelessOperator] = {} def init(self): pass @@ -91,10 +86,9 @@ def _consume_events(self): self.running = True def consume_event(event: Event): if isinstance(event.target, OpNode): - if isinstance(event.target.operator, StatefulOperator): - yield from self.statefuloperators[event.target.operator].process(event) - elif isinstance(event.target.operator, StatelessOperator): - yield from self.statelessoperators[event.target.operator].process(event) + yield from self.statefuloperators[event.target.entity].process(event) + elif isinstance(event.target, StatelessOpNode): + yield from self.statelessoperators[event.target.operator.dataflow.name].process(event) elif isinstance(event.target, SelectAllNode): raise NotImplementedError() @@ -121,11 +115,11 @@ def consume_event(event: Event): def add_operator(self, op: StatefulOperator): """Add a `StatefulOperator` to the datastream.""" - self.statefuloperators[op] = PythonStatefulOperator(op) + self.statefuloperators[op.entity] = PythonStatefulOperator(op) def add_stateless_operator(self, op: StatelessOperator): """Add a `StatelessOperator` to the datastream.""" - self.statelessoperators[op] = PythonStatelessOperator(op) + self.statelessoperators[op.dataflow.name] = PythonStatelessOperator(op) def send(self, event: Event, flush=None): self.events.put(event) diff --git a/tests/integration/flink-runtime/test_collect_operator.py b/tests/integration/flink-runtime/test_collect_operator.py index 574c739..d14418f 100644 --- a/tests/integration/flink-runtime/test_collect_operator.py +++ b/tests/integration/flink-runtime/test_collect_operator.py @@ -10,8 +10,8 @@ def test_merge_operator(): runtime = FlinkRuntime("test_collect_operator") runtime.init() - runtime.add_operator(FlinkOperator(item_op)) - runtime.add_operator(FlinkOperator(user_op)) + runtime.add_operator(item_op) + runtime.add_operator(user_op) # Create a User object diff --git a/tests/integration/flink-runtime/test_select_all.py b/tests/integration/flink-runtime/test_select_all.py index 35c3265..602858d 100644 --- a/tests/integration/flink-runtime/test_select_all.py +++ b/tests/integration/flink-runtime/test_select_all.py @@ -99,8 +99,8 @@ def get_nearby_body_compiled_1(variable_map: dict[str, Any]) -> str: def test_nearby_hotels(): runtime = FlinkRuntime("test_nearby_hotels") runtime.init() - runtime.add_operator(FlinkOperator(hotel_op)) - runtime.add_stateless_operator(FlinkStatelessOperator(get_nearby_op)) + runtime.add_operator(hotel_op) + runtime.add_stateless_operator(get_nearby_op) # Create Hotels hotels = [] diff --git a/tests/integration/flink-runtime/test_two_entities.py b/tests/integration/flink-runtime/test_two_entities.py index 54309fa..3d89bd2 100644 --- a/tests/integration/flink-runtime/test_two_entities.py +++ b/tests/integration/flink-runtime/test_two_entities.py @@ -10,8 +10,8 @@ def test_two_entities(): runtime = FlinkRuntime("test_two_entities") runtime.init() - runtime.add_operator(FlinkOperator(item_op)) - runtime.add_operator(FlinkOperator(user_op)) + runtime.add_operator(item_op) + runtime.add_operator(user_op) # Create a User object foo_user = User("foo", 100) From 610104c3ee6719b51117b3adb228818c25e462de Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Tue, 21 Jan 2025 12:00:05 +0100 Subject: [PATCH 9/9] Fix dataflow test --- src/cascade/dataflow/test_dataflow.py | 68 +++++++++++++++------------ 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/src/cascade/dataflow/test_dataflow.py b/src/cascade/dataflow/test_dataflow.py index 1e29aad..a5b42af 100644 --- a/src/cascade/dataflow/test_dataflow.py +++ b/src/cascade/dataflow/test_dataflow.py @@ -12,14 +12,12 @@ def buy_item(self, item: 'DummyItem') -> bool: self.balance -= item_price return self.balance >= 0 -def buy_item_0_compiled(variable_map: dict[str, Any], state: DummyUser, key_stack: list[str]) -> dict[str, Any]: - key_stack.append(variable_map["item_key"]) +def buy_item_0_compiled(variable_map: dict[str, Any], state: DummyUser): return -def buy_item_1_compiled(variable_map: dict[str, Any], state: DummyUser, key_stack: list[str]) -> dict[str, Any]: - key_stack.pop() +def buy_item_1_compiled(variable_map: dict[str, Any], state: DummyUser): state.balance -= variable_map["item_price"] - return {"user_postive_balance": state.balance >= 0} + return state.balance >= 0 class DummyItem: def __init__(self, key: str, price: int): @@ -29,10 +27,8 @@ def __init__(self, key: str, price: int): def get_price(self) -> int: return self.price -def get_price_compiled(variable_map: dict[str, Any], state: DummyItem, key_stack: list[str]) -> dict[str, Any]: - key_stack.pop() # final function - variable_map["item_price"] = state.price - # return {"item_price": state.price} +def get_price_compiled(variable_map: dict[str, Any], state: DummyItem): + return state.price ################## TESTS ####################### @@ -46,53 +42,60 @@ def get_price_compiled(variable_map: dict[str, Any], state: DummyItem, key_stack def test_simple_df_propogation(): df = DataFlow("user.buy_item") - n1 = OpNode(DummyUser, InvokeMethod("buy_item_0_compiled")) - n2 = OpNode(DummyItem, InvokeMethod("get_price")) - n3 = OpNode(DummyUser, InvokeMethod("buy_item_1")) + n1 = OpNode(DummyUser, InvokeMethod("buy_item_0_compiled"), read_key_from="user_key") + n2 = OpNode(DummyItem, InvokeMethod("get_price"), read_key_from="item_key", assign_result_to="item_price") + n3 = OpNode(DummyUser, InvokeMethod("buy_item_1"), read_key_from="user_key") df.add_edge(Edge(n1, n2)) df.add_edge(Edge(n2, n3)) user.buy_item(item) - event = Event(n1, ["user"], {"item_key":"fork"}, df) + event = Event(n1, {"user_key": "user", "item_key":"fork"}, df) # Manually propogate - item_key = buy_item_0_compiled(event.variable_map, state=user, key_stack=event.key_stack) - next_event = event.propogate(event.key_stack, item_key) + item_key = buy_item_0_compiled(event.variable_map, state=user) + next_event = event.propogate(event, item_key) + assert isinstance(next_event, list) assert len(next_event) == 1 assert next_event[0].target == n2 - assert next_event[0].key_stack == ["user", "fork"] event = next_event[0] - item_price = get_price_compiled(event.variable_map, state=item, key_stack=event.key_stack) - next_event = event.propogate(event.key_stack, item_price) + # manually add the price to the variable map + item_price = get_price_compiled(event.variable_map, state=item) + assert n2.assign_result_to + event.variable_map[n2.assign_result_to] = item_price + next_event = event.propogate(item_price) + + assert isinstance(next_event, list) assert len(next_event) == 1 assert next_event[0].target == n3 event = next_event[0] - positive_balance = buy_item_1_compiled(event.variable_map, state=user, key_stack=event.key_stack) - next_event = event.propogate(event.key_stack, None) + positive_balance = buy_item_1_compiled(event.variable_map, state=user) + next_event = event.propogate(None) assert isinstance(next_event, EventResult) def test_merge_df_propogation(): df = DataFlow("user.buy_2_items") - n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0")) + n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0"), read_key_from="user_key") n3 = CollectNode(assign_result_to="item_prices", read_results_from="item_price") n1 = OpNode( DummyItem, InvokeMethod("get_price"), assign_result_to="item_price", - collect_target=CollectTarget(n3, 2, 0) + collect_target=CollectTarget(n3, 2, 0), + read_key_from="item_1_key" ) n2 = OpNode( DummyItem, InvokeMethod("get_price"), assign_result_to="item_price", - collect_target=CollectTarget(n3, 2, 1) + collect_target=CollectTarget(n3, 2, 1), + read_key_from="item_2_key" ) - n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1")) + n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1"), read_key_from="user_key") df.add_edge(Edge(n0, n1)) df.add_edge(Edge(n0, n2)) df.add_edge(Edge(n1, n3)) @@ -100,25 +103,30 @@ def test_merge_df_propogation(): df.add_edge(Edge(n3, n4)) # User with key "foo" buys items with keys "fork" and "spoon" - event = Event(n0, ["foo"], {"item_1_key": "fork", "item_2_key": "spoon"}, df) + event = Event(n0, {"user_key": "foo", "item_1_key": "fork", "item_2_key": "spoon"}, df) # Propogate the event (without actually doing any calculation) # Normally, the key_stack should've been updated by the runtime here: - key_stack = ["foo", ["fork", "spoon"]] - next_event = event.propogate(key_stack, None) + next_event = event.propogate(None) + assert isinstance(next_event, list) assert len(next_event) == 2 assert next_event[0].target == n1 assert next_event[1].target == n2 event1, event2 = next_event - next_event = event1.propogate(event1.key_stack, None) + next_event = event1.propogate(None) + + assert isinstance(next_event, list) assert len(next_event) == 1 assert next_event[0].target == n3 - next_event = event2.propogate(event2.key_stack, None) + next_event = event2.propogate(None) + + assert isinstance(next_event, list) assert len(next_event) == 1 assert next_event[0].target == n3 - final_event = next_event[0].propogate(next_event[0].key_stack, None) + final_event = next_event[0].propogate(None) + assert isinstance(final_event, list) assert final_event[0].target == n4