diff --git a/auto_eval/robot/policy_clients.py b/auto_eval/robot/policy_clients.py index 38bb0ab..5051bc7 100644 --- a/auto_eval/robot/policy_clients.py +++ b/auto_eval/robot/policy_clients.py @@ -21,6 +21,7 @@ import requests from PIL import Image +from auto_eval.utils import get_url from auto_eval.utils.info import print_yellow @@ -57,7 +58,7 @@ def __call__( else {} ) action = self._session.post( - f"http://{self.host}:{self.port}/act", + get_url(self.host, self.port, "act"), json={ "image": obs_dict["image_primary"], "instruction": language_instruction, @@ -77,7 +78,7 @@ def reset( self, ): # Post request to the policy server to reset internal states - response = self._session.post(f"http://{self.host}:{self.port}/reset") + response = self._session.post(get_url(self.host, self.port, "reset")) # If we get a response, check if it was successful if response.status_code == 404: # if the policy server doesn't have a /reset endpoint, ignore diff --git a/auto_eval/utils/__init__.py b/auto_eval/utils/__init__.py index e69de29..4767ae1 100644 --- a/auto_eval/utils/__init__.py +++ b/auto_eval/utils/__init__.py @@ -0,0 +1,12 @@ +def get_url( + host: str, port: int, endpoint: str | None = None, protocol: str = "http://" +): + """ + Get the URL for a given host and port; if port is negative, skip it. + Cleans the host and endpoint strings + """ + # Remove http:// or https:// from host if present + host_str = host.replace("http://", "").replace("https://", "") + port_str = f":{port}" if int(port) >= 0 else "" + endpoint_str = f"/{endpoint.lstrip('/')}" if endpoint else "" + return f"{protocol}{host_str}{port_str}{endpoint_str}" diff --git a/job_scheduler.py b/job_scheduler.py index a5d9ade..e559bb8 100644 --- a/job_scheduler.py +++ b/job_scheduler.py @@ -62,6 +62,7 @@ sleep_and_torque_off, torque_on, ) +from auto_eval.utils import get_url from auto_eval.utils.info import print_red from auto_eval.web_ui.launcher import RobotIPs @@ -237,8 +238,18 @@ class PolicyServerTestRequest(BaseModel): def is_port_open(host: str, port: int, timeout: float = 2.0) -> Tuple[bool, str]: - """Check if a port is open on a host.""" + """ + Check if a port is open on a host. + If port < 0, only verify that the hostname can be resolved. + """ try: + # If port is negative, only check if hostname can be resolved + if port < 0: + # Just try to resolve the hostname + socket.gethostbyname(host) + return True, f"Hostname resolved successfully" + + # Normal port checking sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(timeout) result = sock.connect_ex((host, port)) @@ -267,7 +278,7 @@ def test_policy_server(request: PolicyServerTestRequest): status_code=400, content={ "status": "error", - "message": f"Could not connect to policy server at {request.policy_server_ip}:{request.policy_server_port}. {port_message}", + "message": f"Could not connect to policy server at {get_url(request.policy_server_ip, request.policy_server_port)}. {port_message}", "error_type": "port_closed", }, ) @@ -285,7 +296,9 @@ def test_policy_server(request: PolicyServerTestRequest): json_numpy.patch() raw_response = requests.post( - f"http://{request.policy_server_ip}:{request.policy_server_port}/act", + get_url( + request.policy_server_ip, request.policy_server_port, endpoint="act" + ), json={ "image": dummy_image, "proprio": dummy_proprio, @@ -314,12 +327,12 @@ def test_policy_server(request: PolicyServerTestRequest): error_type = "connection_error" if isinstance(e, requests.exceptions.Timeout): error_type = "timeout_error" - message = f"Connection to policy server at {request.policy_server_ip}:{request.policy_server_port} timed out. The server might be busy or unreachable." + message = f"Connection to policy server at {get_url(request.policy_server_ip, request.policy_server_port)} timed out. The server might be busy or unreachable." elif isinstance(e, requests.exceptions.ConnectionError): - message = f"Could not connect to policy server at {request.policy_server_ip}:{request.policy_server_port}. Please check the IP and port." + message = f"Could not connect to policy server at {get_url(request.policy_server_ip, request.policy_server_port)}. Please check the IP and port." else: error_type = "request_error" - message = f"Error connecting to policy server at {request.policy_server_ip}:{request.policy_server_port}: {str(e)}" + message = f"Error connecting to policy server at {get_url(request.policy_server_ip, request.policy_server_port)}. {str(e)}" return JSONResponse( status_code=400,