Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions pandaharvester/globus_token_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import traceback
from typing import Dict, Any

from globus_compute_sdk import Executor, Client
from globus_compute_sdk.errors.error_types import TaskExecutionFailed
from globus_compute_sdk.sdk.shell_function import ShellFunction, ShellResult

from cryptography.fernet import Fernet

from pandaharvester.harvestercore import core_utils

def _remote_write_token(encrypted_token: str, remote_token_path: str, key_file: str) -> str:
"""
Remote function executed on the Globus Compute endpoint (HPC site, for example).
It first reads the key from key_file, then decrypts encrypted_token based on key file
Finally writes the plaintext to remote_token_path atomically (via tmp + os.replace)
"""
import os
from cryptography.fernet import Fernet

with open(key_file, "rb") as f:
key = f.read().strip()
fernet = Fernet(key)
plaintext = fernet.decrypt(encrypted_token.encode("utf-8"))

dirname = os.path.dirname(remote_token_path)
if dirname:
os.makedirs(dirname, exist_ok=True)
tmp_path = remote_token_path + ".tmp"

fd = os.open(tmp_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "wb") as fp:
fp.write(plaintext)
fp.flush()
os.fsync(fp.fileno())
os.replace(tmp_path, remote_token_path)

return remote_token_path


class GlobusTokenReplicator:
"""
Wrapper over globus compute client that is used by IamTokenCredManagerRemoteGlobusCompute to sync tokens with a remote site (like HPC).
It first encrypts the token on local harvester machine using Fernet,
then submits a Globus Compute task which decrypts and writes the token file on the remote site.
"""

def __init__(self, gc_cfg: Dict[str, Any], logger):
"""
An example of gc_cfg is expected to look like:
{
"endpoint_id": "<UUID of GC endpoint>",
"local_key_file": "/path/on/harvester/harvester_gc_token.key",
"remote_key_file": "/path/on/endpoint/harvester_gc_token.key",
"task_timeout": 120
}
The first three items are mandatory!
"""
self.logger = logger

try:
self.endpoint_id = gc_cfg["endpoint_id"]
self.local_key_file = gc_cfg["local_key_file"]
self.remote_key_file = gc_cfg.get("remote_key_file", self.local_key_file)
except KeyError as e:
raise RuntimeError(f"GlobusTokenReplicator missing required config key: {e}") from e

self.executor = Executor(endpoint_id=self.endpoint_id)
self.task_timeout = gc_cfg.get("task_timeout", 120)

try:
with open(self.local_key_file, "rb") as f:
key = f.read().strip()
self.fernet = Fernet(key)
except Exception:
self.logger.error(f"Failed to load local Fernet key from {self.local_key_file}\n{traceback.format_exc()}")
raise

def do_it(self, token_str: str, remote_token_path: str) -> bool:
encrypted = self.fernet.encrypt(token_str.encode("utf-8")).decode("ascii")

try:
future = self.executor.submit(
_remote_write_token, encrypted, remote_token_path, self.remote_key_file
)
result = future.result(timeout=self.task_timeout)
self.logger.debug(f"Remote token sync to {remote_token_path} finished with result: {result}")
return True
except TaskExecutionFailed as e:
self.logger.error(f"Globus Compute task failed for {remote_token_path}: {e}")
return False
except Exception:
self.logger.error(f"Unexpected error during remote token sync for {remote_token_path}\n{traceback.format_exc()}")
return False
82 changes: 82 additions & 0 deletions pandaharvester/iam_token_cred_manager_globus_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import traceback

from pandaharvester.harvestercore import core_utils
from pandaharvester.harvestermisc.token_utils import endpoint_to_filename

from .iam_token_cred_manager import IamTokenCredManager
from .globus_token_sync import GlobusTokenReplicator

_logger = core_utils.setup_logger("iam_token_cred_manager_gc")

class IamTokenCredManagerRemoteGlobusCompute(IamTokenCredManager):
def __init__(self, **kwarg):
super().__init__(**kwarg)
tmp_log = self.make_logger(_logger, f"config={self.setup_name}", method_name="__init__")

self.remote_out_dir = self.setupMap.get("remote_out_dir", "")
gc_cfg = self.setupMap.get("globus_compute", {})
if gc_cfg and self.remote_out_dir:
self.replicator = GlobusTokenReplicator(gc_cfg, tmp_log)
else:
tmp_log.debug(f"replicator is not initialized as either gc_cfg or remote_out_dir is missing. They are gc_cfg = {gc_cfg} and remote_out_dir = {self.remote_out_dir}")
self.replicator = None

# If not specify remote token replicator, then do nothing
# First get the final remote token filepath
# Then inside token replicator, encrypt the token locally as a string
# Next within token replicator submit a task to transfer the token to a tmp place
# Inside that task, next decrypt the token and atomically replace the old remote token file
def _sync_remote(self, token_filename: str, token_str: str, logger):
if not self.replicator:
return
remote_token_path = os.path.join(self.remote_out_dir, token_filename)
status = self.replicator.do_it(token_str, remote_token_path)
if status:
logger.info(f"Synchronized token '{token_filename}' to remote: {remote_token_path}")
else:
logger.error(f"FAILED to synchronize token '{token_filename}' to remote: {remote_token_path}")

def renew_credential(self):
# make logger
tmp_log = self.make_logger(_logger, f"config={self.setup_name}", method_name="renew_credential")
# go
all_ok = True
all_err_str = ""
for target in self.targets_dict:
try:
# write to file
if self.target_type == "panda":
token_filename = self.panda_token_filename
else:
token_filename = endpoint_to_filename(target)
token_path = os.path.join(self.out_dir, token_filename)
# check token freshness
if self._is_fresh(token_path):
# token still fresh, skip it
tmp_log.debug(f"token for {target} at {token_path} still fresh; skipped")
continue
# renew access token of target locally
access_token = self.issuer_broker.get_access_token(aud=target, scope=self.scope)
with open(token_path, "w") as f:
f.write(access_token)
tmp_log.info(f"renewed token for {target} at {token_path}")
# encrypt, and sync to remote (if configured) per token file mapping, and then decrypt remotely
try:
self._sync_remote(token_filename, access_token, tmp_log)
except Exception:
all_ok = False
tmp_log.error(f"Remote sync failed for {target}. {traceback.format_exc()}")
all_err_str = "failed to sync some tokens; see plugin log for details "

except Exception as e:
err_str = f"Problem getting token for {target}. {traceback.format_exc()}"
tmp_log.error(err_str)
all_ok = False
all_err_str = "failed to get some tokens. Check the plugin log for details "
continue
# update last timestamp
self._update_ts()
tmp_log.debug("done")
# return
return all_ok, all_err_str