diff --git a/pandaharvester/globus_token_sync.py b/pandaharvester/globus_token_sync.py new file mode 100644 index 00000000..cfa17971 --- /dev/null +++ b/pandaharvester/globus_token_sync.py @@ -0,0 +1,95 @@ +import os +import traceback +from typing import Dict, Any + +from globus_compute_sdk import Executor, Client +from globus_compute_sdk.errors.error_types import TaskExecutionFailed +from globus_compute_sdk.sdk.shell_function import ShellFunction, ShellResult + +from cryptography.fernet import Fernet + +from pandaharvester.harvestercore import core_utils + +def _remote_write_token(encrypted_token: str, remote_token_path: str, key_file: str) -> str: + """ + Remote function executed on the Globus Compute endpoint (HPC site, for example). + It first reads the key from key_file, then decrypts encrypted_token based on key file + Finally writes the plaintext to remote_token_path atomically (via tmp + os.replace) + """ + import os + from cryptography.fernet import Fernet + + with open(key_file, "rb") as f: + key = f.read().strip() + fernet = Fernet(key) + plaintext = fernet.decrypt(encrypted_token.encode("utf-8")) + + dirname = os.path.dirname(remote_token_path) + if dirname: + os.makedirs(dirname, exist_ok=True) + tmp_path = remote_token_path + ".tmp" + + fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, "wb") as fp: + fp.write(plaintext) + fp.flush() + os.fsync(fp.fileno()) + os.replace(tmp_path, remote_token_path) + + return remote_token_path + + +class GlobusTokenReplicator: + """ + Wrapper over globus compute client that is used by IamTokenCredManagerRemoteGlobusCompute to sync tokens with a remote site (like HPC). + It first encrypts the token on local harvester machine using Fernet, + then submits a Globus Compute task which decrypts and writes the token file on the remote site. + """ + + def __init__(self, gc_cfg: Dict[str, Any], logger): + """ + An example of gc_cfg is expected to look like: + { + "endpoint_id": "", + "local_key_file": "/path/on/harvester/harvester_gc_token.key", + "remote_key_file": "/path/on/endpoint/harvester_gc_token.key", + "task_timeout": 120 + } + The first three items are mandatory! + """ + self.logger = logger + + try: + self.endpoint_id = gc_cfg["endpoint_id"] + self.local_key_file = gc_cfg["local_key_file"] + self.remote_key_file = gc_cfg.get("remote_key_file", self.local_key_file) + except KeyError as e: + raise RuntimeError(f"GlobusTokenReplicator missing required config key: {e}") from e + + self.executor = Executor(endpoint_id=self.endpoint_id) + self.task_timeout = gc_cfg.get("task_timeout", 120) + + try: + with open(self.local_key_file, "rb") as f: + key = f.read().strip() + self.fernet = Fernet(key) + except Exception: + self.logger.error(f"Failed to load local Fernet key from {self.local_key_file}\n{traceback.format_exc()}") + raise + + def do_it(self, token_str: str, remote_token_path: str) -> bool: + encrypted = self.fernet.encrypt(token_str.encode("utf-8")).decode("ascii") + + try: + future = self.executor.submit( + _remote_write_token, encrypted, remote_token_path, self.remote_key_file + ) + result = future.result(timeout=self.task_timeout) + self.logger.debug(f"Remote token sync to {remote_token_path} finished with result: {result}") + return True + except TaskExecutionFailed as e: + self.logger.error(f"Globus Compute task failed for {remote_token_path}: {e}") + return False + except Exception: + self.logger.error(f"Unexpected error during remote token sync for {remote_token_path}\n{traceback.format_exc()}") + return False diff --git a/pandaharvester/iam_token_cred_manager_globus_compute.py b/pandaharvester/iam_token_cred_manager_globus_compute.py new file mode 100644 index 00000000..ed3e7dbe --- /dev/null +++ b/pandaharvester/iam_token_cred_manager_globus_compute.py @@ -0,0 +1,82 @@ +import os +import traceback + +from pandaharvester.harvestercore import core_utils +from pandaharvester.harvestermisc.token_utils import endpoint_to_filename + +from .iam_token_cred_manager import IamTokenCredManager +from .globus_token_sync import GlobusTokenReplicator + +_logger = core_utils.setup_logger("iam_token_cred_manager_gc") + +class IamTokenCredManagerRemoteGlobusCompute(IamTokenCredManager): + def __init__(self, **kwarg): + super().__init__(**kwarg) + tmp_log = self.make_logger(_logger, f"config={self.setup_name}", method_name="__init__") + + self.remote_out_dir = self.setupMap.get("remote_out_dir", "") + gc_cfg = self.setupMap.get("globus_compute", {}) + if gc_cfg and self.remote_out_dir: + self.replicator = GlobusTokenReplicator(gc_cfg, tmp_log) + else: + tmp_log.debug(f"replicator is not initialized as either gc_cfg or remote_out_dir is missing. They are gc_cfg = {gc_cfg} and remote_out_dir = {self.remote_out_dir}") + self.replicator = None + + # If not specify remote token replicator, then do nothing + # First get the final remote token filepath + # Then inside token replicator, encrypt the token locally as a string + # Next within token replicator submit a task to transfer the token to a tmp place + # Inside that task, next decrypt the token and atomically replace the old remote token file + def _sync_remote(self, token_filename: str, token_str: str, logger): + if not self.replicator: + return + remote_token_path = os.path.join(self.remote_out_dir, token_filename) + status = self.replicator.do_it(token_str, remote_token_path) + if status: + logger.info(f"Synchronized token '{token_filename}' to remote: {remote_token_path}") + else: + logger.error(f"FAILED to synchronize token '{token_filename}' to remote: {remote_token_path}") + + def renew_credential(self): + # make logger + tmp_log = self.make_logger(_logger, f"config={self.setup_name}", method_name="renew_credential") + # go + all_ok = True + all_err_str = "" + for target in self.targets_dict: + try: + # write to file + if self.target_type == "panda": + token_filename = self.panda_token_filename + else: + token_filename = endpoint_to_filename(target) + token_path = os.path.join(self.out_dir, token_filename) + # check token freshness + if self._is_fresh(token_path): + # token still fresh, skip it + tmp_log.debug(f"token for {target} at {token_path} still fresh; skipped") + continue + # renew access token of target locally + access_token = self.issuer_broker.get_access_token(aud=target, scope=self.scope) + with open(token_path, "w") as f: + f.write(access_token) + tmp_log.info(f"renewed token for {target} at {token_path}") + # encrypt, and sync to remote (if configured) per token file mapping, and then decrypt remotely + try: + self._sync_remote(token_filename, access_token, tmp_log) + except Exception: + all_ok = False + tmp_log.error(f"Remote sync failed for {target}. {traceback.format_exc()}") + all_err_str = "failed to sync some tokens; see plugin log for details " + + except Exception as e: + err_str = f"Problem getting token for {target}. {traceback.format_exc()}" + tmp_log.error(err_str) + all_ok = False + all_err_str = "failed to get some tokens. Check the plugin log for details " + continue + # update last timestamp + self._update_ts() + tmp_log.debug("done") + # return + return all_ok, all_err_str