diff --git a/ait/core/__init__.py b/ait/core/__init__.py index 647afcd7..e24770c3 100644 --- a/ait/core/__init__.py +++ b/ait/core/__init__.py @@ -46,3 +46,6 @@ def deprecated_func(*args, **kwargs): sys.modules["ait"].SERVER_DEFAULT_XSUB_URL = "tcp://*:5559" # type: ignore[attr-defined] sys.modules["ait"].SERVER_DEFAULT_XPUB_URL = "tcp://*:5560" # type: ignore[attr-defined] + +sys.modules["ait"].MIN_PORT = 1024 # type: ignore[attr-defined] +sys.modules["ait"].MAX_PORT = 65535 # type: ignore[attr-defined] diff --git a/ait/core/server/broker.py b/ait/core/server/broker.py index 6908c7af..92ec2ad9 100644 --- a/ait/core/server/broker.py +++ b/ait/core/server/broker.py @@ -9,6 +9,7 @@ import ait.core.server from ait.core import log from .config import ZmqConfig +from .utils import is_valid_address_spec class Broker(gevent.Greenlet): @@ -69,7 +70,11 @@ def _subscribe_all(self): """ for stream in self.inbound_streams + self.outbound_streams: for input_ in stream.inputs: - if not type(input_) is int and input_ is not None: + if ( + not type(input_) is int + and input_ is not None + and not is_valid_address_spec(input_) + ): Broker.subscribe(stream, input_) for plugin in self.plugins: diff --git a/ait/core/server/client.py b/ait/core/server/client.py index 8cf8baef..bfd6297c 100644 --- a/ait/core/server/client.py +++ b/ait/core/server/client.py @@ -117,10 +117,10 @@ def _run(self): raise (e) -class PortOutputClient(ZMQInputClient): +class UDPOutputClient(ZMQInputClient): """ This is the parent class for all outbound streams which publish - to a port. It opens a UDP port to publish to and publishes + to a UDP port. It opens a UDP port to publish to and publishes outgoing message data to this port. """ @@ -131,20 +131,74 @@ def __init__( zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, **kwargs, ): - super(PortOutputClient, self).__init__( + super(UDPOutputClient, self).__init__( zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url ) - self.out_port = kwargs["output"] + if "output" in kwargs: + output = kwargs["output"] + if type(output) is int: + self.addr_spec = ("localhost", output) + elif utils.is_valid_address_spec(output): + protocol, hostname, port = output.split(":") + if protocol.lower() != "udp": + raise ( + ValueError(f"UDPOutputClient: Invalid Specification {output}") + ) + self.addr_spec = (hostname, int(port)) + else: + raise (ValueError(f"UDPOutputClient: Invalid Specification {output}")) + else: + raise (ValueError("UDPOutputClient: Invalid Specification")) + self.context = zmq_context # override pub to be udp socket self.pub = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) def publish(self, msg): - self.pub.sendto(msg, ("localhost", int(self.out_port))) + self.pub.sendto(msg, self.addr_spec) log.debug("Published message from {}".format(self)) -class PortInputClient(ZMQClient, gs.DatagramServer): +class TCPOutputClient(ZMQInputClient): + """ + This is the parent class for all outbound streams which publish + to a TCP port. It opens a TCP connection to publish to and publishes + outgoing message data to this port. + """ + + def __init__( + self, + zmq_context, + zmq_proxy_xsub_url=ait.SERVER_DEFAULT_XSUB_URL, + zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, + **kwargs, + ): + super(TCPOutputClient, self).__init__( + zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url + ) + if "output" in kwargs: + output = kwargs["output"] + if utils.is_valid_address_spec(output): + protocol, hostname, port = output.split(":") + if protocol.lower() != "tcp": + raise ( + ValueError(f"TCPOutputClient: Invalid Specification {output}") + ) + self.addr_spec = (hostname, int(port)) + else: + raise (ValueError(f"TCPOutputClient: Invalid Specification {output}")) + else: + raise (ValueError("TCPOutputClient: Invalid Specification")) + + self.context = zmq_context + self.pub = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + def publish(self, msg): + self.pub.connect(self.addr_spec) + self.pub.sendall(msg) + + +class UDPInputServer(ZMQClient, gs.DatagramServer): """ This is the parent class for all inbound streams which receive messages on a port. It opens a UDP port for receiving messages, listens for them, @@ -158,15 +212,31 @@ def __init__( zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, **kwargs, ): - if "input" in kwargs and type(kwargs["input"][0]) is int: - super(PortInputClient, self).__init__( + if "input" in kwargs: + input = kwargs["input"] + if type(input) is int: + host_spec = input + elif utils.is_valid_address_spec(input): + protocol, hostname, port = input.split(":") + if protocol.lower() != "udp": + raise (ValueError(f"UDPInputServer: Invalid Specification {input}")) + if hostname in ["127.0.0.1", "localhost"]: + host_spec = port + elif hostname in ["0.0.0.0", "server"]: + host_spec = f"0.0.0.0:{port}" + else: + raise (ValueError(f"UDPInputServer: Invalid Specification {input}")) + + else: + raise (ValueError(f"UDPInputServer: Invalid Specification {input}")) + super(UDPInputServer, self).__init__( zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url, - listener=int(kwargs["input"][0]), + listener=host_spec, ) else: - raise (ValueError("Input must be port in order to create PortInputClient")) + raise (ValueError("UDPInputServer: Invalid Specification")) # open sub socket self.sub = gevent.socket.socket(gevent.socket.AF_INET, gevent.socket.SOCK_DGRAM) @@ -175,3 +245,195 @@ def handle(self, packet, address): # This function provided for gs.DatagramServer class log.debug("{} received message from port {}".format(self, address)) self.process(packet) + + +class TCPInputServer(ZMQClient, gs.StreamServer): + """ + This class is similar to UDPInputServer except its TCP instead of UDP. + """ + + def __init__( + self, + zmq_context, + zmq_proxy_xsub_url=ait.SERVER_DEFAULT_XSUB_URL, + zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, + buffer=1024, + **kwargs, + ): + self.cur_socket = None + self.buffer = buffer + if "input" in kwargs: + input = kwargs["input"] + if not utils.is_valid_address_spec(input): + raise (ValueError(f"TCPInputServer: Invalid Specification {input}")) + protocol, hostname, port = input.split(":") + if protocol.lower() != "tcp" or hostname not in [ + "127.0.0.1", + "localhost", + "server", + "0.0.0.0", + ]: + raise (ValueError(f"TCPInputServer: Invalid Specification {input}")) + + self.sub = gevent.socket.socket( + gevent.socket.AF_INET, gevent.socket.SOCK_STREAM + ) + hostname = ( + "127.0.0.1" if hostname in ["127.0.0.1", "localhost"] else "0.0.0.0" + ) + super(TCPInputServer, self).__init__( + zmq_context, + zmq_proxy_xsub_url, + zmq_proxy_xpub_url, + listener=(hostname, int(port)), + ) + else: + raise (ValueError("TCPInputServer: Invalid Specification")) + + def handle(self, socket, address): + self.cur_socket = socket + with socket: + while True: + data = socket.recv(self.buffer) + if not data: + break + log.debug("{} received message from port {}".format(self, address)) + self.process(data) + gevent.sleep(0) # pass control back + + +class TCPInputClient(ZMQClient): + """ + This class creates a TCP input client. Unlike TCPInputServer and UDPInputServer, + this class will proactively initiate a connection with an input source and begin + receiving data from that source. This class does not inherit directly from gevent + servers and thus implements its own housekeeping functions. It also implements a + start function that spawns a process to stay consistent with the behavior of + TCPInputServer and UDPInputServer. + + """ + + def __init__( + self, + zmq_context, + zmq_proxy_xsub_url=ait.SERVER_DEFAULT_XSUB_URL, + zmq_proxy_xpub_url=ait.SERVER_DEFAULT_XPUB_URL, + connection_reattempts=5, + buffer=1024, + **kwargs, + ): + self.connection_reattempts = connection_reattempts + self.buffer = buffer + self.connection_status = -1 + self.proc = None + self.protocol = gevent.socket.SOCK_STREAM + + if "buffer" in kwargs and type(kwargs["buffer"]) == int: + self.buffer = kwargs["buffer"] + + if "input" in kwargs: + input = kwargs["input"] + if not utils.is_valid_address_spec(input): + raise (ValueError(f"TCPInputClient: Invalid Specification {input}")) + protocol, hostname, port = input.split(":") + if protocol.lower() != "tcp": + raise (ValueError(f"TCPInputClient: Invalid Specification {input}")) + super(TCPInputClient, self).__init__( + zmq_context, zmq_proxy_xsub_url, zmq_proxy_xpub_url + ) + + self.sub = gevent.socket.socket(gevent.socket.AF_INET, self.protocol) + + self.hostname = hostname + self.port = int(port) + self.address = (hostname, int(port)) + + else: + raise (ValueError("TCPInputClient: Invalid Specification")) + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if self.sub: + self.sub.close() + if self.proc: + self.proc.kill() + except Exception as e: + log.error(e) + + def __del__(self): + try: + if self.sub: + self.sub.close() + if self.proc: + self.proc.kill() + except Exception as e: + log.error(e) + + def __repr__(self): + return "<%s at %s %s>" % ( + type(self).__name__, + hex(id(self)), + self._formatinfo(), + ) + + def __str__(self): + return "<%s %s>" % (type(self).__name__, self._formatinfo()) + + def start(self): + self.proc = gevent.spawn(self._client) + + def _connect(self): + while self.connection_reattempts: + try: + res = self.sub.connect_ex((self.hostname, self.port)) + if res == 0: + self.connection_reattempts = 5 + return res + else: + self.connection_reattempts -= 1 + gevent.sleep(1) + except Exception as e: + log.error(e) + self.connection_reattempts -= 1 + gevent.sleep(1) + + def _exit(self): + try: + if self.sub: + self.sub.close() + if self.proc: + self.proc.kill() + except Exception as e: + log.error(e) + + def _client(self): + self.connection_status = self._connect() + if self.connection_status != 0: + log.error( + f"Unable to connect to client: {self.address[0]}:{self.address[1]}" + ) + self._exit() + while True: + packet = self.sub.recv(self.buffer) + if not packet: + gevent.sleep(1) + log.info( + f"Trying to reconnect to client: {self.address[0]}:{self.address[1]}" + ) + if self._connect() != 0: + log.error( + f"Unable to connect to client: {self.address[0]}:{self.address[1]}" + ) + self._exit() + self.process(packet) + + def _formatinfo(self): + result = "" + try: + if isinstance(self.address, tuple) and len(self.address) == 2: + result += "address=%s:%s" % self.address + else: + result += "address=%s" % (self.address,) + except Exception as ex: + result += str(ex) or "" + return result diff --git a/ait/core/server/server.py b/ait/core/server/server.py index 11b45686..c0b85168 100644 --- a/ait/core/server/server.py +++ b/ait/core/server/server.py @@ -11,9 +11,11 @@ from .plugin import PluginConfig from .plugin import PluginType from .process import PluginsProcess -from .stream import PortInputStream -from .stream import PortOutputStream -from .stream import ZMQStream +from .stream import input_stream_factory +from .stream import output_stream_factory +from .stream import TCPInputClientStream +from .stream import TCPInputServerStream +from .stream import UDPInputServerStream from ait.core import cfg from ait.core import log @@ -117,7 +119,6 @@ def _load_streams(self): common_err_msg.format(stream_type) + specific_err_msg[stream_type] ) streams = ait.config.get(f"server.{stream_type}-streams") - if streams is None: log.warn(err_msgs[stream_type]) else: @@ -125,7 +126,11 @@ def _load_streams(self): try: if stream_type == "inbound": strm = self._create_inbound_stream(s["stream"]) - if type(strm) == PortInputStream: + if ( + type(strm) == UDPInputServerStream + or type(strm) == TCPInputClientStream + or type(strm) == TCPInputServerStream + ): self.servers.append(strm) else: self.inbound_streams.append(strm) @@ -263,7 +268,6 @@ def _create_inbound_stream(self, config=None): """ if config is None: raise ValueError("No stream config to create stream from.") - name = self._get_stream_name(config) stream_handlers = self._get_stream_handlers(config, name) stream_input = config.get("input", None) @@ -273,20 +277,12 @@ def _create_inbound_stream(self, config=None): # Create ZMQ args re-using the Broker's context zmq_args_dict = self._create_zmq_args(True) - if type(stream_input[0]) is int: - return PortInputStream( - name, - stream_input, - stream_handlers, - zmq_args=zmq_args_dict, - ) - else: - return ZMQStream( - name, - stream_input, - stream_handlers, - zmq_args=zmq_args_dict, - ) + return input_stream_factory( + name, + stream_input, + stream_handlers, + zmq_args=zmq_args_dict, + ) def _create_outbound_stream(self, config=None): """ @@ -316,26 +312,9 @@ def _create_outbound_stream(self, config=None): # Create ZMQ args re-using the Broker's context zmq_args_dict = self._create_zmq_args(True) - if type(stream_output) is int: - ostream = PortOutputStream( - name, - stream_input, - stream_output, - stream_handlers, - zmq_args=zmq_args_dict, - ) - else: - if stream_output is not None: - log.warn( - f"Output of stream {name} is not an integer port. " - "Stream outputs can only be ports." - ) - ostream = ZMQStream( - name, - stream_input, - stream_handlers, - zmq_args=zmq_args_dict, - ) + ostream = output_stream_factory( + name, stream_input, stream_output, stream_handlers, zmq_args=zmq_args_dict + ) # Set the cmd subscriber field for the stream ostream.cmd_subscriber = stream_cmd_sub is True diff --git a/ait/core/server/stream.py b/ait/core/server/stream.py index 86b1a2ae..65a7f122 100644 --- a/ait/core/server/stream.py +++ b/ait/core/server/stream.py @@ -1,7 +1,11 @@ import ait.core.log -from .client import PortInputClient -from .client import PortOutputClient +from .client import TCPInputClient +from .client import TCPInputServer +from .client import TCPOutputClient +from .client import UDPInputServer +from .client import UDPOutputClient from .client import ZMQInputClient +from .utils import is_valid_address_spec class Stream: @@ -51,7 +55,9 @@ def __init__(self, name, inputs, handlers, zmq_args=None, **kwargs): input=self.inputs, output=kwargs["output"], **zmq_args ) else: - super(Stream, self).__init__(input=self.inputs, **zmq_args) + super(Stream, self).__init__( + input=self.inputs, protocol=kwargs.get("protocol", None), **zmq_args + ) def __repr__(self): return "<{} name={}>".format( @@ -71,7 +77,6 @@ def process(self, input_data, topic=None): """ for handler in self.handlers: output = handler.handle(input_data) - if output: input_data = output else: @@ -81,7 +86,6 @@ def process(self, input_data, topic=None): ) ait.core.log.info(msg) return - self.publish(input_data) def valid_workflow(self): @@ -101,13 +105,146 @@ def valid_workflow(self): return True -class PortInputStream(Stream, PortInputClient): +def output_stream_factory(name, inputs, outputs, handlers, zmq_args=None): + """ + This factory preempts the creating of output streams directly. It accepts + the same args as any given stream class and then based primarily on the + values in 'outputs' decides on the appropriate stream to instantiate and + then returns it. + """ + + parsed_output = outputs + if type(parsed_output) is list and len(parsed_output) > 0: + if len(parsed_output) > 1: + ait.core.log.warn(f"Additional output args discarded {parsed_output[1:]}") + parsed_output = parsed_output[0] + if type(parsed_output) is int: + if ait.MIN_PORT <= parsed_output <= ait.MAX_PORT: + return UDPOutputStream( + name, inputs, parsed_output, handlers, zmq_args=zmq_args + ) + else: + raise ValueError(f"Output stream specification invalid: {outputs}") + + elif type(parsed_output) is str and is_valid_address_spec(parsed_output): + protocol, hostname, port = parsed_output.split(":") + if protocol.lower() == "udp" and ait.MIN_PORT <= int(port) <= ait.MAX_PORT: + return UDPOutputStream( + name, inputs, parsed_output, handlers, zmq_args=zmq_args + ) + elif protocol.lower() == "tcp" and ait.MIN_PORT <= int(port) <= ait.MAX_PORT: + return TCPOutputStream( + name, inputs, parsed_output, handlers, zmq_args=zmq_args + ) + else: + raise ValueError(f"Output stream specification invalid: {outputs}") + elif parsed_output is None or ( + type(parsed_output) is list and len(parsed_output) == 0 + ): + return ZMQStream( + name, + inputs, + handlers, + zmq_args=zmq_args, + ) + else: + raise ValueError(f"Output stream specification invalid: {outputs}") + + +def input_stream_factory(name, inputs, handlers, zmq_args=None): + """ + This factory preempts the creating of input streams directly. It accepts + the same args as any given stream class and then based primarily on the + values in 'inputs' decides on the appropriate stream to instantiate and + then returns it. + """ + + stream = None + parsed_inputs = inputs + if type(parsed_inputs) is int: + parsed_inputs = [parsed_inputs] + if type(parsed_inputs) is str: + parsed_inputs = [parsed_inputs] + + if type(parsed_inputs) is not list or ( + type(parsed_inputs) is list and len(parsed_inputs) == 0 + ): + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") + + # backwards compatability with original UDP server spec + if ( + type(parsed_inputs) is list + and type(parsed_inputs[0]) is int + and ait.MIN_PORT <= parsed_inputs[0] <= ait.MAX_PORT + ): + stream = UDPInputServerStream( + name, parsed_inputs[0], handlers, zmq_args=zmq_args + ) + elif is_valid_address_spec(parsed_inputs[0]): + protocol, hostname, port = parsed_inputs[0].split(":") + if int(port) < ait.MIN_PORT or int(port) > ait.MAX_PORT: + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") + if protocol.lower() == "tcp": + if hostname.lower() in [ + "server", + "localhost", + "127.0.0.1", + "0.0.0.0", + ]: + stream = TCPInputServerStream( + name, parsed_inputs[0], handlers, zmq_args + ) + else: + stream = TCPInputClientStream( + name, parsed_inputs[0], handlers, zmq_args + ) + else: + if hostname.lower() in [ + "server", + "localhost", + "127.0.0.1", + "0.0.0.0", + ]: + stream = UDPInputServerStream( + name, parsed_inputs[0], handlers, zmq_args=zmq_args + ) + else: + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") + elif all(isinstance(item, str) for item in parsed_inputs): + stream = ZMQStream(name, parsed_inputs, handlers, zmq_args=zmq_args) + else: + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") + + if stream is None: + raise ValueError(f"Input stream specification invalid: {parsed_inputs}") + return stream + + +class UDPInputServerStream(Stream, UDPInputServer): """ This stream type listens for messages from a UDP port and publishes to a ZMQ socket. """ def __init__(self, name, inputs, handlers, zmq_args=None): - super(PortInputStream, self).__init__(name, inputs, handlers, zmq_args) + super(UDPInputServerStream, self).__init__(name, inputs, handlers, zmq_args) + + +class TCPInputServerStream(Stream, TCPInputServer): + """ + This stream type listens for messages from a TCP port and publishes to a ZMQ socket. + """ + + def __init__(self, name, inputs, handlers, zmq_args=None): + super(TCPInputServerStream, self).__init__(name, inputs, handlers, zmq_args) + + +class TCPInputClientStream(Stream, TCPInputClient): + """ + This stream type connects to a TCP server and publishes to a ZMQ socket. + """ + + def __init__(self, name, inputs, handlers, zmq_args=None): + super(TCPInputClientStream, self).__init__(name, inputs, handlers, zmq_args) class ZMQStream(Stream, ZMQInputClient): @@ -120,13 +257,25 @@ def __init__(self, name, inputs, handlers, zmq_args=None): super(ZMQStream, self).__init__(name, inputs, handlers, zmq_args) -class PortOutputStream(Stream, PortOutputClient): +class UDPOutputStream(Stream, UDPOutputClient): """ This stream type listens for messages from another stream or plugin and publishes to a UDP port. """ def __init__(self, name, inputs, output, handlers, zmq_args=None): - super(PortOutputStream, self).__init__( + super(UDPOutputStream, self).__init__( + name, inputs, handlers, zmq_args, output=output + ) + + +class TCPOutputStream(Stream, TCPOutputClient): + """ + This stream type listens for messages from another stream or plugin and + publishes to a TCP port. + """ + + def __init__(self, name, inputs, output, handlers, zmq_args=None): + super(TCPOutputStream, self).__init__( name, inputs, handlers, zmq_args, output=output ) diff --git a/ait/core/server/utils.py b/ait/core/server/utils.py index 02eb49ee..7f5c2ff3 100644 --- a/ait/core/server/utils.py +++ b/ait/core/server/utils.py @@ -12,6 +12,7 @@ # or other export authority as may be required before exporting such # information to foreign countries or providing access to foreign persons. import pickle +import re def encode_message(topic, data): @@ -62,3 +63,10 @@ def decode_message(msg): msg = None return (tpc, msg) + + +def is_valid_address_spec(address): + if type(address) is not str: + return False + pattern = r"^(TCP|UDP|tcp|udp):.*:\d{1,5}$" + return bool(re.match(pattern, address)) diff --git a/doc/source/server_architecture.rst b/doc/source/server_architecture.rst index 895d1ef1..e1a542e3 100644 --- a/doc/source/server_architecture.rst +++ b/doc/source/server_architecture.rst @@ -61,21 +61,29 @@ AIT provides a number of default plugins. Check the `Plugins API documentation < Streams ^^^^^^^ - Streams must be listed under either **inbound-streams** or **outbound-streams**, and must have a **name**. -- **Inbound streams** can have an integer port or inbound streams as their **input**. Inbound streams can have multiple inputs. A port input should always be listed as the first input to an inbound stream. +- **Inbound streams** can have an address specification or inbound streams as their **input**. Inbound streams can have multiple inputs. - The server sets up an input stream that emits properly formed telemetry packet messages over a globally configured topic. This is used internally by the ground script API for telemetry monitoring. The input streams that pass data to this stream must output data in the Packet UID annotated format that the core packet handlers use. The input streams used can be configured via the **server.api-telemetry-streams** field. If no configuration is provided the server will default to all valid input streams if possible. See :ref:`the Ground Script API documentation ` for additional information. - **Outbound streams** can have plugins or outbound streams as their **input**. Outbound streams can have multiple inputs. - - Outbound streams also have the option to **output** to an integer port (see :ref:`example config below `). + - Outbound streams also have the option to **output** to an address specification (see :ref:`example config below `). - The server exposes an entry point for commands submitted by other processes. During initialization, this entry point will be connected to a single outbound stream, either explicitly declared by the stream (by setting the **command-subscriber** field; see :ref:`example config below `), or decided by the server (select the first outbound stream in the configuration file). - Streams can have any number of **handlers**. A stream passes each received *packet* through its handlers in order and publishes the result. -- There are several stream classes that inherit from the base stream class. These child classes exist for handling the input and output of streams differently based on whether the inputs/output are ports or other streams and plugins. The appropriate stream type will be instantiated based on whether the stream is an inbound or outbound stream and based on the inputs/output specified in the stream's configs. If the input type of an inbound stream is an integer, it will be assumed to be a port. If it is a string, it will be assumed to be another stream name or plugin. Only outbound streams can have an output, and the output must be a port, not another stream or plugin. +- There are several stream classes that inherit from the base stream class. These child classes exist for handling the input and output of streams differently based on whether the inputs/output are remote hosts, ports or other streams and plugins. The appropriate stream type will be instantiated based on whether the stream is an inbound or outbound stream and based on the inputs/output specified in the stream's configs. Only outbound streams can have an output, and the output must be an address specification, not another stream or plugin. .. _Stream_config: +TCP/UDP Address Specification: + +.. code-block:: none + + [TCP|UDP|tcp|udp]:[0.0.0.0|127.0.0.1|server|localhost]:[1024 - 65535] # UDP/TCP Server Spec + + [TCP|tcp]:[remote hostname|remote ip]:[1024 - 65535] # TCP Client Spec + Example configuration: .. code-block:: none @@ -86,17 +94,42 @@ Example configuration: input: - 3077 + # UDP Input Server - stream: - name: telem_port_in_stream + name: telem_port_in_stream_1 input: - 3076 handlers: - my_custom_handlers.TestbedTelemHandler + # UDP Input Server + - stream: + name: telem_port_in_stream_2 + input: + - "UDP:server:3077" + handlers: + - my_custom_handlers.TestbedTelemHandler + + # TCP Input Server + - stream: + name: telem_port_in_stream_3 + input: + - "TCP:server:3078" + handlers: + - my_custom_handlers.TestbedTelemHandler + + # TCP Input Client + - stream: + name: telem_port_in_stream_4 + input: + - "TCP:1.2.3.4:3079 + handlers: + - my_custom_handlers.TestbedTelemHandler + - stream: name: telem_testbed_stream input: - - telem_port_in_stream + - telem_port_in_stream_1 handlers: - name: ait.server.handlers.PacketHandler packet: 1553_HS_Packet @@ -114,14 +147,33 @@ Example configuration: - name: my_custom_handlers.FlightlikeCommandHandler command-subscriber: True + # UDP Output to localhost:3075 - stream: - name: command_port_out_stream + name: command_port_out_stream_1 input: - command_testbed_stream - command_flightlike_stream output: - 3075 + # UDP Output to remote host + - stream: + name: command_port_out_stream_2 + input: + - command_testbed_stream + - command_flightlike_stream + output: + - "UDP:1.2.3.4:3075" + + # TCP Output to remote host + - stream: + name: command_port_out_stream_3 + input: + - command_testbed_stream + - command_flightlike_stream + output: + - "TCP:1.2.3.4:3075" + Handlers ^^^^^^^^ diff --git a/tests/ait/core/server/test_client.py b/tests/ait/core/server/test_client.py index e69de29b..b86a694b 100644 --- a/tests/ait/core/server/test_client.py +++ b/tests/ait/core/server/test_client.py @@ -0,0 +1,82 @@ +import gevent + +from ait.core.server.broker import Broker +from ait.core.server.client import TCPInputClient +from ait.core.server.client import TCPInputServer + +broker = Broker() +TEST_BYTES = "Howdy".encode() +TEST_PORT = 6666 + + +class SimpleServer(gevent.server.StreamServer): + def handle(self, socket, address): + socket.sendall(TEST_BYTES) + + +class TCPServer(TCPInputServer): + def __init__(self, name, inputs, **kwargs): + super(TCPServer, self).__init__(broker.context, input=inputs) + + def process(self, input_data): + self.cur_socket.sendall(input_data) + + +class TCPClient(TCPInputClient): + def __init__(self, name, inputs, **kwargs): + super(TCPClient, self).__init__( + broker.context, input=inputs, protocol=gevent.socket.SOCK_STREAM + ) + self.input_data = None + + def process(self, input_data): + self.input_data = input_data + self._exit() + + +class TestTCPServer: + def setup_method(self): + self.server = TCPServer("test_tcp_server", inputs=f"tcp:server:{TEST_PORT}") + self.server.start() + self.client = gevent.socket.create_connection(("127.0.0.1", TEST_PORT)) + + def teardown_method(self): + self.server.stop() + self.client.close() + + def test_TCP_server(self): + nbytes = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes == len(TEST_BYTES) + assert response == TEST_BYTES + + def test_null_send(self): + nbytes1 = self.client.send(b"") + nbytes2 = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes1 == 0 + assert nbytes2 == len(TEST_BYTES) + assert response == TEST_BYTES + + +class TestTCPClient: + def setup_method(self): + self.server = SimpleServer(("127.0.0.1", 0)) + self.server.start() + self.client = TCPClient( + "test_tcp_client", inputs=f"tcp:127.0.0.1:{self.server.server_port}" + ) + + def teardown_method(self): + self.server.stop() + + def test_TCP_client(self): + self.client.start() + gevent.sleep(1) + assert self.client.input_data == TEST_BYTES + + def test_bad_connection(self): + self.client.port = 1 + self.client.connection_reattempts = 2 + self.client.start() + assert self.client.connection_status != 0 diff --git a/tests/ait/core/server/test_server.py b/tests/ait/core/server/test_server.py index 28736a7b..25c93765 100644 --- a/tests/ait/core/server/test_server.py +++ b/tests/ait/core/server/test_server.py @@ -354,9 +354,9 @@ def test_successful_inbound_stream_creation( # Testing creation of inbound stream with port input config = cfg.AitConfig(config={"name": "some_stream", "input": [3333]}) created_stream = server._create_inbound_stream(config) - assert type(created_stream) == ait.core.server.stream.PortInputStream + assert type(created_stream) == ait.core.server.stream.UDPInputServerStream assert created_stream.name == "some_stream" - assert created_stream.inputs == [3333] + assert created_stream.inputs == 3333 assert created_stream.handlers == [] @mock.patch.object(ait.core.server.server.Server, "_create_handler") @@ -375,11 +375,11 @@ def test_successful_outbound_stream_creation( assert type(created_stream.handlers) == list # Testing creation of outbound stream with port output - config = cfg.AitConfig(config={"name": "some_stream", "output": 3333}) + config = cfg.AitConfig(config={"name": "some_stream", "output": [3333]}) created_stream = server._create_outbound_stream(config) - assert type(created_stream) == ait.core.server.stream.PortOutputStream + assert type(created_stream) == ait.core.server.stream.UDPOutputStream assert created_stream.name == "some_stream" - assert created_stream.out_port == 3333 + assert created_stream.addr_spec == ("localhost", 3333) assert created_stream.handlers == [] diff --git a/tests/ait/core/server/test_stream.py b/tests/ait/core/server/test_stream.py index 6d89a190..aa2e1df2 100644 --- a/tests/ait/core/server/test_stream.py +++ b/tests/ait/core/server/test_stream.py @@ -1,69 +1,272 @@ from unittest import mock +import gevent import pytest import zmq.green -import ait.core from ait.core.server.broker import Broker from ait.core.server.handlers import PacketHandler +from ait.core.server.stream import input_stream_factory +from ait.core.server.stream import output_stream_factory +from ait.core.server.stream import TCPInputClientStream +from ait.core.server.stream import TCPInputServerStream +from ait.core.server.stream import TCPOutputStream +from ait.core.server.stream import UDPInputServerStream +from ait.core.server.stream import UDPOutputStream from ait.core.server.stream import ZMQStream +broker = Broker() + + class TestStream: + invalid_stream_args = [ + "some_stream", + "input_stream", + [ + PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER"), + PacketHandler(input_type=int, packet="CCSDS_HEADER"), + ], + {"zmq_context": broker}, + ] + test_data = [ + ( + "zmq", + { + "name": "some_zmq_stream", + "inputs": ["input_stream"], + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": zmq.green.core._Socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ( + "udp_server", + { + "name": "some_udp_stream", + "inputs": 1234, + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": gevent._socket3.socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ( + "tcp_server", + { + "name": "some_tcp_stream_server", + "inputs": "TCP:server:1234", + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": gevent._socket3.socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ( + "tcp_client", + { + "name": "some_tcp_stream_client", + "inputs": "TCP:127.0.0.1:1234", + "handlers_len": 1, + "handler_type": PacketHandler, + "broker_context": broker.context, + "sub_type": gevent._socket3.socket, + "pub_type": zmq.green.core._Socket, + "repr": "", + }, + ), + ] + def setup_method(self): - self.broker = Broker() - self.stream = ZMQStream( - "some_stream", - ["input_stream"], - [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], - zmq_args={"zmq_context": self.broker.context}, - ) - self.stream.handlers = [ - PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER") - ] + self.streams = { + "zmq": ZMQStream( + "some_zmq_stream", + ["input_stream"], + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + "udp_server": UDPInputServerStream( + "some_udp_stream", + 1234, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + "tcp_server": TCPInputServerStream( + "some_tcp_stream_server", + "TCP:server:1234", + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + "tcp_client": TCPInputClientStream( + "some_tcp_stream_client", + "TCP:127.0.0.1:1234", + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + zmq_args={"zmq_context": broker.context}, + ), + } + for stream in self.streams.values(): + stream.handlers = [ + PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER") + ] - def test_stream_creation(self): - assert self.stream.name is "some_stream" - assert self.stream.inputs == ["input_stream"] - assert len(self.stream.handlers) == 1 - assert type(self.stream.handlers[0]) == PacketHandler - assert self.stream.context == self.broker.context - assert type(self.stream.pub) == zmq.green.core._Socket - assert type(self.stream.sub) == zmq.green.core._Socket + @pytest.mark.parametrize("stream,expected", test_data) + def test_stream_creation(self, stream, expected): + assert self.streams[stream].name is expected["name"] + assert self.streams[stream].inputs == expected["inputs"] + assert len(self.streams[stream].handlers) == expected["handlers_len"] + assert type(self.streams[stream].handlers[0]) == expected["handler_type"] + assert self.streams[stream].context == expected["broker_context"] + assert type(self.streams[stream].pub) == expected["pub_type"] + assert type(self.streams[stream].sub) == expected["sub_type"] - def test_repr(self): - assert self.stream.__repr__() == "" + @pytest.mark.parametrize("stream,expected", test_data) + def test_repr(self, stream, expected): + assert self.streams[stream].__repr__() == expected["repr"] + @pytest.mark.parametrize("stream,_", test_data) @mock.patch.object(PacketHandler, "handle") - def test_process(self, execute_handler_mock): - self.stream.process("input_data") + def test_process(self, execute_handler_mock, stream, _): + self.streams[stream].process("input_data") execute_handler_mock.assert_called_with("input_data") - def test_valid_workflow_one_handler(self): - assert self.stream.valid_workflow() is True + @pytest.mark.parametrize("stream,_", test_data) + def test_valid_workflow_one_handler(self, stream, _): + assert self.streams[stream].valid_workflow() is True - def test_valid_workflow_more_handlers(self): - self.stream.handlers.append( + @pytest.mark.parametrize("stream,_", test_data) + def test_valid_workflow_more_handlers(self, stream, _): + self.streams[stream].handlers.append( PacketHandler(input_type=str, packet="CCSDS_HEADER") ) - assert self.stream.valid_workflow() is True + assert self.streams[stream].valid_workflow() is True - def test_invalid_workflow_more_handlers(self): - self.stream.handlers.append( + @pytest.mark.parametrize("stream,_", test_data) + def test_invalid_workflow_more_handlers(self, stream, _): + self.streams[stream].handlers.append( PacketHandler(input_type=int, packet="CCSDS_HEADER") ) - assert self.stream.valid_workflow() is False + assert self.streams[stream].valid_workflow() is False - def test_stream_creation_invalid_workflow(self): + @pytest.mark.parametrize( + "stream,args", + [ + (ZMQStream, invalid_stream_args), + (UDPInputServerStream, invalid_stream_args), + (TCPInputServerStream, invalid_stream_args), + (TCPInputClientStream, invalid_stream_args), + ], + ) + def test_stream_creation_invalid_workflow(self, stream, args): with pytest.raises(ValueError): - ZMQStream( - "some_stream", - "input_stream", - [ - PacketHandler( - input_type=int, output_type=str, packet="CCSDS_HEADER" - ), - PacketHandler(input_type=int, packet="CCSDS_HEADER"), - ], - zmq_args={"zmq_context": self.broker.context}, - ) + stream(*args) + + @pytest.mark.parametrize( + "args,expected", + [ + (["TCP:127.0.0.1:1234"], TCPInputServerStream), + (["TCP:server:1234"], TCPInputServerStream), + (["TCP:0.0.0.0:1234"], TCPInputServerStream), + (["TCP:localhost:1234"], TCPInputServerStream), + (["TCP:foo:1234"], TCPInputClientStream), + ([1234], UDPInputServerStream), + (1234, UDPInputServerStream), + (["UDP:server:1234"], UDPInputServerStream), + (["UDP:localhost:1234"], UDPInputServerStream), + (["UDP:0.0.0.0:1234"], UDPInputServerStream), + (["UDP:127.0.0.1:1234"], UDPInputServerStream), + ("UDP:127.0.0.1:1234", UDPInputServerStream), + (["FOO"], ZMQStream), + (["FOO", "BAR"], ZMQStream), + ( + [1234, "FOO", "BAR"], + UDPInputServerStream, + ), # Technically valid but not really correct + ], + ) + def test_valid_input_stream_factory(self, args, expected): + full_args = [ + "foo", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + stream = input_stream_factory(*full_args) + assert isinstance(stream, expected) + + @pytest.mark.parametrize( + "args,expected", + [ + (["TCP:127.0.0.1:1"], ValueError), + ([1], ValueError), + (1, ValueError), + ([], ValueError), + (None, ValueError), + (["foo", "bar", "foo", 1], ValueError), + ], + ) + def test_invalid_input_stream_factory(self, args, expected): + full_args = [ + "foo", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + with pytest.raises(expected): + _ = input_stream_factory(*full_args) + + @pytest.mark.parametrize( + "args,expected", + [ + (["TCP:127.0.0.1:1234"], TCPOutputStream), + (["TCP:localhost:1234"], TCPOutputStream), + (["TCP:foo:1234"], TCPOutputStream), + (["UDP:127.0.0.1:1234"], UDPOutputStream), + (["UDP:localhost:1234"], UDPOutputStream), + (["UDP:foo:1234"], UDPOutputStream), + ([1234], UDPOutputStream), + (1234, UDPOutputStream), + ("UDP:foo:1234", UDPOutputStream), + ([], ZMQStream), + (None, ZMQStream), + ( + [1234, "TCP:foo:1234"], + UDPOutputStream, + ), # Technically valid but not really correct + ], + ) + def test_valid_output_stream_factory(self, args, expected): + full_args = [ + "foo", + "bar", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + stream = output_stream_factory(*full_args) + assert isinstance(stream, expected) + + @pytest.mark.parametrize( + "args,expected", + [ + (["FOO:127.0.0.1:1234"], ValueError), + (["UDP", "127.0.0.1", "1234"], ValueError), + (["FOO"], ValueError), + ], + ) + def test_invalid_output_stream_factory(self, args, expected): + full_args = [ + "foo", + "bar", + args, + [PacketHandler(input_type=int, output_type=str, packet="CCSDS_HEADER")], + {"zmq_context": broker.context}, + ] + with pytest.raises(expected): + _ = output_stream_factory(*full_args)