diff --git a/mrobosub_bringup/launch/bringup_sim_launch.xml b/mrobosub_bringup/launch/bringup_sim_launch.xml index 425ab17..09ebc43 100644 --- a/mrobosub_bringup/launch/bringup_sim_launch.xml +++ b/mrobosub_bringup/launch/bringup_sim_launch.xml @@ -1,6 +1,5 @@ - diff --git a/mrobosub_localization/mrobosub_localization/localization.py b/mrobosub_localization/mrobosub_localization/localization.py index 8e439a1..6e113f0 100644 --- a/mrobosub_localization/mrobosub_localization/localization.py +++ b/mrobosub_localization/mrobosub_localization/localization.py @@ -152,8 +152,8 @@ def imu_callback(self, msg: ImuINS): self.roll_pub.publish(Float64(data=roll)) -def main(args=None): - rclpy.init(args=args) +def main(): + rclpy.init() node = StateEstimation() diff --git a/mrobosub_planning/mrobosub_planning/abstract_states.py b/mrobosub_planning/mrobosub_planning/abstract_states.py index f638577..b1e9ff2 100644 --- a/mrobosub_planning/mrobosub_planning/abstract_states.py +++ b/mrobosub_planning/mrobosub_planning/abstract_states.py @@ -1,28 +1,39 @@ from mrobosub_planning.umrsm import State, Outcome +from mrobosub_planning.io_interface import Interface import rclpy import rclpy.node -from typing import Optional, List, Union from abc import abstractmethod class TimedState(State): - """base class for States which can be timed out. + """Base class for States which can be timed out. - expects an outcome called TimedOut and parameter named timeout. - override handle_once_timedout iff cleanup is needed after timeout + Must implement the following functions: + handle_if_node_timedout + handle_once_timedout + + Must specify the following parameters: + timeout """ - def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node): - super().__init__(prev_outcome, node) - self.start_time = self.io_node.get_clock().now().nanoseconds/(1e9) + def __init__(self, prev_outcome: Outcome, io: Interface): + super().__init__(prev_outcome, io) + self.clock = self.io.node.get_clock() + self.start_time = self.time() - def handle(self) -> Optional[Outcome]: - if self.io_node.get_clock().now().nanoseconds/(1e9) - self.start_time >= self.timeout: + def handle(self) -> Outcome | None: + if self.state_runtime() >= self.timeout: return self.handle_once_timedout() return self.handle_if_not_timedout() + def time(self) -> float: + return self.clock.now().nanoseconds / 1e9 + + def state_runtime(self) -> float: + return self.time() - self.start_time + @abstractmethod - def handle_if_not_timedout(self) -> Optional[Outcome]: + def handle_if_not_timedout(self) -> Outcome | None: pass @abstractmethod @@ -35,155 +46,151 @@ def timeout(self) -> float: pass -class ForwardAndWait(State): +class DoubleTimedState(TimedState): """ - Must specify the following outcomes: - Unreached - Reached + Must implement the following functions: + handle_first_phase + handle_second_phase + handle_once_timedout Must specify the following parameters: - target_heave: float - target_surge_time: float - wait_time: float - surge_speed: float + phase_one_time: float + phase_two_time: float """ - def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node): - super().__init__(prev_outcome, node) - self.start_time = self.io_node.get_clock().now().nanoseconds/(1e9) - self.waiting = False - - def handle(self) -> Optional[Outcome]: - if not self.waiting: - self.io_node.set_target_twist_surge(self.surge_speed) - self.io_node.set_target_pose_heave(self.target_heave) - - if self.io_node.get_clock().now().nanoseconds/(1e9) - self.start_time >= self.target_surge_time: - self.io_node.set_target_twist_surge(0) - self.waiting = True - self.start_time = self.io_node.get_clock().now().nanoseconds/(1e9) - else: - self.io_node.set_target_twist_surge(0) - - if self.io_node.get_clock().now().nanoseconds/(1e9) - self.start_time >= self.wait_time: - return self.handle_reached() - - return self.handle_unreached() - - @abstractmethod - def handle_reached(self) -> Optional[Outcome]: - pass + def __init__(self, prev_outcome: Outcome, io: Interface): + super().__init__(prev_outcome, io) - @abstractmethod - def handle_unreached(self) -> Optional[Outcome]: - pass + def handle(self) -> Outcome | None: + if self.state_runtime() < self.phase_one_time: + return self.handle_first_phase() + else: + return self.handle_second_phase() @property + def timeout(self) -> float: + return self.phase_one_time + self.phase_two_time + @abstractmethod - def target_heave(self) -> float: + def handle_first_phase(self) -> Outcome | None: pass - @property @abstractmethod - def target_surge_time(self) -> float: + def handle_second_phase(self) -> Outcome | None: pass @property @abstractmethod - def wait_time(self) -> float: + def phase_one_time(self) -> float: pass @property @abstractmethod - def surge_speed(self) -> float: + def phase_two_time(self) -> float: pass -class DoubleTimedState(State): +class ForwardAndWait(DoubleTimedState): """ - Must specify the following outcomes: - Unreached - Reached + Must implement the following functions: + handle_reached + + May optionally override the following functions: + handle_unreached Must specify the following parameters: - phase_one_time: float - phase_two_time: float + target_heave: float + target_surge_time: float + wait_time: float + surge_speed: float """ - def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node): - super().__init__(prev_outcome, node) - self.start_time = self.io_node.get_clock().now().nanoseconds/(1e9) - self.timed_out_first = False - - def handle(self) -> Optional[Outcome]: - if not self.timed_out_first: - outcome = self.handle_first_phase() - if self.io_node.get_clock().now().nanoseconds/(1e9) - self.start_time >= self.phase_one_time: - self.timed_out_first = True - self.start_time = self.io_node.get_clock().now().nanoseconds/(1e9) - else: - outcome = self.handle_second_phase() - if self.io_node.get_clock().now().nanoseconds/(1e9) - self.start_time >= self.phase_two_time: - outcome = self.handle_once_timedout() + def __init__(self, prev_outcome: Outcome, io: Interface): + super().__init__(prev_outcome, io) + + def handle_first_phase(self) -> Outcome | None: + self.io.set_target_twist_surge(self.surge_speed) + self.io.set_target_pose_heave(self.target_heave) + return self.handle_unreached() + + def handle_second_phase(self) -> Outcome | None: + self.io.set_target_twist_surge(0.0) + return self.handle_unreached() + + def handle_once_timedout(self) -> Outcome: + return self.handle_reached() + + @property + def phase_one_time(self) -> float: + return self.target_surge_time + + @property + def phase_two_time(self) -> float: + return self.wait_time - return outcome + def handle_unreached(self) -> Outcome | None: + return None @abstractmethod - def handle_first_phase(self) -> Optional[Outcome]: + def handle_reached(self) -> Outcome: pass + @property @abstractmethod - def handle_second_phase(self) -> Optional[Outcome]: + def target_heave(self) -> float: pass + @property @abstractmethod - def handle_once_timedout(self) -> Optional[Outcome]: + def target_surge_time(self) -> float: pass @property @abstractmethod - def phase_one_time(self) -> float: + def wait_time(self) -> float: pass @property @abstractmethod - def phase_two_time(self) -> float: + def surge_speed(self) -> float: pass class TurnToYaw(TimedState): """ - Must specify following outcomes: - Reached - TimedOut - - Must specify following parameters: - target_yaw: float - yaw_threshold: float - settle_time: float - timeout: float + Must implement the following functions: + handle_reached + + May optionally override the following functions: + handle_unreached + + Must specify the following parameters: + target_yaw: float + yaw_threshold: float + settle_time: float + timeout: float """ def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node): super().__init__(prev_outcome, node) - self.timer = self.io_node.get_clock().now().nanoseconds/(1e9) + self.settled_since = float("inf") - def handle_if_not_timedout(self) -> Optional[Outcome]: - self.io_node.set_target_pose_yaw(self.target_yaw) + def handle_if_not_timedout(self) -> Outcome | None: + self.io.set_target_pose_yaw(self.target_yaw) - if not self.io_node.is_yaw_within_threshold(self.yaw_threshold): - self.timer = self.io_node.get_clock().now().nanoseconds/(1e9) + if not self.io.is_yaw_within_threshold(self.yaw_threshold): + self.settled_since = self.time() - if self.io_node.get_clock().now().nanoseconds/(1e9) - self.timer >= self.settle_time: + if self.time() - self.settled_since >= self.settle_time: return self.handle_reached() return self.handle_unreached() - def handle_unreached(self) -> Optional[Outcome]: + def handle_unreached(self) -> Outcome | None: return None @abstractmethod - def handle_reached(self) -> Optional[Outcome]: + def handle_reached(self) -> Outcome | None: pass @property @@ -201,54 +208,47 @@ def yaw_threshold(self) -> float: def settle_time(self) -> float: pass - @property - @abstractmethod - def timeout(self) -> float: - pass - class AlignPathmarker(TimedState): - @property - @abstractmethod - def yaw_threshold(self) -> float: - pass - - @abstractmethod - def handle_no_measurements(self) -> Outcome: - pass + """ + Must implement the following functions: + handle_no_measurements + handle_aligned + handle_once_timedout - @abstractmethod - def handle_aligned(self) -> Outcome: - pass + Must specify the following parameters: + yaw_threshold: float + timeout: float + """ def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node) -> None: super().__init__(prev_outcome, node) - self.io_node.activate_bot_cam() - self.last_known_angle: Optional[float] = None + self.io.activate_bot_cam() + self.last_known_angle: float | None = None self.iter = 0 - self.measurements: List[float] = [] + self.measurements: list[float] = [] - def handle_if_not_timedout(self) -> Union[Outcome, None]: - self.io_node.set_target_twist_surge(0) + def handle_if_not_timedout(self) -> Outcome | None: + self.io.set_target_twist_surge(0) self.iter += 1 if self.iter < 50: return None if self.iter < 100: - pm_resp = self.io_node.query_pathmarker() - self.io_node.get_logger().info({f"{pm_resp=}"}) + pm_resp = self.io.query_pathmarker() # type: ignore # TODO: remove once ML is migrated + self.io.logger.info({f"{pm_resp=}"}) if pm_resp is not None: self.measurements.append(pm_resp) return None if self.iter == 100: - self.io_node.get_logger().info({"Calculating target"}) + self.io.logger.info({"Calculating target"}) if len(self.measurements) < 20: return self.handle_no_measurements() self.target_angle = sum(self.measurements) / len(self.measurements) - self.io_node.get_logger().info({f"{self.target_angle=}"}) + self.io.logger.info({f"{self.target_angle=}"}) self.yaw_threshold_count = 0 if self.iter >= 100: - self.io_node.set_target_pose_yaw(self.target_angle) - if self.io_node.is_yaw_within_threshold(self.yaw_threshold): + self.io.set_target_pose_yaw(self.target_angle) + if self.io.is_yaw_within_threshold(self.yaw_threshold): self.yaw_threshold_count += 1 else: self.yaw_threshold_count = 0 @@ -256,28 +256,46 @@ def handle_if_not_timedout(self) -> Union[Outcome, None]: return self.handle_aligned() return None + @property + @abstractmethod + def yaw_threshold(self) -> float: + pass + + @abstractmethod + def handle_no_measurements(self) -> Outcome: + pass -class CenterOnPathmarker(TimedState): @abstractmethod def handle_aligned(self) -> Outcome: pass - def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node): - super().__init__(prev_outcome, node) - self.io_node.activate_bot_cam() + +class CenterOnPathmarker(TimedState): + """ + Must implement the following functions: + handle_aligned + handle_once_timedout + + Must specify the following parameters: + timeout: float + """ + + def __init__(self, prev_outcome: Outcome, io: Interface): + super().__init__(prev_outcome, io) + self.io.activate_bot_cam() self.centered_count = 0 - def handle_if_not_timedout(self) -> Optional[Outcome]: - pm_resp = self.io_node.query_pathmarker_full() + def handle_if_not_timedout(self) -> Outcome | None: + pm_resp = self.io.query_pathmarker_full() # type: ignore # TODO: remove once ML is migrated if pm_resp is None: - self.io_node.set_target_twist_surge(0.0) - self.io_node.set_target_twist_sway(0.0) + self.io.set_target_twist_surge(0.0) + self.io.set_target_twist_sway(0.0) return None x_diff = pm_resp.centroid_x - 0.5 y_diff = pm_resp.centroid_y - 0.5 - self.io_node.set_target_twist_sway(3 * x_diff) - self.io_node.set_target_twist_surge(-3 * y_diff) + self.io.set_target_twist_sway(3 * x_diff) + self.io.set_target_twist_surge(-3 * y_diff) if abs(x_diff) < 0.1 and abs(y_diff) < 0.1: self.centered_count += 1 else: @@ -285,3 +303,7 @@ def handle_if_not_timedout(self) -> Optional[Outcome]: if self.centered_count > 50: return self.handle_aligned() return None + + @abstractmethod + def handle_aligned(self) -> Outcome: + pass diff --git a/mrobosub_planning/mrobosub_planning/captain.py b/mrobosub_planning/mrobosub_planning/captain.py index 5e9d104..9886d86 100755 --- a/mrobosub_planning/mrobosub_planning/captain.py +++ b/mrobosub_planning/mrobosub_planning/captain.py @@ -1,96 +1,123 @@ -from tokenize import Single -from typing import Dict, Type, Optional, Sequence from importlib import import_module from mrobosub_planning.umrsm import StateMachine, State, TransitionMap, Outcome import mrobosub_planning.common_states as common_states import mrobosub_planning.standard_run as standard_run -# import prequal_strafe +from mrobosub_lib import Node + import rclpy from rclpy.executors import SingleThreadedExecutor import threading -from rclpy.node import Node import sys -from mrobosub_planning.periodic_io import Captain +from mrobosub_planning.io_interface import Interface import traceback # maybe change this to something hacky like getting .transitions from the machine name module? -transition_maps: Dict[str, TransitionMap] = { +transition_maps: dict[str, TransitionMap] = { "standard": standard_run.transitions, # "heave_test": heave_test.transitions, } -def state_class_from_str(full_state: str, transitions: TransitionMap) -> Type[State]: +def state_class_from_str(full_state: str, transitions: TransitionMap) -> type[State]: """ Find the associated class object from a given state name. Arguments: full_state (str): The state passed into roslaunch. Could be in the format `state` or `module.state`. - transitions (TransitionMap): The transition map from the associated machine name. - + transitions (TransitionMap): The transition map from the associated machine name. + Returns: The class object for the state or a ValueError if there is an error finding the state. """ - def outcome_to_state_str(outcome: Type[Outcome]) -> str: - return outcome.__qualname__.split('.')[0] - - def outcome_to_state(outcome: Type[Outcome]) -> Type[State]: + + def outcome_to_state_str(outcome: type[Outcome]) -> str: + return outcome.__qualname__.split(".")[0] + + def outcome_to_state(outcome: type[Outcome]) -> type[State]: module = import_module(outcome.__module__) return getattr(module, outcome_to_state_str(outcome)) - - if full_state.count('.') > 1: + + if full_state.count(".") > 1: raise ValueError(f"{full_state} should have at most one '.'") - state = full_state.split('.')[-1] - found_outcomes = [outcome for outcome in transitions.keys() if outcome_to_state_str(outcome) == state] + state = full_state.split(".")[-1] + found_outcomes = [ + outcome + for outcome in transitions.keys() + if outcome_to_state_str(outcome) == state + ] unique_found_states = set(outcome_to_state(outcome) for outcome in found_outcomes) - - if '.' in full_state: - unique_found_states = set(state for state in unique_found_states if state.__module__.removeprefix("mrobosub_planning.") == full_state.split(".")[0]) + + if "." in full_state: + unique_found_states = set( + state + for state in unique_found_states + if state.__module__.removeprefix("mrobosub_planning.") + == full_state.split(".")[0] + ) if len(unique_found_states) != 1: - raise ValueError(f'{full_state=} does not uniquely describe a state. {unique_found_states=}') - + raise ValueError( + f"{full_state=} does not uniquely describe a state. {unique_found_states=}" + ) + found_state = unique_found_states.pop() return found_state -def main(args: Optional[Sequence[str]]=None) -> None: + +class Captain(Node): + def __init__(self, machine_name: str, full_state: str) -> None: + super().__init__("captain") + self.machine_name = machine_name + self.starting_state = state_class_from_str( + full_state, transition_maps[machine_name] + ) + + def run(self) -> None: + io = Interface(self) + machine = StateMachine( + self.machine_name, + transition_maps[self.machine_name], + self.starting_state, + common_states.Stop, + io, + ) + + try: + machine.run() + except Exception: + self.get_logger().info(f"{traceback.format_exc()}") + self.tick = 0 + self.timer = self.create_timer(0.1, self.reset_node) + while(self.tick < 20): + pass + self.timer.cancel() + + def reset_node(self) -> None: + self.tick += 1 + self.io.reset_target_twist() + +def main() -> None: rclpy.init() - captain_node = Captain(name="captain") - captain_node.get_logger().info("Captain Node Created") + + # Syntax `ros2 launch mrobosub_planning captain.launch machine:= state:=` + machine_name = sys.argv[1] + full_state = sys.argv[2] + + captain_node = Captain(machine_name, full_state) executor = SingleThreadedExecutor() executor.add_node(captain_node) t = threading.Thread(target=executor.spin, daemon=False) t.start() - captain_node.get_logger().info("Captain Node Spinning") - # Syntax `roslaunch mrobosub_planning captain.launch machine:= state:=` - machine_name = sys.argv[1] - full_state = sys.argv[2] + captain_node.run() + + executor.shutdown() + t.join() - starting_state = state_class_from_str(full_state, transition_maps[machine_name]) - - machine = StateMachine( - machine_name, - transition_maps[machine_name], - starting_state, - common_states.Stop, - captain_node - ) - try: - machine.run() - except Exception as e: - captain_node.get_logger().info(f"{traceback.format_exc()}") - rate = captain_node.create_rate(50) - for _ in range(20): - captain_node.reset_target_twist() - rate.sleep() - finally: - executor.shutdown() - t.join() if __name__ == "__main__": main() diff --git a/mrobosub_planning/mrobosub_planning/common_states.py b/mrobosub_planning/mrobosub_planning/common_states.py index 8d834db..258dd7f 100644 --- a/mrobosub_planning/mrobosub_planning/common_states.py +++ b/mrobosub_planning/mrobosub_planning/common_states.py @@ -1,7 +1,6 @@ -from typing import Optional, Union from mrobosub_planning.umrsm import State, Outcome +from mrobosub_planning.io_interface import Interface from mrobosub_planning.abstract_states import TimedState -import rclpy class Start(State): @@ -22,16 +21,16 @@ class TimedOut(Outcome): target_heave: float = 0.75 heave_threshold: float = 0.1 timeout: float = 15.0 - yaw_threshold: float = 2. - target_yaw: float = 0. + yaw_threshold: float = 2.0 + target_yaw: float = 0.0 - def handle_if_not_timedout(self) -> Union[Submerged, None]: - self.io_node.set_target_pose_heave(self.target_heave) - self.io_node.set_target_pose_yaw(self.target_yaw) + def handle_if_not_timedout(self) -> Submerged | None: + self.io.set_target_pose_heave(self.target_heave) + self.io.set_target_pose_yaw(self.target_yaw) - if self.io_node.is_heave_within_threshold( + if self.io.is_heave_within_threshold( self.heave_threshold - ) and self.io_node.is_yaw_within_threshold(self.yaw_threshold): + ) and self.io.is_yaw_within_threshold(self.yaw_threshold): return self.Submerged() return None @@ -43,14 +42,14 @@ class Stop(State): class Surfaced(Outcome): pass - def __init__(self, prev_outcome: Outcome, node: rclpy.node.Node): - super().__init__(prev_outcome, node) - self.io_node.reset_target_twist() - self.rate = self.io_node.create_rate(50) + def __init__(self, prev_outcome: Outcome, io: Interface): + super().__init__(prev_outcome, io) + self.io.reset_target_twist() + self.rate = self.io.node.create_rate(50) def handle(self) -> None: for _ in range(20): - self.io_node.reset_target_twist() + self.io.reset_target_twist() self.rate.sleep() return None diff --git a/mrobosub_planning/mrobosub_planning/periodic_io.py b/mrobosub_planning/mrobosub_planning/io_interface.py similarity index 51% rename from mrobosub_planning/mrobosub_planning/periodic_io.py rename to mrobosub_planning/mrobosub_planning/io_interface.py index 966b54d..daa9d4c 100644 --- a/mrobosub_planning/mrobosub_planning/periodic_io.py +++ b/mrobosub_planning/mrobosub_planning/io_interface.py @@ -1,11 +1,8 @@ import math -from typing_extensions import NamedTuple -import rclpy from rclpy.node import Node -from rclpy.service import Service -from std_msgs.msg import Float64, Bool, Int32 +from std_msgs.msg import Float64, Int32 + # from mrobosub_msgs.srv import ObjectPosition, ObjectPositionResponse, PathmarkerAngle # type: ignore -from typing import Dict, Type, Mapping, Optional, Tuple from enum import Enum, auto from std_srvs.srv import SetBool from dataclasses import dataclass @@ -15,78 +12,112 @@ def angle_error(setpoint: float, state: float) -> float: return (((setpoint - state) % 360) + 360) % 360 -Namespace = Type - - class ImageTarget(Enum): GATE_BLUE = auto() GATE_RED = auto() - @dataclass class Pose: - yaw:float = 0.0 - heave:float = 0.0 - roll:float = 0.0 - x:float = 0.0 - y:float = 0.0 - -class Captain(Node): - ''' + yaw: float = 0.0 + pitch: float = 0.0 + roll: float = 0.0 + x: float = 0.0 + y: float = 0.0 + heave: float = 0.0 + + +class Interface: + """ Public interface class for publishers and subscribers - ''' - def __init__(self, name:str='captain'): - ''' - @param name - name of the node (should be captain) - ''' - super().__init__(name) - - #Poses + """ + + def __init__(self, node: Node): + self.logger = node.get_logger() + self.node = node + + # Poses self.pose = Pose() self.target_pose = Pose() # Subscribers - self._yaw_sub = self.create_subscription(Float64, "/pose/yaw", self.yaw_callback, 10) - self._heave_sub = self.create_subscription(Float64, "/pose/heave", self.heave_callback, 10) - self._roll_sub = self.create_subscription(Float64, "/pose/roll", self.roll_callback, 10) - self._x_sub = self.create_subscription(Float64, "/pose/x", self.x_callback, 10) - self._y_sub = self.create_subscription(Float64, "/pose/y", self.y_callback, 10) + self._yaw_sub = node.create_subscription( + Float64, "/pose/yaw", self.yaw_callback, 10 + ) + self._pitch_sub = node.create_subscription( + Float64, "/pose/pitch", self.pitch_callback, 10 + ) + self._roll_sub = node.create_subscription( + Float64, "/pose/roll", self.roll_callback, 10 + ) + self._x_sub = node.create_subscription(Float64, "/pose/x", self.x_callback, 10) + self._y_sub = node.create_subscription(Float64, "/pose/y", self.y_callback, 10) + self._heave_sub = node.create_subscription( + Float64, "/pose/heave", self.heave_callback, 10 + ) # Publishers - self._target_pose_heave_pub = self.create_publisher(Float64, "/target_pose/heave", 1) - self._target_pose_yaw_pub = self.create_publisher(Float64, "/target_pose/yaw", 1) - self._target_pose_roll_pub = self.create_publisher(Float64, "/target_pose/roll", 1) - self._target_pose_x_pub = self.create_publisher(Float64, "/target_pose/x", 1) - self._target_pose_y_pub = self.create_publisher(Float64, "/target_pose/y", 1) - - self._target_twist_yaw_pub = self.create_publisher(Float64, "/target_twist/yaw", 1) - self._target_twist_roll_pub = self.create_publisher(Float64, "/target_twist/roll", 1) - self._target_twist_surge_pub = self.create_publisher(Float64, "/target_twist/surge", 1) - self._target_twist_sway_pub = self.create_publisher(Float64, "/target_twist/sway", 1) - self._target_twist_heave_pub = self.create_publisher(Float64, "/target_twist/heave", 1) - - self._left_dropper_pub = self.create_publisher(Int32, "/left_servo/angle", 1) - self._right_dropper_pub = self.create_publisher(Int32, "/right_servo/angle", 1) + self._target_pose_yaw_pub = node.create_publisher( + Float64, "/target_pose/yaw", 1 + ) + self._target_pose_pitch_pub = node.create_publisher( + Float64, "/target_pose/pitch", 1 + ) + self._target_pose_roll_pub = node.create_publisher( + Float64, "/target_pose/roll", 1 + ) + self._target_pose_x_pub = node.create_publisher(Float64, "/target_pose/x", 1) + self._target_pose_y_pub = node.create_publisher(Float64, "/target_pose/y", 1) + self._target_pose_heave_pub = node.create_publisher( + Float64, "/target_pose/heave", 1 + ) + + self._target_twist_yaw_pub = node.create_publisher( + Float64, "/target_twist/yaw", 1 + ) + self._target_twist_pitch_pub = node.create_publisher( + Float64, "/target_twist/pitch", 1 + ) + self._target_twist_roll_pub = node.create_publisher( + Float64, "/target_twist/roll", 1 + ) + self._target_twist_surge_pub = node.create_publisher( + Float64, "/target_twist/surge", 1 + ) + self._target_twist_sway_pub = node.create_publisher( + Float64, "/target_twist/sway", 1 + ) + self._target_twist_heave_pub = node.create_publisher( + Float64, "/target_twist/heave", 1 + ) + + self._left_dropper_pub = node.create_publisher(Int32, "/left_servo/angle", 1) + self._right_dropper_pub = node.create_publisher(Int32, "/right_servo/angle", 1) # Services # TODO: Add services for perception topics when those are created. - - self._zed_on_srv = self.create_client(SetBool, "/zed/on") + + self._zed_on_srv = node.create_client(SetBool, "/zed/on") attempt_counter = 0 - while not self._zed_on_srv.wait_for_service(timeout_sec=1.0) and attempt_counter < 5: - self.get_logger().info('\"/zed/on\" service not available, waiting again...') + while ( + not self._zed_on_srv.wait_for_service(timeout_sec=1.0) + and attempt_counter < 5 + ): + self.logger.info('"/zed/on" service not available, waiting again...') attempt_counter += 1 if attempt_counter == 5: - self.get_logger().error('Failed to connect to \"/zed/on\" service') + self.logger.error('Failed to connect to "/zed/on" service') attempt_counter = 0 - self._bot_cam_on_srv = self.create_client(SetBool, "/bot_cam/on") - while not self._bot_cam_on_srv.wait_for_service(timeout_sec=1.0) and attempt_counter < 5: - self.get_logger().info('\"/bot_cam/on\" service not available, waiting again...') + self._bot_cam_on_srv = node.create_client(SetBool, "/bot_cam/on") + while ( + not self._bot_cam_on_srv.wait_for_service(timeout_sec=1.0) + and attempt_counter < 5 + ): + self.logger.info('"/bot_cam/on" service not available, waiting again...') attempt_counter += 1 if attempt_counter == 5: - self.get_logger().error('Failed to connect to "/bot_cam/on\" service') + self.logger.error('Failed to connect to "/bot_cam/on" service') def is_yaw_within_threshold(self, threshold: float) -> float: return abs(angle_error(self.target_pose.yaw, self.pose.yaw)) <= threshold @@ -117,11 +148,11 @@ def set_target_pose_yaw(self, target_yaw: float) -> None: self._target_pose_yaw_pub.publish(msg) self.target_pose.yaw = target_yaw - def set_target_pose_heave(self, target_heave: float) -> None: + def set_target_pose_pitch(self, target_pitch: float) -> None: msg = Float64() - msg.data = float(target_heave) - self._target_pose_heave_pub.publish(msg) - self.target_pose.heave = target_heave + msg.data = float(target_pitch) + self._target_pose_pitch_pub.publish(msg) + self.target_pose.pitch = target_pitch def set_target_pose_roll(self, target_roll: float) -> None: msg = Float64() @@ -141,16 +172,27 @@ def set_target_pose_y(self, target_y: float) -> None: self._target_pose_y_pub.publish(msg) self.target_pose.y = target_y - def set_target_twist_roll(self, override_roll: float) -> None: + def set_target_pose_heave(self, target_heave: float) -> None: msg = Float64() - msg.data = float(override_roll) - self._target_twist_roll_pub.publish(msg) + msg.data = float(target_heave) + self._target_pose_heave_pub.publish(msg) + self.target_pose.heave = target_heave def set_target_twist_yaw(self, override_yaw: float) -> None: msg = Float64() msg.data = float(override_yaw) self._target_twist_yaw_pub.publish(msg) + def set_target_twist_pitch(self, override_pitch: float) -> None: + msg = Float64() + msg.data = float(override_pitch) + self._target_twist_pitch_pub.publish(msg) + + def set_target_twist_roll(self, override_roll: float) -> None: + msg = Float64() + msg.data = float(override_roll) + self._target_twist_roll_pub.publish(msg) + def set_target_twist_surge(self, override_surge: float) -> None: msg = Float64() msg.data = float(override_surge) @@ -167,11 +209,12 @@ def set_target_twist_heave(self, override_heave: float) -> None: self._target_twist_heave_pub.publish(msg) def reset_target_twist(self) -> None: - self.set_target_twist_heave(0.) - self.set_target_twist_yaw(0.) - self.set_target_twist_surge(0.) - self.set_target_twist_roll(0.) - self.set_target_twist_sway(0.) + self.set_target_twist_yaw(0.0) + self.set_target_twist_pitch(0.0) + self.set_target_twist_roll(0.0) + self.set_target_twist_surge(0.0) + self.set_target_twist_sway(0.0) + self.set_target_twist_heave(0.0) def set_left_dropper_angle(self, angle: int) -> None: msg = Int32() @@ -183,63 +226,41 @@ def set_right_dropper_angle(self, angle: int) -> None: msg.data = int(angle) self._right_dropper_pub.publish(msg) - def _call_service(self, service: Service, request: SetBool.Request, error_string:str) -> bool: - success = True - future = service.call_async(request) - # TODO: Potentially can add future callbacks to perform this async, but for now like this - # Timeout is set to 2s - rclpy.spin_until_future_complete(self, future, timeout_sec=2.0) - if not future.done(): - self.get_logger().error(f"{error_string}: Service call timed out") - success = False - exc = future.execption() - if success and exc: - self.get_logger().error(f"{error_string}: {exc!r}") - success = False - return success - - - def activate_zed(self) -> bool: - self.req = self._bot_cam_on_srv.Request() - self.req.data = False + bot_cam_req = self._bot_cam_on_srv.Request(data=False) + bot_cam_res = self._bot_cam_on_srv.call(bot_cam_req) - success = self._call_service(self._bot_cam_on_srv, self.req, "Turning BotCam Off") + zed_req = self._zed_on_srv.Request(data=True) + zed_res = self._zed_on_srv.call(zed_req) - self.req = self._zed_on_srv.Request() - self.req.data = True - - success = success and self._call_service(self._zed_on_srv, self.req, "Turning Zed On") + success = bot_cam_res.success and zed_res.success return success def activate_bot_cam(self) -> bool: - self.req = self._zed_on_srv.Request() - self.req.data = False - - success = self._call_service(self._zed_on_srv, self.req, "Turning Zed Off") + zed_req = self._zed_on_srv.Request(data=False) + zed_res = self._zed_on_srv.call(zed_req) - self.req = self._bot_cam_on_srv.Request() - self.req.data = True + bot_cam_req = self._bot_cam_on_srv.Request(data=True) + bot_cam_res = self._bot_cam_on_srv.call(bot_cam_req) - success = success and self._call_service(self._bot_cam_on_srv, self.req, "Turning BotCam On") + success = bot_cam_res.success and zed_res.success return success def deactivate_cameras(self) -> bool: - self.req = self._zed_on_srv.Request() - self.req.data = False - success = self._call_service(self._zed_on_srv, self.req, "Turning Zed Off") + zed_req = self._zed_on_srv.Request(data=False) + zed_res = self._zed_on_srv.call(zed_req) - self.req = self._bot_cam_on_srv.Request() - self.req.data = False + bot_cam_req = self._bot_cam_on_srv.Request(data=False) + bot_cam_res = self._bot_cam_on_srv.call(bot_cam_req) - success = success and self._call_service(self._bot_cam_on_srv, self.req, "Turning BotCam Off") + success = bot_cam_res.success and zed_res.success return success def yaw_callback(self, msg: Float64) -> None: self.pose.yaw = msg.data - def heave_callback(self, msg: Float64) -> None: - self.pose.heave = msg.data + def pitch_callback(self, msg: Float64) -> None: + self.pose.pitch = msg.data def roll_callback(self, msg: Float64) -> None: self.pose.roll = msg.data @@ -249,3 +270,6 @@ def x_callback(self, msg: Float64) -> None: def y_callback(self, msg: Float64) -> None: self.pose.y = msg.data + + def heave_callback(self, msg: Float64) -> None: + self.pose.heave = msg.data diff --git a/mrobosub_planning/mrobosub_planning/typecheck.py b/mrobosub_planning/mrobosub_planning/typecheck.py index b8375c1..f5903e5 100644 --- a/mrobosub_planning/mrobosub_planning/typecheck.py +++ b/mrobosub_planning/mrobosub_planning/typecheck.py @@ -1,11 +1,16 @@ -from captain import transition_maps +from mrobosub_planning.captain import transition_maps +from mrobosub_planning.umrsm import State from inspect import signature errors = 0 for map_name, map in transition_maps.items(): + expected_init_params = len(signature(State.__init__).parameters) for outcome, state in map.items(): - if len(signature(state.__init__).parameters) != 2: - print(f"{state}.__init__ has wrong number of parameters") + init_params = len(signature(state.__init__).parameters) + if init_params != expected_init_params: + print( + f"{state}.__init__ has wrong number of parameters (expected {expected_init_params}, has {init_params})" + ) errors += 1 if not state.is_valid_income_type(outcome): print( diff --git a/mrobosub_planning/mrobosub_planning/umrsm.py b/mrobosub_planning/mrobosub_planning/umrsm.py index 73570b7..d860d9e 100644 --- a/mrobosub_planning/mrobosub_planning/umrsm.py +++ b/mrobosub_planning/mrobosub_planning/umrsm.py @@ -5,19 +5,17 @@ from abc import abstractmethod from typing import ( Any, - Dict, - Optional, - Type, - Tuple, TYPE_CHECKING, ) import warnings +from dataclasses import dataclass +from typing_extensions import dataclass_transform, Self + import rclpy from std_msgs.msg import String from std_srvs.srv import Trigger -from dataclasses import dataclass -from typing_extensions import dataclass_transform, Self -from mrobosub_planning.periodic_io import Captain + +from mrobosub_planning.io_interface import Interface STATE_TOPIC = "captain/current_state" SOFT_STOP_SERVICE = "captain/soft_stop" @@ -61,7 +59,7 @@ class SoftStopTransition(Outcome): class StateMeta(type): - _outcomes: dict[str, Type[Outcome]] + _outcomes: dict[str, type[Outcome]] def __new__(cls, name: str, bases: tuple, dict_: dict) -> "StateMeta": state = super().__new__(cls, name, bases, dict_) @@ -108,16 +106,12 @@ class State(metaclass=StateMeta): _num_unexpected_params = 0 - def __init__(self, prev_outcome: Outcome, node: Captain): - """ - node: The io_node which can be used via the Periodic_IO interface to access various publishers - and subscribers, and can be used to create new publishers/subscribers - """ + def __init__(self, prev_outcome: Outcome, io: Interface): self.prev_outcome = prev_outcome - self.io_node = node + self.io = io @abstractmethod - def handle(self) -> Optional[Outcome]: + def handle(self) -> Outcome | None: """Contains the logic to be run for a particular state. Is called repeatedly for each iteration of the state, including the first one. @@ -125,11 +119,11 @@ def handle(self) -> Optional[Outcome]: pass @classmethod - def is_valid_income_type(cls, outcome_type: Type[Outcome]) -> bool: + def is_valid_income_type(cls, outcome_type: type[Outcome]) -> bool: return True @classmethod - def with_params(cls, **kwargs: Any) -> Type[Self]: + def with_params(cls, **kwargs: Any) -> type[Self]: num_unexpected = 0 for k in kwargs: if not hasattr(cls, k): @@ -150,7 +144,7 @@ def __repr__(cls) -> str: return f"" -TransitionMap = Dict[Type[Outcome], Type[State]] +TransitionMap = dict[type[Outcome], type[State]] class StateMachine: @@ -160,9 +154,9 @@ def __init__( self, name: str, transitions: TransitionMap, - StartState: Type[State], - StopState: Type[State], - captainNode: Captain + StartState: type[State], + StopState: type[State], + io: Interface, ): """Creates a new state machine. @@ -180,25 +174,30 @@ def __init__( self.StartState = StartState self.transitions = transitions self.StopState = StopState - self.node = captainNode + self.node = io.node + self.io = io - self._soft_stop_srv = self.node.create_service(Trigger, SOFT_STOP_SERVICE, self.soft_stop) + self._soft_stop_srv = self.node.create_service( + Trigger, SOFT_STOP_SERVICE, self.soft_stop + ) self.stop_signal_recvd = False - def soft_stop(self, req: Trigger.Request, res: Trigger.Response) -> Trigger.Response: + def soft_stop( + self, req: Trigger.Request, res: Trigger.Response + ) -> Trigger.Response: self.stop_signal_recvd = True res.success = True res.message = type(self.current_state).__qualname__ return res - def run(self, hz: int = 50) -> Optional[Outcome]: + def run(self, hz: int = 50) -> Outcome | None: """Performs a run, beginning with the StartState and ending when it reaches StopState. Returns the Outcome from calling handle() on StopState. """ rate = self.node.create_rate(hz) publisher = self.node.create_publisher(String, STATE_TOPIC, 1) - self.current_state = self.StartState(InitTransition(), self.node) + self.current_state = self.StartState(InitTransition(), self.io) while type(self.current_state) != self.StopState: self.run_once(publisher) rate.sleep() @@ -219,7 +218,7 @@ def run_once(self, state_topic_pub: rclpy.publisher.Publisher) -> None: outcome = SoftStopTransition() NextState = self.StopState outcome_name = "!! Abort !!" - self.node.get_logger().info( + self.io.logger.info( f"Aborting from state {type(self.current_state).__qualname__} and moving to stop state" ) else: @@ -230,11 +229,11 @@ def run_once(self, state_topic_pub: rclpy.publisher.Publisher) -> None: NextState = self.transitions[outcome_type] if type(self.current_state) == NextState: - self.node.get_logger().warn( + self.io.logger.warn( f"{type(self.current_state).__qualname__} contains a type which returns itself!" ) - self.node.get_logger().info( + self.io.logger.info( f"transition {type(self.current_state).__qualname__} --[{outcome_name}]--> {NextState.__qualname__}" ) - self.current_state = NextState(outcome, self.node) + self.current_state = NextState(outcome, self.io) diff --git a/mrobosub_planning/mrobosub_planning/viz.py b/mrobosub_planning/mrobosub_planning/viz.py index 94aef72..bb9a132 100644 --- a/mrobosub_planning/mrobosub_planning/viz.py +++ b/mrobosub_planning/mrobosub_planning/viz.py @@ -1,7 +1,6 @@ -from umrsm import * import graphviz -from captain import transition_maps +from mrobosub_planning.captain import transition_maps from mrobosub_planning.umrsm import TransitionMap from pathlib import Path