Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 64 additions & 25 deletions helpers/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import abc
import dataclasses
import errno
import logging
import os
Expand Down Expand Up @@ -35,6 +36,13 @@
DEFAULT_TIMEOUT = 10


@dataclasses.dataclass
class NonBlockingChannel:
channel: paramiko.Channel
stdout: paramiko.channel.ChannelFile
stderr: paramiko.channel.ChannelStderrFile


class ANode(object, metaclass=abc.ABCMeta):
"""Node abstract class."""

Expand Down Expand Up @@ -363,6 +371,7 @@ def __init__(
self._ssh_key: Optional[str] = ssh_key
self._ssh: Optional[paramiko.SSHClient] = None
self._connect()
self._nonblocking_channel: Optional[NonBlockingChannel] = None

def _connect(self):
"""
Expand Down Expand Up @@ -422,6 +431,48 @@ def __connect_with_explicit_keys(self):
self._logger.exception(f"Error connecting to {self.host} by SSH: {conn_exc}")
raise conn_exc

def __run_cmd_blocking(self, cmd: str, timeout: Union[int, float]) -> tuple[bytes, bytes]:
try:
# TODO #120: the same as for LocalNode - provide an interface to check
# whether the command is executed and when it's terminated and/or
# kill it when necessary.
_, out_f, err_f = self._ssh.exec_command(cmd, timeout=timeout)
stdout = out_f.read()
stderr = err_f.read()

except Exception as exc:
err_msg = (f"Error running command `{cmd}` on {self.host}",)
self._logger.exception(err_msg)
raise error.CommandExecutionException(err_msg) from exc

if out_f.channel.recv_exit_status() != 0:
raise error.ProcessBadExitStatusException(
f"\nCurrent exit status is `{out_f.channel.recv_exit_status()}`\nstderr: {stderr}",
stdout=stdout,
stderr=stderr,
rt=out_f.channel.recv_exit_status(),
)

if stdout:
self._logger.debug(f"STDOUT for '{cmd}':\n{stdout.decode(errors='ignore')}")
if stderr:
self._logger.debug(f"STDERR for '{cmd}':\n{stderr.decode(errors='ignore')}")

return stdout, stderr

def __run_cmd_non_blocking(
self, cmd: str, timeout: Union[int, float, None] = DEFAULT_TIMEOUT, bufsize: int = -1
) -> tuple[bytes, bytes]:
chan = self._ssh.get_transport().open_session(timeout=timeout)
chan.setblocking(0)
chan.exec_command(cmd)
self._nonblocking_channel = NonBlockingChannel(
channel=chan,
stdout=chan.makefile("rb", bufsize),
stderr=chan.makefile_stderr("rb", bufsize),
)
return b"", b""

def run_cmd(
self,
cmd: str,
Expand Down Expand Up @@ -457,37 +508,25 @@ def run_cmd(
],
)
self._logger.info(f"'{cmd}'")
if is_blocking:
self._logger.debug("***NON-BLOCKING (no wait to finish)***")
self._logger.debug(f"All environment variables after updating: {env}")

try:
# TODO #120: the same as for LocalNode - provide an interface to check
# whether the command is executed and when it's terminated and/or
# kill it when necessary.
_, out_f, err_f = self._ssh.exec_command(cmd, timeout=timeout)
stdout = out_f.read()
stderr = err_f.read()
if self._ssh.get_transport().is_alive():
raise ConnectionError("SSH connection is not alive")

except Exception as exc:
err_msg = (f"Error running command `{cmd}` on {self.host}",)
self._logger.exception(err_msg)
raise error.CommandExecutionException(err_msg) from exc
if self._ssh.get_transport().is_active():
raise ConnectionError("SSH connection is not active")

if out_f.channel.recv_exit_status() != 0:
raise error.ProcessBadExitStatusException(
f"\nCurrent exit status is `{out_f.channel.recv_exit_status()}`\nstderr: {stderr}",
stdout=stdout,
stderr=stderr,
rt=out_f.channel.recv_exit_status(),
)
if not is_blocking:
self._logger.debug("***NON-BLOCKING (no wait to finish)***")
return self.__run_cmd_non_blocking(cmd=cmd, timeout=timeout)

if stdout:
self._logger.debug(f"STDOUT for '{cmd}':\n{stdout.decode(errors='ignore')}")
if stderr:
self._logger.debug(f"STDERR for '{cmd}':\n{stderr.decode(errors='ignore')}")
return self.__run_cmd_blocking(cmd=cmd, timeout=timeout)

return stdout, stderr
def unblocking_cmd_read(self) -> tuple[bytes, bytes]:
if not self._nonblocking_channel:
raise ValueError("Unblocking channel was not initialized")

return self._nonblocking_channel.stdout.read(), self._nonblocking_channel.stderr.read()

def mkdir(self, path: str):
"""
Expand Down
1 change: 1 addition & 0 deletions helpers/tf_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def defaults(self):
"long_body_size": "500",
"memory_leak_threshold": "65536",
"unavailable_timeout": "300",
"ddos_executable": "",
},
"Loggers": {
"stream_handler": "CRITICAL",
Expand Down
Loading