diff --git a/extensions/lgpilot/codeToNode.py b/extensions/lgpilot/codeToNode.py new file mode 100644 index 00000000..728a204c --- /dev/null +++ b/extensions/lgpilot/codeToNode.py @@ -0,0 +1,49 @@ +import sys +import re + +def create_func(writer, library_name, function_name, function_args): + writer.write(f"from {library_name} import {function_name}\n") + writer.write(f"def {function_name}({function_args}):\n\t") + writer.write(f"y = {function_name}({function_args})\n\t") + writer.write("return y\n") + +def create_setup(writer, function_name): + writer.write(f"\nclass {function_name}Node(lg.Node):\n\t") + writer.write("INPUT = lg.Topic(InputMessage)\n\tOUTPUT = lg.Topic(OutputMessage)\n\n\t") + writer.write(f"def setup(self):\n\t\tself.func = {function_name}\n\n\t") + +def create_feature(writer, function_name, function_args): + writer.write("@lg.subscriber(INPUT)\n\t@lg.publisher(OUTPUT)\n\n\t") + writer.write(f"def {function_name}_feature(self, message: InputMessage):\n\t\t") + # turn a string containing a list of function args (ex: 'x, y, z') into InputMessage attributes (ex: message.x, message.y, message.z) + params = [f"message.{i.strip()}" for i in function_args.split(",")] + # turn a list of parameters (ex: ["message.x", "message.y", "message.z"]) into a string of arguments + # ex: y = self.func(message.x, message.y, message.z) + writer.write("y = self.func(" + ", ".join(params) + ")\n\t\tyield self.OUTPUT, y") + + +def code_to_node(filename): + """ Take a python file containing a function and + output a file named node.py containing labgraph node. + + """ + library_name, function_name, function_args = "", "", "" + with open(filename, 'r') as reader: + with open("node.py", 'w') as writer: + for line in reader: + # first check if line contains library_name and function name + result = re.search("from (.*) import (.*)", line) + if result is not None: # a match was found + library_name, function_name = result.group(1), result.group(2) + # next check if line contains function arguments + result = re.search(f"[a-zA-z]* = {function_name}\((.*)\)", line) + if result is not None: + function_args = result.group(1) + + create_func(writer, library_name, function_name, function_args) + create_setup(writer, function_name) + create_feature(writer, function_name, function_args) + + +if __name__ == "__main__": + code_to_node(sys.argv[1]) \ No newline at end of file diff --git a/extensions/lgpilot/convolve.py b/extensions/lgpilot/convolve.py new file mode 100644 index 00000000..bf30158f --- /dev/null +++ b/extensions/lgpilot/convolve.py @@ -0,0 +1,12 @@ +import numpy as np +from scipy.signal import convolve + +# Create two arrays +x = np.array([1, 2, 3, 4]) +h = np.array([1, 2, 3]) + +# Perform convolution +y = convolve(x, h) + +# Print result +print(y) \ No newline at end of file diff --git a/extensions/lgpilot/node.py b/extensions/lgpilot/node.py new file mode 100644 index 00000000..7740f1c3 --- /dev/null +++ b/extensions/lgpilot/node.py @@ -0,0 +1,18 @@ +from scipy.signal import convolve +def convolve(x, h): + y = convolve(x, h) + return y + +class convolveNode(lg.Node): + INPUT = lg.Topic(InputMessage) + OUTPUT = lg.Topic(OutputMessage) + + def setup(self): + self.func = convolve + + @lg.subscriber(INPUT) + @lg.publisher(OUTPUT) + + def convolve_feature(self, message: InputMessage): + y = self.func(message.x, message.h) + yield self.OUTPUT, y \ No newline at end of file diff --git a/extensions/lgpilot/test.py b/extensions/lgpilot/test.py new file mode 100644 index 00000000..88efd559 --- /dev/null +++ b/extensions/lgpilot/test.py @@ -0,0 +1,136 @@ +# Import labgraph +import labgraph as lg +# Imports required for this example +from scipy.signal import convolve +import numpy as np + +import labgraph as lg +import numpy as np +import pytest +from ...generators.sine_wave_generator import ( + SineWaveChannelConfig, + SineWaveGenerator, +) + +from ..mixer_one_input_node import MixerOneInputConfig, MixerOneInputNode +from ..signal_capture_node import SignalCaptureConfig, SignalCaptureNode +from ..signal_generator_node import SignalGeneratorNode + + +# A data type used in streaming, see docs: Messages +class InputMessage(lg.Message): + x: np.ndarray + h: np.ndarray + +class OutputMessage(lg.Message): + data: np.ndarray + + +# ================================= CONVOLUTION =================================== + + +def convolve(x, h): + y = convolve(x, h) + return y + +class ConvolveNode(lg.Node): + INPUT = lg.Topic(InputMessage) + OUTPUT = lg.Topic(OutputMessage) + + def setup(self): + self.func = convolve + + @lg.subscriber(INPUT) + @lg.publisher(OUTPUT) + + def convolve_feature(self, message: InputMessage): + y = self.func(message.x, message.h) + yield self.OUTPUT, y + +# ====================================================================== + + +# class MixerOneInputConfig(lg.Config): +# # This is an NxM matrix (for M inputs, N outputs) +# weights: np.ndarray + +class ConvolveInputConfig(lg.Config): + array: np.ndarray + kernel: np.ndarray + + +class MyGraphConfig(lg.Config): + sine_wave_channel_config: SineWaveChannelConfig + convolve_config: ConvolveInputConfig + capture_config: SignalCaptureConfig + + +class MyGraph(lg.Graph): + + sample_source: SignalGeneratorNode + convolve_node: ConvolveNode + capture_node: SignalCaptureNode + + def setup(self) -> None: + self.capture_node.configure(self.config.capture_config) + self.sample_source.set_generator( + SineWaveGenerator(self.config.sine_wave_channel_config) + ) + self.convolve_node.configure(self.config.convolve_config) + + def connections(self) -> lg.Connections: + return ( + (self.convolve_node.INPUT, self.sample_source.SAMPLE_TOPIC), + (self.capture_node.SAMPLE_TOPIC, self.mixer_node.OUTPUT), + ) + + +def test_convolve_input_node() -> None: + """ + Tests that node convolves correctly, uses numpy arrays and kernel sizes as input + """ + + sample_rate = 1 # Hz + test_duration = 10 # sec + + # Test configurations + shape = (2,) + amplitudes = np.array([5.0, 3.0]) + frequencies = np.array([5, 10]) + phase_shifts = np.array([1.0, 5.0]) + midlines = np.array([3.0, -2.5]) + + test_array = [1, 2, 3] + test_kernel = [2] + + # Generate expected values + + expected = convolve(test_array, test_kernel) # use the convolve from the library to generate the expected values + + # Create the graph + generator_config = SineWaveChannelConfig( + shape, amplitudes, frequencies, phase_shifts, midlines, sample_rate + ) + capture_config = SignalCaptureConfig(int(test_duration / sample_rate)) + + # mixer_weights = np.identity(2) + # mixer_config = MixerOneInputConfig(mixer_weights) + + convolve_input_array = [1, 2, 3] + convolve_input_kernel = [2] + + convolve_config = ConvolveInputConfig(convolve_input_array, convolve_input_kernel) + + my_graph_config = MyGraphConfig(generator_config, convolve_config, capture_config) + + graph = MyGraph() + graph.configure(my_graph_config) + + runner = lg.LocalRunner(module=graph) + runner.run() + received = np.array(graph.capture_node.samples).T + np.testing.assert_almost_equal(received, expected) + +# 1. test the convolve function +# 2. create the graph and run it +# 3. repeat the same thing for other APIs -- just need to create simple test cases \ No newline at end of file