diff --git a/Dockerfile b/Dockerfile index e0202a91a..94f90f653 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,9 @@ RUN apt-get update && \ apt-get upgrade -y && \ apt-get install -y --no-install-recommends \ openssl \ - ca-certificates && \ + ca-certificates \ + gcc \ + libc6-dev && \ pip install --no-cache-dir --upgrade pip setuptools wheel && \ apt-get autoremove -y && \ apt-get clean && \ diff --git a/README.md b/README.md index 3fcdeaa38..6347c9b5a 100644 --- a/README.md +++ b/README.md @@ -3,16 +3,61 @@ ![Keeper Commander](https://raw.githubusercontent.com/Keeper-Security/Commander/master/images/commander-black.png) ### About Keeper Commander -Keeper Commander is a command-line CLI and Python SDK interface to Keeper® Password Manager and KeeperPAM. Commander can be used to access and control your Keeper vault, perform administrative functions (managing users, teams, roles, SSO, privileged access resources, data import/export), launch sessions, rotate passwords, integrate with developer tools, eliminate hardcoded passwords, run as a REST service and more. Keeper Commander is an open source project with contributions from Keeper's engineering team and partners. +Keeper Commander is a command-line and terminal UI interface to Keeper® Password Manager and KeeperPAM. Commander can be used to access and control your Keeper vault, perform administrative actions (managing users, teams, roles, SSO, privileged access resources, device approvals, data import/export), launch sessions, rotate passwords, integrate with developer tools, eliminate hardcoded passwords, run as a REST service and more. Keeper Commander is an open source project with contributions from Keeper's engineering team, customers and partners. -### Documentation and Getting Started -- Keeper Commander documentation: [https://docs.keeper.io/en/keeperpam/commander-cli/overview](https://docs.keeper.io/en/keeperpam/commander-cli/overview) +### Windows and macOS Binaries +See the [Releases](https://github.com/Keeper-Security/Commander/releases) -### About Keeper Security -Keeper is the leading cybersecurity platform for preventing password-related data breaches and cyberthreats. KeeperPAM is the leading zero-trust privileged access management ("PAM") platform for securing and managing access to your critical infrastructure. +### Linux / Python using PIP +``` +python3 -m venv keeper-env +source keeper-env/bin/activate +pip install keepercommander +``` + +### Running from Source +``` +git clone https://github.com/Keeper-Security/Commander +cd Commander +python3 -m venv venv +source venv/bin/activate +pip install -r requirements.txt +pip install -e . +pip install -e '.[email]' +``` + +### Starting Commander +For a list of all available commands: +``` +keeper help +``` + +To launch the interactive command shell: -- Learn More about Keeper: [https://keepersecurity.com](https://keepersecurity.com) +``` +keeper shell +``` -- Encryption and Security Model: [https://docs.keeper.io/en/enterprise-guide/keeper-encryption-model](https://docs.keeper.io/en/enterprise-guide/keeper-encryption-model) +or for a full terminal vault user interface +``` +keeper supershell +``` + +Once logged in, check out the `this-device` command to set up persistent login sessions, logout timer and 2FA frequency. Also check out the `biometric register` command to enable biometric authentication on supported platforms. + +### Documentation +- [Commander Documentation Home](https://docs.keeper.io/en/keeperpam/commander-cli/overview) +- [Installation](https://docs.keeper.io/en/keeperpam/commander-cli/commander-installation-setup) +- [Full Command Reference](https://docs.keeper.io/en/keeperpam/commander-cli/command-reference) +- [Service Mode REST API](https://docs.keeper.io/en/keeperpam/commander-cli/service-mode-rest-api) +- [Commander SDK](https://docs.keeper.io/en/keeperpam/commander-sdk/keeper-commander-sdks) +- [All Keeper Documentation](https://docs.keeper.io/) + +### About Keeper Security +Keeper Security is the creator of KeeperPAM - the zero-trust and zero-knowledge privileged access management ("PAM") platform for securing and managing access to your critical infrastructure. +- [Keeper Security Homepage](https://keepersecurity.com) +- [Privileged Access Management](https://www.keepersecurity.com/privileged-access-management/) +- [Endpoint Privilege Manager](https://www.keepersecurity.com/endpoint-privilege-management/) +- [Encryption and Security Model](https://docs.keeper.io/en/enterprise-guide/keeper-encryption-model) +- [Downloads](https://www.keepersecurity.com/download.html?t=d) -- Documentation Home: [https://docs.keeper.io/](https://docs.keeper.io/) diff --git a/keepercommander/__init__.py b/keepercommander/__init__.py index 31c6081c0..3c8cca3f8 100644 --- a/keepercommander/__init__.py +++ b/keepercommander/__init__.py @@ -10,4 +10,4 @@ # Contact: ops@keepersecurity.com # -__version__ = '17.2.1' +__version__ = '17.2.2' diff --git a/keepercommander/__main__.py b/keepercommander/__main__.py index 30a3faf23..95fd6d0b5 100644 --- a/keepercommander/__main__.py +++ b/keepercommander/__main__.py @@ -29,6 +29,7 @@ from . import cli, utils from .params import KeeperParams from .config_storage import loader +from .constants import resolve_server, KEEPER_SERVERS def get_params_from_config(config_filename=None, launched_with_shortcut=False, data_dir=None): # type: (Optional[str], bool, Optional[str]) -> KeeperParams @@ -96,12 +97,54 @@ def get_env_config(): def usage(m): + """Show full help with all commands - used for 'keeper help' or 'keeper ?'""" print(m) parser.print_help() - cli.display_command_help(show_enterprise=True, show_shell=True, show_legacy=True) + cli.display_command_help(show_enterprise=True, show_shell=True, show_legacy=False) sys.exit(1) +def show_brief_help(): + """Show brief help for 'keeper -h' - just global options and guidance""" + print('') + print('Keeper Commander - CLI-based vault and admin interface to the Keeper platform') + print('') + print('Usage: keeper [OPTIONS] [COMMAND] [COMMAND_OPTIONS]') + print('') + print('Global Options:') + print(' --server, -ks SERVER Keeper region or host') + print(' Regions: US, EU, AU, CA, JP, GOV') + print(' Dev/QA: US_DEV, EU_DEV, GOV_QA, etc.') + print(' --user, -ku USER Email address for the account') + print(' --password, -kp PASSWORD Master password for the account') + print(' --config CONFIG Config file to use') + print(' --debug Turn on debug mode') + print(' --batch-mode Run in batch/non-interactive mode') + print(' --proxy PROXY Proxy server') + print(' --new-login Force full login (bypass persistent login)') + print(' --version Display version') + print('') + print('Getting Started:') + print(' keeper shell Open interactive command shell') + print(' keeper supershell Open full-screen vault browser (TUI)') + print(' keeper login Login to your Keeper account') + print('') + print('Getting Help:') + print(' keeper help Show hundreds of available commands') + print(' keeper help Show help for a specific command') + print(' keeper -h Show help for a specific command') + print('') + print('Examples:') + print(' keeper shell # Start interactive shell') + print(' keeper --server EU login # Login to EU region') + print(' keeper login -h # Show login command help') + print(' keeper search "github" --format=json # Search and output JSON') + print('') + print('User Guide: https://docs.keeper.io/en/keeperpam/commander-cli') + print('') + sys.exit(0) + + parser = argparse.ArgumentParser(prog='keeper', add_help=False, allow_abbrev=False) parser.add_argument('--server', '-ks', dest='server', action='store', help='Keeper Host address.') parser.add_argument('--user', '-ku', dest='user', action='store', help='Email address for the account.') @@ -120,6 +163,7 @@ def usage(m): 'server-side throttling' parser.add_argument('--fail-on-throttle', action='store_true', help=fail_on_throttle_help) parser.add_argument('--data-dir', dest='data_dir', action='store', help='Directory to use for Commander data (config, cache, etc.). Overrides environment variables.') +parser.add_argument('--new-login', dest='new_login', action='store_true', help='Force full login flow (bypass persistent login)') parser.add_argument('command', nargs='?', type=str, action='store', help='Command') parser.add_argument('options', nargs='*', action='store', help='Options') parser.error = usage @@ -286,7 +330,19 @@ def main(from_package=False): params.proxy = opts.proxy if opts.server: - params.server = opts.server + resolved_server = resolve_server(opts.server) + if resolved_server: + params.server = resolved_server + else: + # Show error and valid options + print(f"\nError: '{opts.server}' is not a valid Keeper server.") + print('\nValid server codes:') + print(' Production: US, EU, AU, CA, JP, GOV') + print(' Dev: US_DEV, EU_DEV, AU_DEV, CA_DEV, JP_DEV, GOV_DEV') + print(' QA: US_QA, EU_QA, AU_QA, CA_QA, JP_QA, GOV_QA') + print('\nYou can also use the full hostname (e.g., keepersecurity.com, keepersecurity.eu)') + print('') + sys.exit(1) if opts.user is not None: params.user = opts.user @@ -308,24 +364,62 @@ def main(from_package=False): print(f'Keeper Commander, version {__version__}') return - if flags and len(flags) > 0: - if flags[0] in ('-h', '--help'): - flags.clear() - opts.command = '?' - elif opts.command == 'help' and len(opts.options) == 0: - opts.command = '?' - if (opts.command or '') == '?': + # Handle help flags and commands + has_help_flag = flags and len(flags) > 0 and flags[0] in ('-h', '--help') + + if has_help_flag: + if not opts.command: + # 'keeper -h' with no command → show brief help + show_brief_help() + else: + # 'keeper -h' → pass -h to the command (keep it in original_args) + # The -h is already in original_args_after_command, so just continue + pass + + # Handle 'keeper help' and 'keeper help ' + if opts.command == 'help': + if len(opts.options) == 0: + # 'keeper help' with no args → show full command list + usage('') + else: + # 'keeper help ' → convert to ' --help' + opts.command = opts.options[0] + original_args_after_command = ['--help'] + + # Handle 'keeper ?' + if opts.command == '?': usage('') if not opts.command and from_package: opts.command = 'shell' + # If no command provided, show helpful welcome message + if not opts.command and not params.commands: + print('') + print('Keeper Commander - CLI-based vault and admin interface to the Keeper platform') + print('') + print('To get started:') + print(' keeper shell Open interactive command shell') + print(' keeper supershell Open full-screen vault browser (TUI)') + print(' keeper -h Show help and available options') + print('') + print('Learn more at https://docs.keeper.io/en/keeperpam/commander-cli/overview') + print('') + return + if isinstance(params.timedelay, int) and params.timedelay >= 1 and params.commands: cli.runcommands(params) else: - if opts.command in {'shell', 'login', '-'}: + # Check if -h/--help is in the arguments for a command + command_wants_help = any(arg in ('-h', '--help') for arg in original_args_after_command) + + if opts.command in {'shell', '-'} and not command_wants_help: + # Special handling for shell/- when NOT asking for help if opts.command == '-': params.batch_mode = True + elif opts.command == 'login' and not original_args_after_command and not command_wants_help: + # 'keeper login' with no args - just open shell and let it handle login + pass elif opts.command and os.path.isfile(opts.command): with open(opts.command, 'r') as f: lines = f.readlines() @@ -336,12 +430,15 @@ def main(from_package=False): if opts.command: # Use the filtered original argument order to preserve proper flag/value pairing options = ' '.join([shlex.quote(x) for x in original_args_after_command]) if original_args_after_command else '' + # Inject --new-login into login command if main parser captured it + if opts.command == 'login' and opts.new_login: + options = '--new-login ' + options if options else '--new-login' command = ' '.join([opts.command or '', options]).strip() params.commands.append(command) params.commands.append('q') params.batch_mode = True - errno = cli.loop(params) + errno = cli.loop(params, new_login=opts.new_login) sys.exit(errno) diff --git a/keepercommander/api.py b/keepercommander/api.py index fbe47324a..652c7fe08 100644 --- a/keepercommander/api.py +++ b/keepercommander/api.py @@ -69,7 +69,8 @@ def run_command(params, request): def login(params, new_login=False, login_ui=None): # type: (KeeperParams, bool, Optional[Any]) -> None - logging.info('Logging in to Keeper Commander') + logging.info(f'Logging in to Keeper as {params.user}') + flow = loginv3.LoginV3Flow(login_ui) try: flow.login(params, new_login=new_login) diff --git a/keepercommander/auth/console_ui.py b/keepercommander/auth/console_ui.py index 463735246..c12fbb9b6 100644 --- a/keepercommander/auth/console_ui.py +++ b/keepercommander/auth/console_ui.py @@ -6,6 +6,7 @@ import webbrowser from typing import Optional, List +from colorama import Fore, Style from . import login_steps from .. import utils from ..display import bcolors @@ -23,48 +24,77 @@ def __init__(self): def on_device_approval(self, step): if self._show_device_approval_help: - print("\nDevice Approval Required\n") - - print("Approve by selecting a method below:") - print("\t\"" + bcolors.OKGREEN + "email_send" + bcolors.ENDC + "\" to send email") - print("\t\"" + bcolors.OKGREEN + "email_code=" + bcolors.ENDC + "\" to validate verification code sent via email") - print("\t\"" + bcolors.OKGREEN + "keeper_push" + bcolors.ENDC + "\" to send Keeper Push notification") - print("\t\"" + bcolors.OKGREEN + "2fa_send" + bcolors.ENDC + "\" to send 2FA code") - print("\t\"" + bcolors.OKGREEN + "2fa_code=" + bcolors.ENDC + "\" to validate a code provided by 2FA application") - print("\t\"" + bcolors.OKGREEN + "" + bcolors.ENDC + "\" to resume") - + print(f"\n{Style.BRIGHT}Device Approval Required{Style.RESET_ALL}\n") + print("Select an approval method:") + print(f" {Fore.GREEN}1{Fore.RESET}. Email - Send approval link to your email") + print(f" {Fore.GREEN}2{Fore.RESET}. Keeper Push - Send notification to an approved device") + print(f" {Fore.GREEN}3{Fore.RESET}. 2FA Push - Send code via your 2FA method") + print() + print(f" {Fore.GREEN}c{Fore.RESET}. Enter code - Enter a verification code") + print(f" {Fore.GREEN}q{Fore.RESET}. Cancel login") + print() self._show_device_approval_help = False else: - print(bcolors.BOLD + "\nWaiting for device approval." + bcolors.ENDC) - print("Check email, SMS message or push notification on the approved device.\n") + print(f"\n{Style.BRIGHT}Waiting for device approval.{Style.RESET_ALL}") + print(f"Check email, SMS, or push notification on the approved device.") + print(f"Enter {Fore.GREEN}c {Fore.RESET} to submit a verification code.\n") try: - selection = input('Type your selection or to resume: ') + selection = input('Selection (or Enter to check status): ').strip().lower() - if selection == "email_send" or selection == "es": + if selection == '1' or selection == 'email_send' or selection == 'es': step.send_push(login_steps.DeviceApprovalChannel.Email) - print(bcolors.WARNING + "\nAn email with instructions has been sent to " + step.username + bcolors.WARNING + '\nPress when approved.') + print(f"\n{Fore.GREEN}Email sent to {step.username}{Fore.RESET}") + print("Click the approval link in the email, then press Enter.\n") + + elif selection == '2' or selection == 'keeper_push' or selection == 'kp': + step.send_push(login_steps.DeviceApprovalChannel.KeeperPush) + print(f"\n{Fore.GREEN}Push notification sent.{Fore.RESET}") + print("Approve on your device, then press Enter.\n") + + elif selection == '3' or selection == '2fa_send' or selection == '2fs': + step.send_push(login_steps.DeviceApprovalChannel.TwoFactor) + print(f"\n{Fore.GREEN}2FA code sent.{Fore.RESET}") + print("Enter the code using option 'c'.\n") + + elif selection == 'c' or selection.startswith('c '): + # Support both "c" (prompts for code) and "c " (code inline) + if selection == 'c': + code_input = input('Enter verification code: ').strip() + else: + code_input = selection[2:].strip() # Extract code after "c " + + if code_input: + # Try email code first, then 2FA + try: + step.send_code(login_steps.DeviceApprovalChannel.Email, code_input) + print(f"{Fore.GREEN}Successfully verified email code.{Fore.RESET}") + except KeeperApiError: + try: + step.send_code(login_steps.DeviceApprovalChannel.TwoFactor, code_input) + print(f"{Fore.GREEN}Successfully verified 2FA code.{Fore.RESET}") + except KeeperApiError as e: + print(f"{Fore.YELLOW}Invalid code. Please try again.{Fore.RESET}") elif selection.startswith("email_code="): code = selection.replace("email_code=", "") step.send_code(login_steps.DeviceApprovalChannel.Email, code) - print("Successfully verified email code.") - - elif selection == "2fa_send" or selection == "2fs": - step.send_push(login_steps.DeviceApprovalChannel.TwoFactor) - print(bcolors.WARNING + "\n2FA code was sent." + bcolors.ENDC) + print(f"{Fore.GREEN}Successfully verified email code.{Fore.RESET}") elif selection.startswith("2fa_code="): code = selection.replace("2fa_code=", "") step.send_code(login_steps.DeviceApprovalChannel.TwoFactor, code) - print("Successfully verified 2FA code.") + print(f"{Fore.GREEN}Successfully verified 2FA code.{Fore.RESET}") - elif selection == "keeper_push" or selection == "kp": - step.send_push(login_steps.DeviceApprovalChannel.KeeperPush) - logging.info('Successfully made a push notification to the approved device.\nPress when approved.') + elif selection == 'q': + step.cancel() - elif selection == "": + elif selection == '': step.resume() + + else: + print(f"{Fore.YELLOW}Invalid selection. Enter 1, 2, 3, c, q, or press Enter.{Fore.RESET}") + except KeyboardInterrupt: step.cancel() except KeeperApiError as kae: @@ -240,7 +270,7 @@ def on_two_factor(self, step): def on_password(self, step): if self._show_password_help: - print(f'Enter password for {step.username}') + print(f'Enter master password for {step.username}') if self._failed_password_attempt > 0: print('Forgot password? Type "recover"') diff --git a/keepercommander/cli.py b/keepercommander/cli.py index a0314eb6b..a5b1189e8 100644 --- a/keepercommander/cli.py +++ b/keepercommander/cli.py @@ -22,11 +22,12 @@ from pathlib import Path from typing import Union +from colorama import Fore, Style from prompt_toolkit import PromptSession from prompt_toolkit.enums import EditingMode from prompt_toolkit.shortcuts import CompleteStyle -from . import api, display, ttk +from . import api, display, ttk, utils from . import versioning from .autocomplete import CommandCompleter from .commands import ( @@ -36,7 +37,7 @@ from .commands.base import CliCommand, GroupCommand from .commands.utils import LoginCommand from .commands import msp -from .constants import OS_WHICH_CMD, KEEPER_PUBLIC_HOSTS +from .constants import OS_WHICH_CMD, KEEPER_PUBLIC_HOSTS, KEEPER_SERVERS from .error import CommandError, Error from .params import KeeperParams from .subfolder import BaseFolderNode @@ -55,161 +56,138 @@ command_info['server'] = 'Sets or displays current Keeper region' +# Shell-specific commands (handled inline in the shell loop) +command_info['clear'] = 'Clear the screen' +command_info['history'] = 'Show command history' +command_info['quit'] = 'Exit the shell' +aliases['c'] = 'clear' +aliases['h'] = 'history' +aliases['q'] = 'quit' + logging.getLogger('asyncio').setLevel(logging.WARNING) def display_command_help(show_enterprise=False, show_shell=False, show_legacy=False): from .command_categories import get_command_category, get_category_order from .display import bcolors - + from colorama import Fore, Style + import shutil + alias_lookup = {x[1]: x[0] for x in aliases.items()} - + DIM = Fore.WHITE # Use white for better readability (not too bright, not too dim) + + # Get terminal width + try: + terminal_width = shutil.get_terminal_size(fallback=(80, 24)).columns + except: + terminal_width = 80 + + def clean_description(desc): + """Remove trailing period from description""" + if desc and desc.endswith('.'): + return desc[:-1] + return desc + # Collect all commands from all sources all_commands = {} all_commands.update(command_info) if show_enterprise: all_commands.update(enterprise_command_info) all_commands.update(msp_command_info) - - # Group commands by category + + # Group commands by category and build display info categorized_commands = {} for cmd, description in all_commands.items(): category = get_command_category(cmd) if category not in categorized_commands: categorized_commands[category] = [] - categorized_commands[category].append((cmd, description)) - - # Define colors for different categories - more variety and visual appeal - category_colors = { - 'Record Commands': bcolors.OKGREEN, # Green - primary functionality - 'Sharing Commands': bcolors.OKBLUE, # Blue - collaboration - 'Record Type Commands': bcolors.HEADER, # Purple/Magenta - special types - 'Import and Exporting Data': bcolors.WARNING, # Yellow - data operations - 'Reporting Commands': '\033[96m', # Cyan - analytics - 'MSP Management Commands': bcolors.HIGHINTENSITYRED, # Bright Red - MSP admin - 'Enterprise Management Commands': '\033[94m', # Blue - enterprise admin - 'Automation Commands': '\033[32m', # Dark Green - automation workflows - 'Secrets Manager Commands': '\033[95m', # Magenta - KSM - 'BreachWatch Commands': bcolors.FAIL, # Red - security alerts - 'Device Management Commands': '\033[93m', # Bright Yellow - devices - 'Domain Management Commands': '\033[92m', # Bright Green - domains - 'Service Mode REST API': '\033[36m', # Dark Cyan - services - 'Email Configuration Commands': '\033[38;5;214m', # Orange - email services - 'Miscellaneous Commands': '\033[37m', # Light Gray - utilities - 'KeeperPAM Commands': '\033[92m', # Bright Green - PAM - 'Legacy Commands': '\033[90m', # Dark Gray - deprecated - 'Other': bcolors.WHITE - } - - print(f'\n{bcolors.BOLD}{bcolors.UNDERLINE}Commands:{bcolors.ENDC}') - print('=' * 80) - - # Display commands in category order with colors and separators - first_category = True + categorized_commands[category].append((cmd, clean_description(description))) + + # Pre-compute all command display strings and find global max width + # This allows alignment across all categories when terminal is wide enough + all_cmd_displays = [] # List of (category, cmd_display, description) + global_max_width = 0 + + # Special subcommands for certain categories + pam_subcommands = [ + ('pam action', 'Execute action on the Gateway'), + ('pam config', 'Manage PAM Configurations'), + ('pam connection', 'Manage Connections'), + ('pam gateway', 'Manage Gateways'), + ('pam legacy', 'Switch to legacy PAM commands'), + ('pam project', 'PAM Project Import/Export'), + ('pam rbi', 'Manage Remote Browser Isolation'), + ('pam rotation', 'Manage Rotations'), + ('pam split', 'Split credentials from legacy PAM Machine'), + ('pam tunnel', 'Manage Tunnels'), + ] + domain_subcommands = [ + ('domain list (dl)', 'List all reserved domains for the enterprise'), + ('domain reserve (dr)', 'Reserve, delete, or generate token for a domain'), + ] + for category in get_category_order(): if category not in categorized_commands: continue - - # Skip Legacy Commands unless specifically requested if category == 'Legacy Commands' and not show_legacy: continue - - # Add separator between categories (except for first one) - if not first_category: - print() # Empty line between categories - first_category = False - - # Sort commands within each category - commands_in_category = sorted(categorized_commands[category], key=lambda x: x[0]) - - # Display category header with color - color = category_colors.get(category, bcolors.WHITE) - print(f'{color}{bcolors.BOLD}{category}:{bcolors.ENDC}') - print(f'{color}{"-" * len(category)}{bcolors.ENDC}') - - # Special handling for KeeperPAM Commands to show sub-commands + if category == 'KeeperPAM Commands': - # Define PAM sub-commands with descriptions - pam_subcommands = [ - ('pam action', 'Execute action on the Gateway'), - ('pam config', 'Manage PAM Configurations'), - ('pam connection', 'Manage Connections'), - ('pam gateway', 'Manage Gateways'), - ('pam legacy', 'Switch to legacy PAM commands'), - ('pam project', 'PAM Project Import/Export'), - ('pam rbi', 'Manage Remote Browser Isolation'), - ('pam rotation', 'Manage Rotations'), - ('pam split', 'Split credentials from legacy PAM Machine'), - ('pam tunnel', 'Manage Tunnels'), - ] - - # Calculate width for PAM commands - max_cmd_width = max(len(cmd) for cmd, _ in pam_subcommands) - for cmd_display, description in sorted(pam_subcommands): - # Bold only the "pam" part - pam_part = cmd_display.split(' ')[0] # "pam" - sub_part = cmd_display.split(' ', 1)[1] # "action", "config", etc. - formatted_cmd = f'{bcolors.BOLD}{pam_part}{bcolors.ENDC} {sub_part}' - # Adjust spacing to account for formatting codes - spacing = max_cmd_width - len(cmd_display) + len(bcolors.BOLD) + len(bcolors.ENDC) - print(f' {formatted_cmd}{" " * spacing} {description}') + all_cmd_displays.append((category, cmd_display, description)) + global_max_width = max(global_max_width, len(cmd_display)) elif category == 'Domain Management Commands': - # Define domain sub-commands with descriptions - domain_subcommands = [ - ('domain list', 'List all reserved domains for the enterprise'), - ('domain reserve', 'Reserve, delete, or generate token for a domain'), - ] - - # Calculate width for domain commands - max_cmd_width = max(len(cmd) for cmd, _ in domain_subcommands) - for cmd_display, description in sorted(domain_subcommands): - # Bold only the "domain" part - domain_part = cmd_display.split(' ')[0] # "domain" - sub_part = cmd_display.split(' ', 1)[1] # "list", "reserve" - formatted_cmd = f'{bcolors.BOLD}{domain_part}{bcolors.ENDC} {sub_part}' - # Adjust spacing to account for formatting codes - spacing = max_cmd_width - len(cmd_display) + len(bcolors.BOLD) + len(bcolors.ENDC) - print(f' {formatted_cmd}{" " * spacing} {description}') + all_cmd_displays.append((category, cmd_display, description)) + global_max_width = max(global_max_width, len(cmd_display)) else: - # Regular command display for other categories - max_cmd_width = 0 - cmd_display_list = [] + commands_in_category = sorted(categorized_commands[category], key=lambda x: x[0]) for cmd, description in commands_in_category: alias = alias_lookup.get(cmd) or '' alias_str = f' ({alias})' if alias else '' cmd_display = f'{cmd}{alias_str}' - cmd_display_list.append((cmd, alias_str, description)) - max_cmd_width = max(max_cmd_width, len(cmd_display)) - - # Display commands in this category with proper table alignment - for cmd, alias_str, description in cmd_display_list: - cmd_display = f'{cmd}{alias_str}' - print(f' {bcolors.BOLD}{cmd_display:<{max_cmd_width}}{bcolors.ENDC} {description}') - - # Add shell commands if requested - if show_shell: - print() # Separator - color = bcolors.WHITE - print(f'{color}{bcolors.BOLD}Shell Commands:{bcolors.ENDC}') - print(f'{color}{"-" * 14}{bcolors.ENDC}') - # Calculate max width for shell commands too - shell_commands = [ - ('clear (c)', 'Clear the screen.'), - ('history (h)', 'Show command history.'), - ('shell', 'Use Keeper interactive shell.'), - ('quit (q)', 'Quit.') - ] - shell_max_width = max(len(cmd) for cmd, _ in shell_commands) - - for cmd, description in shell_commands: - print(f' {bcolors.BOLD}{cmd:<{shell_max_width}}{bcolors.ENDC} {description}') - - print(f'\n{bcolors.UNDERLINE}Usage:{bcolors.ENDC}') - print(f"Type '{bcolors.BOLD}help {bcolors.ENDC}' to display help on a specific command") - if not show_legacy: - print(f"Type '{bcolors.BOLD}help --legacy{bcolors.ENDC}' to show legacy/deprecated commands") + all_cmd_displays.append((category, cmd_display, description)) + global_max_width = max(global_max_width, len(cmd_display)) + + # Determine if we should use global alignment + # Use global alignment if terminal is wide enough (command + padding + reasonable description) + min_desc_width = 40 + use_global_alignment = terminal_width >= (4 + global_max_width + 2 + min_desc_width) + + print() + print(f" {Style.BRIGHT}Available Commands{Style.RESET_ALL}") + print(f" {DIM}{'─' * 70}{Fore.RESET}") + + # Display commands grouped by category + current_category = None + category_cmd_widths = {} # Cache per-category max widths for non-global alignment + + # Pre-compute per-category max widths + if not use_global_alignment: + for category, cmd_display, _ in all_cmd_displays: + if category not in category_cmd_widths: + category_cmd_widths[category] = 0 + category_cmd_widths[category] = max(category_cmd_widths[category], len(cmd_display)) + + for category, cmd_display, description in all_cmd_displays: + if category != current_category: + if current_category is not None: + print() + print(f" {Style.BRIGHT}{category}{Style.RESET_ALL}") + current_category = category + + # Use global or per-category width + width = global_max_width if use_global_alignment else category_cmd_widths[category] + print(f" {Fore.GREEN}{cmd_display:<{width}}{Fore.RESET} {DIM}{description}{Fore.RESET}") + + print() + print(f" {DIM}Type {Fore.GREEN}help {DIM} to display help on command{Fore.RESET}") + # Only show these hints inside the shell (not from terminal) + if not show_shell: + print(f" {DIM}Type {Fore.GREEN}help basics{DIM} for a quick start guide{Fore.RESET}") + if not show_legacy: + print(f" {DIM}Type {Fore.GREEN}help --legacy{DIM} to show legacy/deprecated commands{Fore.RESET}") + print() def is_executing_as_msp_admin(): @@ -271,17 +249,36 @@ def is_msp(params_local): if command_line.lower().startswith('server'): _, sp, server = command_line.partition(' ') + server = server.strip() if server else '' + + # Handle help flag + if server in ('-h', '--help'): + print('Usage: server [REGION]') + print() + print('Set or display the current Keeper region.') + print() + print('Valid regions:') + print(f' Production: US, EU, AU, CA, JP, GOV') + print(f' Dev: US_DEV, EU_DEV, AU_DEV, CA_DEV, JP_DEV, GOV_DEV') + print(f' QA: US_QA, EU_QA, AU_QA, CA_QA, JP_QA, GOV_QA') + return + if server: if not params.session_token: - server = server.strip() - region = next((x for x in KEEPER_PUBLIC_HOSTS.items() - if server.casefold() in [x[0].casefold(), x[1].casefold()]), None) - if region: - params.server = region[1] - logging.info('Keeper region is set to %s', region[0]) + # Look up server in KEEPER_SERVERS (case insensitive) + server_upper = server.upper() + if server_upper in KEEPER_SERVERS: + params.server = KEEPER_SERVERS[server_upper] + logging.info('Keeper region is set to %s', server_upper) else: - params.server = server - logging.info('Keeper server is set to %s', params.server) + # Check if it matches a valid hostname directly + server_lower = server.lower() + if server_lower in KEEPER_SERVERS.values(): + params.server = server_lower + logging.info('Keeper server is set to %s', params.server) + else: + logging.error('Invalid region: %s', server) + print(f'Valid regions: {", ".join(sorted(KEEPER_SERVERS.keys()))}') else: logging.warning('Cannot change Keeper region while logged in') else: @@ -381,7 +378,9 @@ def is_msp(params_local): if command.is_authorised(): if not params.session_token: try: - LoginCommand().execute(params, email=params.user, password=params.password, new_login=False) + # Some commands (like logout) need auth but not sync + skip_sync = getattr(command, 'skip_sync_on_auth', False) + LoginCommand().execute(params, email=params.user, password=params.password, new_login=False, skip_sync=skip_sync) except KeyboardInterrupt: logging.info('Canceled') if not params.session_token: @@ -419,6 +418,8 @@ def is_msp(params_local): api.sync_down(params) return result else: + if not params.session_token and utils.is_email(orig_cmd): + return LoginCommand().execute(params, email=orig_cmd, new_login=False) display_command_help(show_enterprise=(params.enterprise is not None)) @@ -690,7 +691,7 @@ def read_command_with_continuation(prompt_session, params): return result -def loop(params): # type: (KeeperParams) -> int +def loop(params, skip_init=False, suppress_goodbye=False, new_login=False): # type: (KeeperParams, bool, bool, bool) -> int global prompt_session error_no = 0 suppress_errno = False @@ -707,13 +708,17 @@ def loop(params): # type: (KeeperParams) -> int complete_style=CompleteStyle.MULTI_COLUMN, complete_while_typing=False) - display.welcome() - versioning.welcome_print_version(params) + if not skip_init: + display.welcome() + versioning.welcome_print_version(params) + # Show government warning for GOV environments when entering interactive shell + if params.server and 'govcloud' in params.server.lower(): + display.show_government_warning() - if not params.batch_mode: + if not params.batch_mode and not skip_init: if params.user: try: - LoginCommand().execute(params, email=params.user, password=params.password, new_login=False) + LoginCommand().execute(params, email=params.user, password=params.password, new_login=new_login) except KeyboardInterrupt: print('') except EOFError: @@ -722,12 +727,14 @@ def loop(params): # type: (KeeperParams) -> int logging.error(e) else: if params.device_token: - logging.info('Current Keeper region: %s', params.server) - else: - logging.info('Use "server" command to change Keeper region > "server US"') - for region in KEEPER_PUBLIC_HOSTS: - logging.info('\t%s: %s', region, KEEPER_PUBLIC_HOSTS[region]) - logging.info('To login type: login ') + logging.info('Region: %s', params.server) + print() + logging.info("You are not logged in.") + print(f'Type {Fore.GREEN}login {Fore.RESET} to authenticate or {Fore.GREEN}server {Fore.RESET} to change data centers.') + print(f'Type {Fore.GREEN}?{Fore.RESET} for a list of all available commands.') + + # Mark that we're in the shell loop (used by supershell to know if it should start a shell on exit) + params._in_shell_loop = True while True: if params.session_token: @@ -800,7 +807,10 @@ def loop(params): # type: (KeeperParams) -> int if params.batch_mode and error_no != 0 and not suppress_errno: break - if not params.batch_mode: + # Clear the shell loop flag + params._in_shell_loop = False + + if not params.batch_mode and not suppress_goodbye: logging.info('\nGoodbye.\n') return error_no diff --git a/keepercommander/command_categories.py b/keepercommander/command_categories.py index 60bc61a3a..4789cc6c1 100644 --- a/keepercommander/command_categories.py +++ b/keepercommander/command_categories.py @@ -18,7 +18,7 @@ # Sharing Commands 'Sharing Commands': { 'share-record', 'share-folder', 'record-permissions', 'record-permission', 'one-time-share', - 'external-shares-report' + 'ext-shares-report' }, # Record Type Commands @@ -36,7 +36,7 @@ 'Reporting Commands': { 'audit-log', 'audit-report', 'audit-alert', 'user-report', 'security-audit-report', 'share-report', 'shared-records-report', 'aging-report', 'action-report', - 'compliance-report', 'compliance', 'external-shares-report', 'risk-management', + 'compliance-report', 'compliance', 'ext-shares-report', 'risk-management', 'security-audit' }, @@ -86,11 +86,6 @@ 'service-config-add' }, - # Email Configuration Commands - 'Email Configuration Commands': { - 'email-config' - }, - # Email Configuration Commands 'Email Configuration Commands': { 'email-config' @@ -99,9 +94,9 @@ # Miscellaneous Commands 'Miscellaneous Commands': { 'this-device', 'login', 'login-status', 'biometric', 'whoami', 'logout', - 'help', 'sync-down', 'version', 'clear', 'run-batch', 'generate', + 'help', 'sync-down', 'version', 'clear', 'history', 'quit', 'run-batch', 'generate', 'reset-password', 'sync-security-data', 'keeper-fill', '2fa', 'create-account', - 'run-as', 'sleep', 'server', 'proxy', 'keep-alive' + 'run-as', 'sleep', 'server', 'proxy', 'keep-alive', 'supershell' }, # KeeperPAM Commands @@ -109,9 +104,15 @@ 'pam' }, + # EPM Commands + 'EPM Commands': { + 'epm' + }, + # Legacy Commands 'Legacy Commands': { - 'rotate', 'connect', 'ssh', 'ssh-agent', 'rdp', 'rsync', 'set', 'echo' + 'rotate', 'connect', 'ssh', 'ssh-agent', 'rdp', 'rsync', 'set', 'echo', + 'mysql', 'postgresql' } } @@ -120,9 +121,9 @@ def get_command_category(command): for category, commands in COMMAND_CATEGORIES.items(): if command in commands: return category - + # Default category for uncategorized commands - return 'Other' + return 'Miscellaneous Commands' def get_category_order(): """Return the preferred order for displaying categories""" @@ -141,9 +142,8 @@ def get_category_order(): 'Domain Management Commands', 'Email Configuration Commands', 'Service Mode REST API', - 'Email Configuration Commands', 'Miscellaneous Commands', 'KeeperPAM Commands', - 'Legacy Commands', - 'Other' - ] \ No newline at end of file + 'EPM Commands', + 'Legacy Commands' + ] diff --git a/keepercommander/commands/base.py b/keepercommander/commands/base.py index d2d8de802..22f237248 100644 --- a/keepercommander/commands/base.py +++ b/keepercommander/commands/base.py @@ -132,6 +132,16 @@ def register_commands(commands, aliases, command_info): commands['email-config'] = EmailConfigCommand() command_info['email-config'] = 'Email provider configuration management' + # SuperShell requires textual library - only register if available + if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 9): + try: + from .supershell import SuperShellCommand + commands['supershell'] = SuperShellCommand() + command_info['supershell'] = 'Launch full terminal vault UI with vim navigation' + aliases['ss'] = 'supershell' + except ImportError: + pass # textual not installed, skip supershell + from . import credential_provision credential_provision.register_commands(commands) credential_provision.register_command_info(aliases, command_info) @@ -203,8 +213,9 @@ def register_enterprise_commands(commands, aliases, command_info): if sys.version_info.major > 3 or (sys.version_info.major == 3 and sys.version_info.minor >= 9): from.pedm import pedm_admin pedm_command = pedm_admin.PedmCommand() - commands['pedm'] = pedm_command - command_info['pedm'] = pedm_command.description + commands['epm'] = pedm_command + command_info['epm'] = pedm_command.description + aliases['pedm'] = 'epm' def register_msp_commands(commands, aliases, command_info): diff --git a/keepercommander/commands/compliance.py b/keepercommander/commands/compliance.py index e5f9dbb6b..49d59e8e9 100644 --- a/keepercommander/commands/compliance.py +++ b/keepercommander/commands/compliance.py @@ -160,6 +160,8 @@ def execute(self, params, **kwargs): # type: (KeeperParams, any) -> any headers = self.report_headers if report_fmt == 'json' else [field_to_title(h) for h in self.report_headers] report = dump_report_data(report_data, headers, title=self.title, fmt=report_fmt, filename=kwargs.get('output'), column_width=32, group_by=self.group_by_column) + if no_cache: + sd.storage.delete_db() return report @@ -454,27 +456,31 @@ def compile_user_report(user, access_events): def get_aging_data(rec_ids): if not rec_ids: return {} - aging_data = {r: {'created': None, 'last_modified': None, 'last_rotation': None} for r in rec_ids} + aging_data = {r: {'created': None, 'last_modified': None, 'last_rotation': None} for r in rec_ids if r} now = datetime.datetime.now() max_stored_age_dt = now - datetime.timedelta(days=1) max_stored_age_ts = int(max_stored_age_dt.timestamp()) - stored_entities = sox_data.storage.get_record_aging().get_all() - stored_aging_data = {e.record_uid: {'created': from_ts(e.created), 'last_modified': from_ts(e.last_modified), 'last_rotation': from_ts(e.last_rotation)} for e in stored_entities} + stored_aging_data = {} + if not kwargs.get('no_cache'): + stored_entities = sox_data.storage.get_record_aging().get_all() + stored_aging_data = {e.record_uid: {'created': from_ts(e.created), 'last_modified': from_ts(e.last_modified), 'last_rotation': from_ts(e.last_rotation)} for e in stored_entities if e.record_uid} aging_data.update(stored_aging_data) - def get_requests(filter_recs, filter_type, order='desc', aggregate='last_created'): + def get_requests(filter_recs, filter_type, order='descending', aggregate='last_created'): columns = ['record_uid'] requests = [] while filter_recs: chunk = filter_recs[:API_EVENT_SUMMARY_ROW_LIMIT] filter_recs = filter_recs[API_EVENT_SUMMARY_ROW_LIMIT:] + rq_filter = {'record_uid': chunk} + if filter_type: rq_filter.update({'audit_event_type': filter_type}) request = dict( command = 'get_audit_event_reports', report_type = 'span', scope = 'enterprise', aggregate = [aggregate], limit = API_EVENT_SUMMARY_ROW_LIMIT, - filter = dict(record_uid=chunk, audit_event_type=filter_type), + filter = rq_filter, columns = columns, order = order ) @@ -486,13 +492,13 @@ def get_request_params(record_aging_event): known_events_map = get_known_aging_data(record_aging_event) filter_recs = [uid for uid in rec_ids if uid not in known_events_map] types_by_aging_event = dict( - created = None, + created = [], last_modified = ['record_update'], last_rotation = ['record_rotation_scheduled_ok', 'record_rotation_on_demand_ok'] ) filter_types = types_by_aging_event.get(record_aging_event) - order, aggregate = ('asc', 'first_created') if record_aging_event == 'created' \ - else ('desc', 'last_created') + order, aggregate = ('ascending', 'first_created') if record_aging_event == 'created' \ + else ('descending', 'last_created') return filter_recs, filter_types, order, aggregate def fetch_events(requests): @@ -513,8 +519,8 @@ def get_known_aging_data(event_type): def get_aging_event_dts(event_type): events = get_aging_events(event_type) aggregate = 'first_created' if event_type == 'created' else 'last_created' - record_timestamps = {event.get('record_uid', ''): event.get(aggregate) for event in events} - return {rec: format_datetime(ts) for rec, ts in record_timestamps.items()} + record_timestamps = {event.get('record_uid'): event.get(aggregate) for event in events if event.get('record_uid')} + return {rec: from_ts(ts) for rec, ts in record_timestamps.items()} aging_stats = ['created', 'last_modified', 'last_rotation'] record_events_by_stat = {stat: get_aging_event_dts(stat) for stat in aging_stats} @@ -523,13 +529,16 @@ def get_aging_event_dts(event_type): aging_data.get(record, {}).update({stat: dt}) stat == 'created' and aging_data.get(record, {}).setdefault('last_modified', dt) - save_aging_data(aging_data) + if not kwargs.get('no_cache'): + save_aging_data(aging_data) return aging_data def save_aging_data(aging_data): existing_entities = sox_data.storage.get_record_aging() updated_entities = [] for r, events in aging_data.items(): + if not r: + continue entity = existing_entities.get_entity(r) or StorageRecordAging(r) created_dt = events.get('created') created_ts = int(created_dt.timestamp()) if created_dt else 0 @@ -683,4 +692,3 @@ def generate_report_data(self, params, kwargs, sox_data, report_fmt, node, root_ row = [sfuid, sf_team_uids, sf_team_names, records, team_users + users] report_data.append(row) return report_data - diff --git a/keepercommander/commands/discover/job_start.py b/keepercommander/commands/discover/job_start.py index fd4d1003a..bfb324536 100644 --- a/keepercommander/commands/discover/job_start.py +++ b/keepercommander/commands/discover/job_start.py @@ -41,10 +41,10 @@ class PAMGatewayActionDiscoverJobStartCommand(PAMGatewayActionDiscoverCommandBas action='store_true', help='Skip discovering directories.') parser.add_argument('--skip-cloud-users', required=False, dest='skip_cloud_users', action='store_true', help='Skip discovering cloud users.') - parser.add_argument('--cred', required=False, dest='credentials', - action='append', help='List resource credentials.') - parser.add_argument('--cred-file', required=False, dest='credential_file', - action='store', help='A JSON file containing list of credentials.') + # parser.add_argument('--cred', required=False, dest='credentials', + # action='append', help='List resource credentials.') + # parser.add_argument('--cred-file', required=False, dest='credential_file', + # action='store', help='A JSON file containing list of credentials.') def get_parser(self): return PAMGatewayActionDiscoverJobStartCommand.parser @@ -145,7 +145,7 @@ def execute(self, params, **kwargs): if len(kv) != 2: print(f"{bcolors.FAIL}A '--cred' is invalid. It does not have a value.{bcolors.ENDC}") return - if hasattr(c, kv[0]) is False: + if not hasattr(c, kv[0]): print(f"{bcolors.FAIL}A '--cred' is invalid. The key '{kv[0]}' is invalid.{bcolors.ENDC}") return if hasattr(c, kv[1]) == "": @@ -170,14 +170,14 @@ def execute(self, params, **kwargs): print(f"{bcolors.FAIL}The JSON file {credential_files} could not be imported: {err}{bcolors.ENDC}") return - if isinstance(creds, list) is False: + if not isinstance(creds, list): print(f"{bcolors.FAIL}Credential file is invalid. Structure is not an array.{bcolors.ENDC}") return num = 1 for obj in creds: c = CredentialBase() for key in obj: - if hasattr(c, key) is False: + if not hasattr(c, key): print(f"{bcolors.FAIL}Object {num} has the invalid key {key}.{bcolors.ENDC}") return setattr(c, key, obj[key]) diff --git a/keepercommander/commands/discover/result_process.py b/keepercommander/commands/discover/result_process.py index a15fb32e1..1653cbccd 100644 --- a/keepercommander/commands/discover/result_process.py +++ b/keepercommander/commands/discover/result_process.py @@ -7,17 +7,22 @@ from keeper_secrets_manager_core.utils import url_safe_str_to_bytes from . import PAMGatewayActionDiscoverCommandBase, GatewayContext -from ..pam.router_helper import router_get_connected_gateways, router_set_record_rotation_information +from ..pam.router_helper import (router_get_connected_gateways, router_set_record_rotation_information, + router_configure_resource) from ... import api, subfolder, utils, crypto, vault, vault_extensions from ...display import bcolors -from ...proto import router_pb2, record_pb2 -from ...discovery_common.jobs import Jobs +from ...proto import router_pb2, record_pb2, pam_pb2 +from ...discovery_common.jobs import Jobs, JobItem +from ...discovery_common.dag_sort import sort_infra_vertices +from ...discovery_common.infrastructure import Infrastructure from ...discovery_common.process import Process, QuitException, NoDiscoveryDataException from ...discovery_common.types import ( DiscoveryObject, UserAcl, PromptActionEnum, PromptResult, BulkRecordAdd, BulkRecordConvert, BulkProcessResults, BulkRecordSuccess, BulkRecordFail, DirectoryInfo, NormalizedRecord, RecordField) +from ...discovery_common.constants import PAM_USER +from ...discovery_common.constants import VERTICES_SORT_MAP from pydantic import BaseModel -from typing import Optional, List, Any, TYPE_CHECKING +from typing import Optional, List, Any, Tuple, Dict, TYPE_CHECKING from ...api import get_records_add_request @@ -44,6 +49,10 @@ def _ok(value: str) -> str: return f"{bcolors.OKGREEN}{value}{bcolors.ENDC}" +def _w(value: str) -> str: + return f"{bcolors.WARNING}{value}{bcolors.ENDC}" + + # This is used for the admin user search class AdminSearchResult(BaseModel): record: Any @@ -61,13 +70,10 @@ class PAMGatewayActionDiscoverResultProcessCommand(PAMGatewayActionDiscoverComma parser = argparse.ArgumentParser(prog='pam-action-discover-process') parser.add_argument('--job-id', '-j', required=True, dest='job_id', action='store', help='Discovery job to process.') - - # This is not ready yet. - # parser.add_argument('--smart-add', required=False, dest='smart_add', action='store_true', - # help='Automatically add resources with credentials and their users.') - parser.add_argument('--add-all', required=False, dest='add_all', action='store_true', help='Respond with ADD for all prompts.') + parser.add_argument('--preview', required=False, dest='do_preview', action='store_true', + help='Preview the results') parser.add_argument('--debug-gs-level', required=False, dest='debug_level', action='store', help='GraphSync debug level. Default is 0', type=int, default=0) @@ -106,7 +112,7 @@ def _get_shared_folder(params: KeeperParams, pad: str, gateway_context: GatewayC print(f"{pad}{_f('Input was not a number.')}") @staticmethod - def get_field_values(record: TypedRecord, field_type: str) -> List[str]: + def get_field_values(record: TypedRecord, field_type: str) -> List[Any]: return next( (f.value for f in record.fields @@ -128,7 +134,7 @@ def get_keys_by_record(self, params: KeeperParams, gateway_context: GatewayConte key_field = Process.get_key_field(record.record_type) keys = [] if key_field == "host_port": - values = self.get_field_values(record, "pamHostname") + values = self.get_field_values(record, "pamHostname") # type: List[dict] if len(values) == 0: return [] @@ -189,7 +195,6 @@ def _record_lookup(record_uid: str, context: Optional[Any] = None) -> Optional[ record_uid=record.record_uid, record_type=record.record_type, title=record.title, - notes=record.notes ) for field in record.fields: normalized_record.fields.append( @@ -291,11 +296,11 @@ def _edit_record(self, content: DiscoveryObject, pad: str, editable: List[str]) break # If this is the first line, check if line is a path to a file. - if first_line is True: + if first_line: try: test_file = line.strip() logging.debug(f"is first line, check for file path for '{test_file}'") - if os.path.exists(test_file) is True: + if os.path.exists(test_file): with open(test_file, "r") as fh: new_value = fh.read() fh.close() @@ -320,7 +325,7 @@ def _edit_record(self, content: DiscoveryObject, pad: str, editable: List[str]) # Is the value a path to a file, i.e., a private key file. try: - if os.path.exists(new_value) is True: + if os.path.exists(new_value): with open(new_value, "r") as fh: new_value = fh.read() fh.close() @@ -393,13 +398,13 @@ def _prompt_display_fields(self, content: DiscoveryObject, pad: str) -> List[str formatted_value.append(value) value = ", ".join(formatted_value) else: - if has_editable is True: + if has_editable: value = f"{bcolors.FAIL}MISSING{bcolors.ENDC}" else: value = f"{bcolors.OKBLUE}None{bcolors.ENDC}" color = bcolors.HEADER - if has_editable is True: + if has_editable: color = bcolors.OKGREEN rows = str(value).split("\n") @@ -482,7 +487,7 @@ def _prompt(self, print(f"{pad}{_h(content.description)}") show_current_object = True - while show_current_object is True: + while show_current_object: print(f"{pad}{bcolors.HEADER}Record Title:{bcolors.ENDC} {content.title}") logging.debug(f"Fields: {content.fields}") @@ -589,10 +594,16 @@ def _prompt(self, raise QuitException() print() + return PromptResult( + action=PromptActionEnum.SKIP, + acl=acl, + content=content + ) + def _find_user_record(self, params: KeeperParams, bulk_convert_records: List[BulkRecordConvert], - context: Optional[Any] = None) -> (Optional[TypedRecord], bool): + context: Optional[Any] = None) -> Tuple[Optional[TypedRecord], bool]: gateway_context = context.get("gateway_context") # type: GatewayContext record_link = context.get("record_link") # type: RecordLink @@ -672,7 +683,7 @@ def _find_user_record(self, parent_record = vault.TypedRecord.load(params, parent_record_uid) # type: Optional[TypedRecord] if parent_record is not None: is_directory_user = self._is_directory_user(parent_record.record_type) - if is_directory_user is False: + if not is_directory_user: logging.debug(f"pamUser parent for {user_record.title}, " "{user_record.record_uid} is not a directory; BAD for search") continue @@ -737,7 +748,7 @@ def _find_user_record(self, b = bcolors.BOLD tc = "" index_str = user_index - if admin_search_result.being_used is True: + if admin_search_result.being_used: hc = bcolors.WARNING b = bcolors.WARNING tc = bcolors.WARNING @@ -750,7 +761,7 @@ def _find_user_record(self, f'{tc + "(Already taken)" + bcolors.ENDC if admin_search_result.being_used is True else ""}') user_index += 1 - if has_local_user is True: + if has_local_user: print(f"{bcolors.BOLD}* Not a PAM User record. " f"A PAM User would be generated from this record.{bcolors.ENDC}") @@ -763,7 +774,7 @@ def _find_user_record(self, else: try: selected = admin_search_results[int(select) - 1] - if selected.being_used is True: + if selected.being_used: print(f"{bcolors.FAIL}Cannot select a record that has already been taken. " f"Another record is using this local user as its administrator.{bcolors.ENDC}") return None, False @@ -773,6 +784,8 @@ def _find_user_record(self, print(f"{bcolors.FAIL}Entered row index does not exists.{bcolors.ENDC}") continue + return None, False + @staticmethod def _handle_admin_record_from_record(record: TypedRecord, content: DiscoveryObject, @@ -865,7 +878,7 @@ def _prompt_admin(self, acl: UserAcl, bulk_convert_records: List[BulkRecordConvert], indent: int = 0, - context: Optional[Any] = None) -> PromptResult: + context: Optional[Any] = None) -> Optional[PromptResult]: if content is None: raise Exception("The admin content was not passed in to prompt the user.") @@ -971,7 +984,7 @@ def _prompt_confirm_add(bulk_add_records: List[BulkRecordAdd]): print(f"{bcolors.FAIL}Did not get 'Y' or 'N'{bcolors.ENDC}") @staticmethod - def _prepare_record(content: DiscoveryObject, context: Optional[Any] = None) -> (Any, str): + def _prepare_record(content: DiscoveryObject, context: Optional[Any] = None) -> Tuple[Any, str]: """ Prepare the Vault record side. @@ -1051,6 +1064,10 @@ def _prepare_record(content: DiscoveryObject, context: Optional[Any] = None) -> def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: Optional[Any] = None) -> ( BulkProcessResults): + """ + Create Vault records, setup rotation settings, and configure the resource (if resource). + """ + if len(bulk_add_records) == 1: print("Adding the record to the Vault ...") else: @@ -1061,6 +1078,8 @@ def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: Optiona build_process_results = BulkProcessResults() + ############################################################################################################## + # # STEP 1 - Batch add new records # Generate a list of RecordAdd instance. @@ -1091,12 +1110,17 @@ def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: Optiona logging.debug(f"attempted to batch add {len(bulk_add_records)} record(s), " f"only have {len(add_results)} results.") - # STEP 3 - Add rotation settings. - # Use the list we passed in, find the results, and add if the additions were successful. + ############################################################################################################## + # + # STEP 2 - Add rotation settings for user and resource configuration for resources + # At this point the all the records have been created. # Keep track of each record we create a rotation for to avoid version problems, if there was a dup. created_cache = [] + # TODO: There is a bulk version of the following code, it's not live. + # Wait until live, then switch code to use that. + # For the records passed in to be created. print("add rotation settings: ", end="") sys.stdout.flush() @@ -1136,7 +1160,7 @@ def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: Optiona success = (result.status == record_pb2.RecordModifyResult.DESCRIPTOR.values_by_name['RS_SUCCESS'].number) status = record_pb2.RecordModifyResult.DESCRIPTOR.values_by_number[result.status].name - if success is False: + if not success: build_process_results.failure.append( BulkRecordFail( title=title, @@ -1146,24 +1170,39 @@ def _create_records(cls, bulk_add_records: List[BulkRecordAdd], context: Optiona logging.debug(f"Had problem adding record for {title}: {status}") continue - rq = router_pb2.RouterRecordRotationRequest() - rq.recordUid = url_safe_str_to_bytes(bulk_record.record_uid) - rq.revision = 0 + # Only set the rotation setting if the record is a PAM User. + if bulk_record.record_type == PAM_USER: - # Set the gateway/configuration that this record should be connected. - rq.configurationUid = url_safe_str_to_bytes(gateway_context.configuration_uid) + rq = router_pb2.RouterRecordRotationRequest() + rq.recordUid = url_safe_str_to_bytes(bulk_record.record_uid) + rq.revision = 0 - # Only set the resource if the record type is a PAM User. - # Machines, databases, and directories have a login/password in the record that indicates who the admin is. - if bulk_record.record_type == "pamUser" and bulk_record.parent_record_uid is not None: - rq.resourceUid = url_safe_str_to_bytes(bulk_record.parent_record_uid) + # Set the gateway/configuration that this record should be connected. + rq.configurationUid = url_safe_str_to_bytes(gateway_context.configuration_uid) - # Right now, the schedule and password complexity are not set. This would be part of a rule engine. - rq.schedule = '' - rq.pwdComplexity = b'' - rq.disabled = rotation_disabled + if bulk_record.parent_record_uid is not None: + rq.resourceUid = url_safe_str_to_bytes(bulk_record.parent_record_uid) - router_set_record_rotation_information(params, rq) + # Right now, the schedule and password complexity are not set. This would be part of a rule engine. + rq.schedule = '' + rq.pwdComplexity = b'' + rq.disabled = rotation_disabled + + router_set_record_rotation_information(params, rq) + + # This will be a resource. + # A LINK edge will be created between the configuration and resource. + # If there is an admin user, it will be set on the resource. + else: + + # This will create a LINK between the PAM Configuration and the resource. + rq = pam_pb2.PAMResourceConfig() + rq.recordUid = url_safe_str_to_bytes(bulk_record.record_uid) + rq.networkUid = url_safe_str_to_bytes(gateway_context.configuration_uid) + if bulk_record.admin_uid: + rq.adminUid = url_safe_str_to_bytes(bulk_record.admin_uid) + + router_configure_resource(params, rq) created_cache.append(bulk_record.record_uid) @@ -1206,8 +1245,6 @@ def _convert_records(cls, bulk_convert_records: List[BulkRecordConvert], context # Machines, databases, and directories have a login/password in the record that indicates who the admin is. if record.record_type == "pamUser" and bulk_convert_record.parent_record_uid is not None: rq.resourceUid = url_safe_str_to_bytes(bulk_convert_record.parent_record_uid) - else: - rq.resourceUid = None # Right now, the schedule and password complexity are not set. This would be part of a rule engine. rq.schedule = '' @@ -1281,17 +1318,155 @@ def remove_job(params: KeeperParams, configuration_record: KeeperRecord, job_id: logging.error(err) print(f"{bcolors.FAIL}No items left to process. Failed to remove discovery job.{bcolors.ENDC}") + def preview(self, job_item: JobItem, params: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0): + + context = { + "params": params, + "gateway_context": gateway_context, + } + + sync_point = job_item.sync_point + infra = Infrastructure(record=gateway_context.configuration, + params=params, + logger=logging, + debug_level=debug_level) + infra.load(sync_point) + + configuration = None + try: + configuration = infra.get_root.has_vertices()[0] + except (Exception,): + print(f"{bcolors.FAIL}Could not find the configuration in the infrastructure graph. " + f"Has discovery been run for this gateway?{bcolors.ENDC}") + + record_type_to_vertices_map = sort_infra_vertices(configuration) + + # ------------ + + def _print_resource(rt: str, rule_result: str): + + printed_something = False + + titles = { + "pamDirectory": "Directories", + "pamMachine": "Machines", + "pamDatabase": "Databases" + } # type: Dict[str, Optional[str]] + + for rv in record_type_to_vertices_map[rt]: # type: DAGVertex + if not rv.active or not rv.has_data: + continue + user_vertices = rv.has_vertices() + + user_list = [] + for user_vertex in user_vertices: + if not user_vertex.active or not user_vertex.has_data: + continue + + user_content = DiscoveryObject.get_discovery_object(user_vertex) + if user_content.ignore_object or self._record_lookup(user_content.record_uid, context) is not None: + continue + + user_list.append(f" . {user_content.item.user} ({user_content.name})") + + c = DiscoveryObject.get_discovery_object(rv) + if len(user_list) == 0 and c.action_rules_result != rule_result or c.ignore_object: + continue + + has_record = "" + record_uid = c.record_uid + if record_uid is not None: + if self._record_lookup(record_uid, context): + has_record = f" (record exists: {record_uid})" + if len(user_list) == 0: + continue + else: + record_uid = None + + if c.action_rules_result != rule_result and not record_uid: + continue + + title = titles.get(c.record_type) + if title is not None: + print(f" {_b(title)}") + titles[c.record_type] = None + + ip = "" + if c.item.host != c.item.ip: + ip = f" ({c.item.ip})" + + with_admin = "" + if c.admin_uid is not None and not record_uid: + with_admin = f" with Administrator UID {c.admin_uid}" + + print(f" * {c.description}{ip}{with_admin}{has_record}{bcolors.ENDC}") + printed_something = True + + if record_uid: + for user in user_list: + print(user) + + return printed_something + + # ------------ + + def _print_cloud_user(rt: str, rule_result: str): + + title = "Users" + + for user_vertex in record_type_to_vertices_map[rt]: # type: DAGVertex + if not user_vertex.active or not user_vertex.has_data: + continue + + uc = DiscoveryObject.get_discovery_object(user_vertex) + + if (uc.action_rules_result != rule_result + or uc.ignore_object + or self._record_lookup(uc.record_uid, context) is not None): + continue + + if title is not None: + print(f" {_b(title)}") + title = None + + print(f" * {uc.item.user} ({uc.name})") + + # ------------ + + print("") + print(_h("Will Be Automatically Added")) + nothing_to_print = True + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + if record_type == "pamUser": + _print_cloud_user("pamUser", rule_result="add") + else: + if _print_resource(record_type, rule_result="add"): + nothing_to_print = False + if nothing_to_print: + print(f" {_w('No records will be automatically added.')}") + + print("") + print(_h("Will Be Prompted For")) + nothing_to_print = True + for record_type in sorted(record_type_to_vertices_map, key=lambda i: VERTICES_SORT_MAP[i]['order']): + if record_type == "pamUser": + _print_cloud_user("pamUser", rule_result="prompt") + else: + if _print_resource(record_type, rule_result="prompt"): + nothing_to_print = False + if nothing_to_print: + print(f" {_w('No items will be prompted.')}") + + print("") + def execute(self, params: KeeperParams, **kwargs): if not hasattr(params, 'pam_controllers'): router_get_connected_gateways(params) + do_preview = kwargs.get("do_preview", False) job_id = kwargs.get("job_id") add_all = kwargs.get("add_all", False) - smart_add = kwargs.get("smart_add", False) - - # Right now, keep dry_run False. We might add it back in. - dry_run = kwargs.get("dry_run", False) debug_level = kwargs.get("debug_level", 0) all_gateways = GatewayContext.all_gateways(params) @@ -1329,6 +1504,15 @@ def execute(self, params: KeeperParams, **kwargs): print(f'{bcolors.FAIL}Discovery job failed. Cannot process.{bcolors.ENDC}') return + # Preview is a just a way to list which items will be added or prompted. + if do_preview: + self.preview( + job_item=job_item, + params=params, + gateway_context=gateway_context, + ) + return + process = Process( record=configuration_record, job_id=job_item.job_id, @@ -1337,16 +1521,7 @@ def execute(self, params: KeeperParams, **kwargs): debug_level=debug_level, ) - if dry_run is True: - if add_all is True: - logging.debug("dry run has been set, disable auto add.") - add_all = False - - print(f"{bcolors.HEADER}The DRY RUN flag has been set. The rule engine will not add any records. " - f"You will not be prompted to edit or add records.{bcolors.ENDC}") - print("") - - if add_all is True: + if add_all: print(f"{bcolors.HEADER}The ADD ALL flag has been set. All found items will be added.{bcolors.ENDC}") print("") @@ -1359,9 +1534,6 @@ def execute(self, params: KeeperParams, **kwargs): # Prompt user the about adding records prompt_func=self._prompt, - # Flag to auto add resources with credential, and all it users. - smart_add=smart_add, - # Prompt user for an admin for a resource prompt_admin_func=self._prompt_admin, @@ -1392,7 +1564,7 @@ def execute(self, params: KeeperParams, **kwargs): context={ "params": params, "gateway_context": gateway_context, - "dry_run": dry_run, + "dry_run": False, "add_all": add_all } ) @@ -1403,7 +1575,7 @@ def execute(self, params: KeeperParams, **kwargs): if results is not None and results.num_results > 0: print(f"{bcolors.OKGREEN}Successfully added {results.success_count} " f"record{'s' if results.success_count != 1 else ''}.{bcolors.ENDC}") - if results.has_failures is True: + if results.has_failures: print(f"{bcolors.FAIL}There were {results.failure_count} " f"failure{'s' if results.failure_count != 1 else ''}.{bcolors.ENDC}") for fail in results.failure: diff --git a/keepercommander/commands/discover/rule_add.py b/keepercommander/commands/discover/rule_add.py index 6e43713b9..2dedaa089 100644 --- a/keepercommander/commands/discover/rule_add.py +++ b/keepercommander/commands/discover/rule_add.py @@ -22,10 +22,14 @@ class PAMGatewayActionDiscoverRuleAddCommand(PAMGatewayActionDiscoverCommandBase dest='rule_action', action='store', help='Action to take if rule matches') parser.add_argument('--priority', '-p', required=True, dest='priority', action='store', type=int, help='Rule execute priority') + parser.add_argument('--name', '-n', required=False, dest='name', action='store', type=str, + help='Rule name') parser.add_argument('--ignore-case', required=False, dest='ignore_case', action='store_true', help='Ignore value case. Rule value must be in lowercase.') parser.add_argument('--shared-folder-uid', required=False, dest='shared_folder_uid', action='store', help='Folder to place record.') + parser.add_argument('--admin-uid', required=False, dest='admin_uid', + action='store', help='Admin record UID to use for resource.') parser.add_argument('--statement', '-s', required=True, dest='statement', action='store', help='Rule statement') @@ -61,7 +65,7 @@ def validate_rule_statement(params: KeeperParams, gateway_context: GatewayContex statement_struct = data.get("statementStruct") logging.debug(f"Rule Structure = {statement_struct}") - if isinstance(statement_struct, list) is False: + if not isinstance(statement_struct, list): raise Exception(f"The structured rule statement is not a list.") return statement_struct @@ -97,13 +101,27 @@ def execute(self, params, **kwargs): statement=statement ) + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None and len(shared_folder_uid) != 22: + print(f"{bcolors.FAIL}The shared folder UID {shared_folder_uid} is not the correct length." + f"{bcolors.ENDC}") + return + + admin_uid = kwargs.get("admin_uid") + if admin_uid is not None and len(admin_uid) != 22: + print(f"{bcolors.FAIL}The admin UID {admin_uid} is not the correct length." + f"{bcolors.ENDC}") + return + # If the rule passes its validation, then add control DAG rules = Rules(record=gateway_context.configuration, params=params) new_rule = ActionRuleItem( + name=kwargs.get("name"), action=kwargs.get("rule_action"), priority=kwargs.get("priority"), case_sensitive=not kwargs.get("ignore_case", False), - shared_folder_uid=kwargs.get("shared_folder_uid"), + shared_folder_uid=shared_folder_uid, + admin_uid=admin_uid, statement=statement_struct, enabled=True ) diff --git a/keepercommander/commands/discover/rule_list.py b/keepercommander/commands/discover/rule_list.py index e67f702e4..ad06d02ee 100644 --- a/keepercommander/commands/discover/rule_list.py +++ b/keepercommander/commands/discover/rule_list.py @@ -26,24 +26,28 @@ def print_rule_table(rule_list: List[RuleItem]): print("") print(f"{bcolors.HEADER}{'Rule ID'.ljust(15, ' ')} " + f"{'Name'.ljust(20, ' ')} " f"{'Action'.ljust(6, ' ')} " f"{'Priority'.ljust(8, ' ')} " f"{'Case'.ljust(12, ' ')} " f"{'Added'.ljust(19, ' ')} " f"{'Shared Folder UID'.ljust(22, ' ')} " + f"{'Admin UID'.ljust(22, ' ')} " "Rule" f"{bcolors.ENDC}") print(f"{''.ljust(15, '=')} " + f"{''.ljust(20, '=')} " f"{''.ljust(6, '=')} " f"{''.ljust(8, '=')} " f"{''.ljust(12, '=')} " f"{''.ljust(19, '=')} " f"{''.ljust(22, '=')} " + f"{''.ljust(22, '=')} " f"{''.ljust(10, '=')} ") for rule in rule_list: - if rule.case_sensitive is True: + if rule.case_sensitive: ignore_case_str = "Sensitive" else: ignore_case_str = "Insensitive" @@ -51,12 +55,23 @@ def print_rule_table(rule_list: List[RuleItem]): shared_folder_uid = "" if rule.shared_folder_uid is not None: shared_folder_uid = rule.shared_folder_uid + + admin_uid = "" + if rule.admin_uid is not None: + admin_uid = rule.admin_uid + + name = "" + if rule.name is not None: + name = rule.name + print(f"{bcolors.OKGREEN}{rule.rule_id.ljust(14, ' ')}{bcolors.ENDC} " + f"{name[:20].ljust(20, ' ')} " f"{rule.action.value.ljust(6, ' ')} " f"{str(rule.priority).rjust(8, ' ')} " f"{ignore_case_str.ljust(12, ' ')} " f"{rule.added_ts_str.ljust(19, ' ')} " f"{shared_folder_uid.ljust(22, ' ')} " + f"{admin_uid.ljust(22, ' ')} " f"{Rules.make_action_rule_statement_str(rule.statement)}") def execute(self, params, **kwargs): diff --git a/keepercommander/commands/discover/rule_remove.py b/keepercommander/commands/discover/rule_remove.py index 5a14b9c4e..27f4fbcc0 100644 --- a/keepercommander/commands/discover/rule_remove.py +++ b/keepercommander/commands/discover/rule_remove.py @@ -10,8 +10,10 @@ class PAMGatewayActionDiscoverRuleRemoveCommand(PAMGatewayActionDiscoverCommandB parser = argparse.ArgumentParser(prog='pam-action-discover-rule-remove') parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', help='Gateway name of UID') - parser.add_argument('--rule-id', '-i', required=True, dest='rule_id', action='store', + parser.add_argument('--rule-id', '-i', required=False, dest='rule_id', action='store', help='Identifier for the rule') + parser.add_argument('--remove-all', required=False, dest='remove_all', action='store_true', + help='Remove all the rules.') def get_parser(self): return PAMGatewayActionDiscoverRuleRemoveCommand.parser @@ -27,14 +29,28 @@ def execute(self, params, **kwargs): print(f'{bcolors.FAIL}Discovery job gateway [{gateway}] was not found.{bcolors.ENDC}') return + rule_id = kwargs.get("rule_id") + remove_all = kwargs.get("remove_all") + + if rule_id is None and remove_all is None: + print(f'{bcolors.FAIL}Either --rule-id or --remove-all are required.{bcolors.ENDC}') + return + try: - rule_id = kwargs.get("rule_id") rules = Rules(record=gateway_context.configuration, params=params) - rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) - if rule_item is None: - raise ValueError("Rule Id does not exist.") - rules.remove_rule(rule_item) + if remove_all: + rules.remove_all(RuleTypeEnum.ACTION) + print(f"{bcolors.OKGREEN}All rules removed.{bcolors.ENDC}") + else: + + rule_item = rules.get_rule_item(rule_type=RuleTypeEnum.ACTION, rule_id=rule_id) + if rule_item is None: + raise ValueError("Rule Id does not exist.") + rules.remove_rule(rule_item) - print(f"{bcolors.OKGREEN}Rule has been removed.{bcolors.ENDC}") + print(f"{bcolors.OKGREEN}Rule has been removed.{bcolors.ENDC}") except Exception as err: - print(f"{bcolors.FAIL}Rule was not removed: {err}{bcolors.ENDC}") + if remove_all: + print(f"{bcolors.FAIL}Rules have NOT been removed: {err}{bcolors.ENDC}") + else: + print(f"{bcolors.FAIL}Rule was not removed: {err}{bcolors.ENDC}") diff --git a/keepercommander/commands/discover/rule_update.py b/keepercommander/commands/discover/rule_update.py index 4be0a8d00..f4642d061 100644 --- a/keepercommander/commands/discover/rule_update.py +++ b/keepercommander/commands/discover/rule_update.py @@ -17,14 +17,27 @@ class PAMGatewayActionDiscoverRuleUpdateCommand(PAMGatewayActionDiscoverCommandB dest='rule_action', action='store', help='Update the action to take if rule matches') parser.add_argument('--priority', '-p', required=False, dest='priority', action='store', type=int, help='Update the rule execute priority') + parser.add_argument('--name', '-n', required=False, dest='name', action='store', type=str, + help='Rule name') parser.add_argument('--ignore-case', required=False, dest='ignore_case', action='store_true', help='Update the rule to ignore case') parser.add_argument('--no-ignore-case', required=False, dest='ignore_case', action='store_false', help='Update the rule to not ignore case') parser.add_argument('--shared-folder-uid', required=False, dest='shared_folder_uid', action='store', help='Update the folder to place record.') + parser.add_argument('--admin-uid', required=False, dest='admin_uid', + action='store', help='Admin record UID to use for resource.') + parser.add_argument('--clear-shared-folder-uid', required=False, dest='clear_shared_folder_uid', + action='store_true', help='Clear shared folder UID, use default.') + parser.add_argument('--clear-admin-uid', required=False, dest='clear_admin_uid', + action='store_true', help='Clear admin UID') parser.add_argument('--statement', '-s', required=False, dest='statement', action='store', help='Update the rule statement') + parser.add_argument('--active', required=False, dest='active', action='store_true', + help='Enable rule.') + parser.add_argument('--disable', required=False, dest='active', action='store_false', + help='Disable rule.') + parser.set_defaults(active=None, ignore_case=None) def get_parser(self): return PAMGatewayActionDiscoverRuleUpdateCommand.parser @@ -50,22 +63,69 @@ def execute(self, params, **kwargs): rule_action = kwargs.get("rule_action") if rule_action is not None: rule_item.action = RuleTypeEnum.find_enum(rule_action) + priority = kwargs.get("priority") if priority is not None: + print(" * Changing the priority of the rule.") rule_item.priority = priority + ignore_case = kwargs.get("ignore_case") if ignore_case is not None: + if ignore_case: + print(" * Ignore the case of text.") + else: + print(" * Make rule text case sensitive.") + rule_item.case_sensitive = not ignore_case - shared_folder_uid = kwargs.get("shared_folder_uid") - if shared_folder_uid is not None: - rule_item.shared_folder_uid = shared_folder_uid + + if kwargs.get("clear_shared_folder_uid"): + print(" * Clearing shared folder.") + rule_item.shared_folder_uid = None + else: + shared_folder_uid = kwargs.get("shared_folder_uid") + if shared_folder_uid is not None: + if len(shared_folder_uid) != 22: + print(f"{bcolors.FAIL}The shared folder UID {shared_folder_uid} is not the correct length." + f"{bcolors.ENDC}") + print(" * Changing shared folder UID.") + rule_item.shared_folder_uid = shared_folder_uid + + if kwargs.get("clear_admin_uid"): + print(" * Clearing resource admin UID.") + rule_item.admin_uid = None + else: + admin_uid = kwargs.get("admin_uid") + if admin_uid is not None: + if len(admin_uid) != 22: + print(f"{bcolors.FAIL}The admin UID {admin_uid} is not the correct length." + f"{bcolors.ENDC}") + return + print(" * Changing the resource admin UID.") + rule_item.admin_uid = admin_uid + statement = kwargs.get("statement") if statement is not None: + # validate_rule_statement will throw exceptions. rule_item.statement = PAMGatewayActionDiscoverRuleAddCommand.validate_rule_statement( params=params, gateway_context=gateway_context, statement=statement ) + print(" * Changing the rule statement.") + + name = kwargs.get("name") + if name is not None: + print(" * Changing the rule name.") + rule_item.name = name + + enabled = kwargs.get("active") + if enabled is not None: + if enabled: + print(" * Enabling the rule.") + else: + print(" * Disabling the rule.") + rule_item.enabled = enabled + rules.update_rule(rule_item) print(f"{bcolors.OKGREEN}Rule has been updated{bcolors.ENDC}") except Exception as err: diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index 48afab83a..76c563ac4 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -72,7 +72,9 @@ from .pam_debug.gateway import PAMDebugGatewayCommand from .pam_debug.rotation_setting import PAMDebugRotationSettingsCommand from .pam_debug.link import PAMDebugLinkCommand +from .pam_debug.vertex import PAMDebugVertexCommand from .pam_import.edit import PAMProjectCommand +from .pam_launch.launch import PAMLaunchCommand from .pam_service.list import PAMActionServiceListCommand from .pam_service.add import PAMActionServiceAddCommand from .pam_service.remove import PAMActionServiceRemoveCommand @@ -89,16 +91,18 @@ def validate_cron_field(field, min_val, max_val): - # Accept *, single number, range, step, list - pattern = r'^(\*|\d+|\d+-\d+|\*/\d+|\d+(,\d+)*|\d+-\d+/\d+)$' + # Accept *, single number, range, step, list, and L suffix for last day/week + pattern = r'^(\*|\d+L?|L[W]?|\d+-\d+|\*/\d+|\d+(,\d+)*|\d+-\d+/\d+)$' if not re.match(pattern, field): return False def is_valid_number(n): - return n.isdigit() and min_val <= int(n) <= max_val + # Strip L and W suffix if present (for last day/week expressions) + n_stripped = n.rstrip('LW') + return n_stripped and n_stripped.isdigit() and min_val <= int(n_stripped) <= max_val parts = re.split(r'[,\-/]', field) - return all(part == '*' or is_valid_number(part) for part in parts if part != '*') + return all(part == '*' or part in ('L', 'LW') or is_valid_number(part) for part in parts if part != '*') def validate_cron_expression(expr, for_rotation=False): parts = expr.strip().split() @@ -180,6 +184,7 @@ def __init__(self): self.register_command('connection', PAMConnectionCommand(), 'Manage Connections', 'n') self.register_command('rbi', PAMRbiCommand(), 'Manage Remote Browser Isolation', 'b') self.register_command('project', PAMProjectCommand(), 'PAM Project Import/Export', 'p') + self.register_command('launch', PAMLaunchCommand(), 'Launch a connection to a PAM resource', 'l') class PAMGatewayCommand(GroupCommand): @@ -298,11 +303,13 @@ def __init__(self): self.register_command('graph', PAMDebugGraphCommand(), 'Render graphs', 'r') # Disable for now. Needs more work. - # self.register_command('verify', PAMDebugVerifyCommand(), 'Verify graphs', 'v') + # self.register_command('verify', PAMDebugVerifyCommand(), 'Verify graphs') self.register_command('acl', PAMDebugACLCommand(), 'Control ACL of PAM Users', 'c') self.register_command('link', PAMDebugLinkCommand(), 'Link resource to configuration', 'l') self.register_command('rs-reset', PAMDebugRotationSettingsCommand(), 'Create/reset rotation settings', 'rs') + self.register_command('vertex', PAMDebugVertexCommand(), + 'Debug a graph vertex', 'v') class PAMLegacyCommand(Command): @@ -368,6 +375,10 @@ class PAMCreateRecordRotationCommand(Command): help='UID or path of the configuration record.') parser.add_argument('--iam-aad-config', '-iac', dest='iam_aad_config_uid', action='store', help='UID of a PAM Configuration. Used for an IAM or Azure AD user in place of --resource.') + parser.add_argument('--rotation-profile', '-rp', dest='rotation_profile', action='store', + choices=['general', 'iam_user', 'scripts_only'], + help='Rotation profile type: general (resource-based), iam_user (IAM/Azure user), ' + 'scripts_only (run PAM scripts only)') parser.add_argument('--resource', '-rs', dest='resource', action='store', help='UID or path of the resource record.') schedule_group = parser.add_mutually_exclusive_group() schedule_group.add_argument('--schedulejson', '-sj', required=False, dest='schedule_json_data', @@ -406,7 +417,8 @@ def config_resource(_dag, target_record, target_config_uid, silent=None): if not _dag.linking_dag.has_graph: # Add DAG for resource if target_config_uid: - _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_config_uid) + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_config_uid, + transmission_key=transmission_key) _dag.edit_tunneling_config(rotation=True) else: raise CommandError('', f'{bcolors.FAIL}Resource "{target_record.record_uid}" is not associated ' @@ -417,7 +429,7 @@ def config_resource(_dag, target_record, target_config_uid, silent=None): if not _dag.resource_belongs_to_config(target_record.record_uid): # Change DAG to this new configuration. resource_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, - target_record.record_uid) + target_record.record_uid, transmission_key=transmission_key) _dag.link_resource_to_config(target_record.record_uid) admin = kwargs.get('admin') @@ -467,6 +479,9 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): record_schedule_data = [] pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation.get('pwd_complexity', '')) record_resource_uid = current_record_rotation.get('resource_uid') + # IAM users have resource_uid == config_uid; should be empty to preserve rotation profile + if record_resource_uid == record_config_uid: + record_resource_uid = None disabled = current_record_rotation.get('disabled', False) schedule = 'On-Demand' @@ -510,10 +525,12 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): return if _dag and not _dag.linking_dag.has_graph: - _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_iam_aad_config_uid) + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_iam_aad_config_uid, + transmission_key=transmission_key) if not _dag or not _dag.linking_dag.has_graph: _dag.edit_tunneling_config(rotation=True) - old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_record.record_uid) + old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_record.record_uid, + transmission_key=transmission_key) if old_dag.linking_dag.has_graph and old_dag.record.record_uid != target_iam_aad_config_uid: old_dag.remove_from_dag(target_record.record_uid) @@ -669,6 +686,9 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None record_schedule_data = [] pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation.get('pwd_complexity', '')) record_resource_uid = current_record_rotation.get('resource_uid') + # IAM users have resource_uid == config_uid; should be empty to preserve rotation profile + if record_resource_uid == record_config_uid: + record_resource_uid = None disabled = current_record_rotation.get('disabled', False) schedule = 'On-Demand' @@ -735,7 +755,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None return if isinstance(target_resource_uid, str) and len(target_resource_uid) > 0: - _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_resource_uid) + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_resource_uid, + transmission_key=transmission_key) if not _dag or not _dag.linking_dag.has_graph: if target_config_uid and target_resource_uid: config_resource(_dag, target_record, target_config_uid, silent=silent) @@ -753,7 +774,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None current_record_rotation = params.record_rotation_cache.get(target_record.record_uid) if not _dag or not _dag.linking_dag.has_graph: - _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_resource_uid) + _dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, target_resource_uid, + transmission_key=transmission_key) if not _dag.linking_dag.has_graph: raise CommandError('', f'{bcolors.FAIL}Resource "{target_resource_uid}" is not associated ' f'with any configuration. ' @@ -864,9 +886,16 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None # Noop and resource cannot be both assigned if not noop_rotation: record_resource_uid = target_resource_uid + # IAM users are linked directly to config (target_resource_uid == config_uid) + # In this case, resourceUid should be empty to preserve IAM rotation profile + if record_resource_uid == _dag.record.record_uid: + record_resource_uid = None if record_resource_uid is None: if current_record_rotation: record_resource_uid = current_record_rotation.get('resource_uid') + # Also check if the cached resource_uid is actually the config UID + if record_resource_uid == record_config_uid: + record_resource_uid = None if record_resource_uid is None: resource_field = record_pam_config.get_typed_field('pamResources') if resource_field and isinstance(resource_field.value, list) and len(resource_field.value) > 0: @@ -1070,11 +1099,13 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None # use --schedule-only, -so to preserve individual setups (General, IAM, NOOP) # use --iam-aad-config, -iac IAM_AAD_CONFIG_UID to convert to IAM User for _record in pam_records: - tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, _record.record_uid) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, _record.record_uid, + transmission_key=transmission_key) if _record.record_type in ['pamMachine', 'pamDatabase', 'pamDirectory', 'pamRemoteBrowser']: config_resource(tmp_dag, _record, config_uid, silent=kwargs.get('silent')) elif _record.record_type == 'pamUser': iam_aad_config_uid = kwargs.get('iam_aad_config_uid') + rotation_profile = kwargs.get('rotation_profile') if iam_aad_config_uid and iam_aad_config_uid not in pam_configurations: raise CommandError('', f'Record uid {iam_aad_config_uid} is not a PAM Configuration record.') @@ -1084,8 +1115,32 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None ' --resource is used to configure users found on a resource.' ' --iam-aad-config-uid is used to configure AWS IAM or Azure AD users') + # Handle --rotation-profile option + if rotation_profile: + if rotation_profile == 'iam_user': + # Use iam_aad_config_uid if provided, otherwise try to get from current rotation or --config + effective_config_uid = iam_aad_config_uid or config_uid + if not effective_config_uid: + current_rotation = params.record_rotation_cache.get(_record.record_uid) + if current_rotation: + effective_config_uid = current_rotation.get('configuration_uid') + if not effective_config_uid: + raise CommandError('', 'IAM user rotation requires a PAM Configuration. ' + 'Use --config or --iam-aad-config to specify one.') + if effective_config_uid not in pam_configurations: + raise CommandError('', f'Record uid {effective_config_uid} is not a PAM Configuration record.') + config_iam_aad_user(tmp_dag, _record, effective_config_uid) + elif rotation_profile == 'scripts_only': + # Set noop flag for scripts_only profile + kwargs['noop'] = 'TRUE' + config_user(tmp_dag, _record, resource_uid, config_uid, silent=kwargs.get('silent')) + elif rotation_profile == 'general': + # General rotation requires a resource + if not resource_uid: + raise CommandError('', 'General rotation profile requires --resource to be specified.') + config_user(tmp_dag, _record, resource_uid, config_uid, silent=kwargs.get('silent')) # NB! --folder=UID without --iam-aad-config, or --schedule-only converts to General rotation - if iam_aad_config_uid: + elif iam_aad_config_uid: config_iam_aad_user(tmp_dag, _record, iam_aad_config_uid) else: config_user(tmp_dag, _record, resource_uid, config_uid, silent=kwargs.get('silent')) @@ -1652,7 +1707,7 @@ def execute(self, params, **kwargs): if format_type == 'table': # Only print tunneling config for table format encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_configuration_uid, - is_config=True) + is_config=True, transmission_key=transmission_key) tmp_dag.print_tunneling_config(pam_configuration_uid, None) @staticmethod @@ -2209,10 +2264,10 @@ def execute(self, params, **kwargs): pam_configuration_create_record_v6(params, record, shared_folder_uid) - encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) # Add DAG for configuration tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid=record.record_uid, - is_config=True) + is_config=True, transmission_key=transmission_key) tmp_dag.edit_tunneling_config( kwargs.get('connections'), kwargs.get('tunneling'), @@ -2362,9 +2417,9 @@ def execute(self, params, **kwargs): if (_connections is not None or _tunneling is not None or _rotation is not None or _rbi is not None or _recording is not None or _typescript_recording is not None or orig_admin_cred_ref != admin_cred_ref): - encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, - configuration.record_uid, is_config=True) + configuration.record_uid, is_config=True, transmission_key=transmission_key) tmp_dag.edit_tunneling_config(_connections, _tunneling, _rotation, _recording, _typescript_recording, _rbi) if orig_admin_cred_ref != admin_cred_ref: if orig_admin_cred_ref: # just drop is_admin from old Domain @@ -2405,7 +2460,7 @@ def execute(self, params, **kwargs): raise Exception(f'Configuration "{pam_config_uid}" not found') encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_config.record_uid, - is_config=True) + is_config=True, transmission_key=transmission_key) if tmp_dag.linking_dag.has_graph: tmp_dag.remove_from_dag(pam_config_uid) pam_configuration_remove(params, pam_config_uid) @@ -3028,7 +3083,8 @@ def record_rotate(self, params, record_uid, slient:bool = False): config_uid = facade.controller_uid if not resource_uid: - tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record.record_uid) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record.record_uid, + transmission_key=transmission_key) resource_uid = tmp_dag.get_resource_uid(record_uid) if not resource_uid: # NOOP records don't need resource_uid diff --git a/keepercommander/commands/enterprise.py b/keepercommander/commands/enterprise.py index f429b970b..a33e9591a 100644 --- a/keepercommander/commands/enterprise.py +++ b/keepercommander/commands/enterprise.py @@ -4351,8 +4351,7 @@ def _handle_api_error(self, error, domain, action, output_format): if output_format == 'json': return json.dumps({ - 'error': error_msg, - 'error_code': error_code, + 'message': error_msg, 'domain': domain, 'action': action, }, indent=2) diff --git a/keepercommander/commands/enterprise_push.py b/keepercommander/commands/enterprise_push.py index b5d917d1b..628d26bf3 100644 --- a/keepercommander/commands/enterprise_push.py +++ b/keepercommander/commands/enterprise_push.py @@ -272,7 +272,7 @@ def execute(self, params, **kwargs): record_no += 1 else: logging.warning('User: %s Transfer Record Error: (%s) %s', email, trec.status, trec.message) - logging.info('Pushed %d record(s) to \"%s\"', record_no, email) + logging.info('Pushed %d %s to "%s"', record_no, 'record' if record_no == 1 else 'records', email) if len(pre_delete_rq['objects']) > 0: pre_delete_rs = api.communicate(params, pre_delete_rq) diff --git a/keepercommander/commands/enterprise_reports.py b/keepercommander/commands/enterprise_reports.py index 38e5c09f7..aa49d49ba 100644 --- a/keepercommander/commands/enterprise_reports.py +++ b/keepercommander/commands/enterprise_reports.py @@ -12,12 +12,12 @@ def register_commands(commands): - commands['external-shares-report'] = ExternalSharesReportCommand() + commands['ext-shares-report'] = ExternalSharesReportCommand() commands['license-consumption-report'] = LicenseConsumptionReportCommand() def register_command_info(aliases, command_info): - aliases['esr'] = 'external-shares-report' + aliases['esr'] = 'ext-shares-report' aliases['lcr'] = 'license-consumption-report' for p in [external_share_report_parser, license_consumption_report_parser]: @@ -63,7 +63,7 @@ def get_feature_enforcements_from_constants(): ext_shares_report_desc = 'Run an external record sharing report' -external_share_report_parser = argparse.ArgumentParser(prog='external-shares-report', description=ext_shares_report_desc, +external_share_report_parser = argparse.ArgumentParser(prog='ext-shares-report', description=ext_shares_report_desc, parents=[base.report_output_parser]) external_share_report_parser.add_argument('-a', '--action', action='store', choices=['remove', 'none'], default='none', help='action to perform on external shares, \'none\' if omitted') diff --git a/keepercommander/commands/msp.py b/keepercommander/commands/msp.py index df73e6aea..320484cf8 100644 --- a/keepercommander/commands/msp.py +++ b/keepercommander/commands/msp.py @@ -59,7 +59,7 @@ def register_command_info(aliases, command_info): description='Download current MSP data from the Keeper Cloud') msp_info_parser = argparse.ArgumentParser(prog='msp-info', usage='msp-info', parents=[report_output_parser], - description='Displays MSP details, such as managed companies and pricing') + description='Displays MSP details, including MC info and pricing') msp_info_parser.add_argument('-p', '--pricing', dest='pricing', action='store_true', help='Display pricing information') msp_info_parser.add_argument('-r', '--restriction', dest='restriction', action='store_true', help='Display MSP restriction information') diff --git a/keepercommander/commands/pam/pam_dto.py b/keepercommander/commands/pam/pam_dto.py index b089914ee..b1fc1934f 100644 --- a/keepercommander/commands/pam/pam_dto.py +++ b/keepercommander/commands/pam/pam_dto.py @@ -77,12 +77,14 @@ def toJSON(self): class GatewayAction(metaclass=abc.ABCMeta): - def __init__(self, action, is_scheduled, gateway_destination=None, inputs=None, conversation_id=None): + def __init__(self, action, is_scheduled, gateway_destination=None, inputs=None, conversation_id=None, message_id=None): self.action = action self.is_scheduled = is_scheduled self.gateway_destination = gateway_destination self.inputs = inputs self.conversationId = conversation_id + # messageId is derived from conversationId for WebRTC sessions + self.messageId = message_id def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) @@ -96,6 +98,14 @@ def generate_conversation_id(is_bytes=False): message_id = CommonHelperMethods.bytes_to_url_safe_str(message_id_bytes) return message_id + @staticmethod + def conversation_id_to_message_id(conversation_id): + """Convert conversationId to messageId format (replace + with -, / with _)""" + if conversation_id: + # Remove any padding '=' characters and replace special chars + return conversation_id.rstrip('=').replace('+', '-').replace('/', '_') + return None + class GatewayActionGatewayInfo(GatewayAction): @@ -213,8 +223,8 @@ def toJSON(self): class GatewayActionWebRTCSession(GatewayAction): - def __init__(self, inputs: dict,conversation_id=None): - super().__init__('webrtc-session', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + def __init__(self, inputs: dict, conversation_id=None, message_id=None): + super().__init__('webrtc-session', inputs=inputs, conversation_id=conversation_id, message_id=message_id, is_scheduled=False) def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) diff --git a/keepercommander/commands/pam/router_helper.py b/keepercommander/commands/pam/router_helper.py index fd3de1ffe..b5499783b 100644 --- a/keepercommander/commands/pam/router_helper.py +++ b/keepercommander/commands/pam/router_helper.py @@ -91,6 +91,15 @@ def router_set_record_rotation_information(params, proto_request, transmission_k return rs +def router_configure_resource(params, proto_request, transmission_key=None, + encrypted_transmission_key=None, encrypted_session_token=None): + rs = _post_request_to_router(params, 'configure_resource', proto_request, transmission_key=transmission_key, + encrypted_transmission_key=encrypted_transmission_key, + encrypted_session_token=encrypted_session_token) + + return rs + + def router_get_rotation_schedules(params, proto_request): return _post_request_to_router(params, 'get_rotation_schedules', rq_proto=proto_request, rs_type=pam_pb2.PAMRotationSchedulesResponse) diff --git a/keepercommander/commands/pam_debug/graph.py b/keepercommander/commands/pam_debug/graph.py index 2904b963b..4e8a65988 100644 --- a/keepercommander/commands/pam_debug/graph.py +++ b/keepercommander/commands/pam_debug/graph.py @@ -10,10 +10,10 @@ from ...discovery_common.user_service import UserService from ...discovery_common.jobs import Jobs from ...discovery_common.constants import (PAM_USER, PAM_DIRECTORY, PAM_MACHINE, PAM_DATABASE, VERTICES_SORT_MAP, - DIS_INFRA_GRAPH_ID, RECORD_LINK_GRAPH_ID, USER_SERVICE_GRAPH_ID, - DIS_JOBS_GRAPH_ID) + DIS_INFRA_GRAPH_ID, RECORD_LINK_GRAPH_ID, USER_SERVICE_GRAPH_ID, + DIS_JOBS_GRAPH_ID) from ...discovery_common.types import (DiscoveryObject, DiscoveryUser, DiscoveryDirectory, DiscoveryMachine, - DiscoveryDatabase, JobContent) + DiscoveryDatabase, JobContent) from ...discovery_common.dag_sort import sort_infra_vertices from ...keeper_dag import DAG from ...keeper_dag.connection.commander import Connection as CommanderConnection @@ -228,6 +228,8 @@ def _group(configuration_vertex: DAGVertex) -> dict: if acl is None: print(f"{pad} {self._f('missing ACL')}") else: + if acl.is_iam_user is True: + print(f"{pad} . is IAM user") if acl.is_admin is True: print(f"{pad} . is the {self._b('Admin')}") if acl.belongs_to is True: @@ -240,6 +242,12 @@ def _group(configuration_vertex: DAGVertex) -> dict: print(f"{pad} . is a NOOP") if acl.rotation_settings.disabled is True: print(f"{pad} . rotation is disabled") + + if (acl.rotation_settings.saas_record_uid_list is not None + and len(acl.rotation_settings.saas_record_uid_list) > 0): + print(f"{pad} . has SaaS rotation: " + f"{acl.rotation_settings.saas_record_uid_list[0]}") + continue if vertex.has_data is True: @@ -282,7 +290,7 @@ def _group(configuration_vertex: DAGVertex) -> dict: except Exception as err: print(f"{pad} ! data not JSON: {err}") for i in bad: - print("{pad} " + i) + print(f"{pad} " + i) if len(group[PAMDebugGraphCommand.OTHER]) > 0: print(f"{pad} " + self._b("Other PAM Types")) @@ -302,7 +310,6 @@ def _group(configuration_vertex: DAGVertex) -> dict: vertex = item.get("v") # type: DAGVertex print(f"{pad} * {vertex.uid}") - def _do_text_list_service(self, params: KeeperParams, gateway_context: GatewayContext, debug_level: int = 0, indent: int = 0): diff --git a/keepercommander/commands/pam_debug/info.py b/keepercommander/commands/pam_debug/info.py index d31a281ea..75da7e09c 100644 --- a/keepercommander/commands/pam_debug/info.py +++ b/keepercommander/commands/pam_debug/info.py @@ -472,6 +472,13 @@ def _print_field(f): print(f" * {task.name} = {task.user}") else: print(" Machines has no schedules tasks that are using non-builtin users.") + + print(f" {self._b('IIS Pools')} (Non Builtin Users)") + if len(content.item.facts.iis_pools) > 0: + for iis_pool in content.item.facts.iis_pools: + print(f" * {iis_pool.name} = {iis_pool.user}") + else: + print(" Machines has no IIS Pools that are using non-builtin users.") else: print(f"{bcolors.FAIL} Machine facts are not set. Discover inside may not have been " f"performed.{bcolors.ENDC}") diff --git a/keepercommander/commands/pam_debug/vertex.py b/keepercommander/commands/pam_debug/vertex.py new file mode 100644 index 000000000..7d1a25692 --- /dev/null +++ b/keepercommander/commands/pam_debug/vertex.py @@ -0,0 +1,198 @@ + +from __future__ import annotations +import argparse +from ..discover import PAMGatewayActionDiscoverCommandBase, GatewayContext +from ...display import bcolors +from ... import vault, vault_extensions +from ...discovery_common.infrastructure import Infrastructure +from ...discovery_common.record_link import RecordLink +from ...discovery_common.user_service import UserService +from ...discovery_common.types import UserAcl, DiscoveryObject +from ...discovery_common.constants import PAM_USER, PAM_MACHINE, PAM_DATABASE, PAM_DIRECTORY +from ...keeper_dag import EdgeType +import time +import re +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ...vault import TypedRecord + from ...params import KeeperParams + + +class PAMDebugVertexCommand(PAMGatewayActionDiscoverCommandBase): + parser = argparse.ArgumentParser(prog='pam-action-debug-info') + + type_name_map = { + PAM_USER: "PAM User", + PAM_MACHINE: "PAM Machine", + PAM_DATABASE: "PAM Database", + PAM_DIRECTORY: "PAM Directory", + } + + # The record to base everything on. + parser.add_argument('--gateway', '-g', required=True, dest='gateway', action='store', + help='Gateway name or UID') + parser.add_argument('--vertex', '-i', required=True, dest='vertex_uid', action='store', + help='Vertex in infrastructure graph') + + def get_parser(self): + return PAMDebugVertexCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + + gateway = kwargs.get("gateway") + debug_level = kwargs.get("debug_level", False) + + gateway_context = GatewayContext.from_gateway(params, gateway) + if gateway_context is None: + print(f"{bcolors.FAIL}Could not find the gateway configuration for {gateway}.") + return + + infra = Infrastructure(record=gateway_context.configuration, params=params, fail_on_corrupt=False, + debug_level=debug_level) + infra.load() + + vertex_uid = kwargs.get("vertex_uid") + vertex = infra.dag.get_vertex(vertex_uid) + if vertex is None: + print(f"{bcolors.FAIL}Could not find the vertex in the graph for {gateway}.") + return + + content = DiscoveryObject.get_discovery_object(vertex) + missing_since = "NA" + if content.missing_since_ts is not None: + missing_since = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(content.missing_since_ts)) + + print(self._h("Discovery Object Information")) + print(f" {self._b('Vertex UID')}: {content.uid}") + print(f" {self._b('Object ID')}: {content.id}") + print(f" {self._b('Record UID')}: {content.record_uid}") + print(f" {self._b('Parent Record UID')}: {content.parent_record_uid}") + print(f" {self._b('Shared Folder UID')}: {content.shared_folder_uid}") + print(f" {self._b('Record Type')}: {content.record_type}") + print(f" {self._b('Object Type')}: {content.object_type_value}") + print(f" {self._b('Ignore Object')}: {content.ignore_object}") + print(f" {self._b('Rule Engine Result')}: {content.action_rules_result}") + print(f" {self._b('Name')}: {content.name}") + print(f" {self._b('Generated Title')}: {content.title}") + print(f" {self._b('Generated Description')}: {content.description}") + print(f" {self._b('Missing Since')}: {missing_since}") + print(f" {self._b('Discovery Notes')}:") + for note in content.notes: + print(f" * {note}") + if content.error is not None: + print(f"{bcolors.FAIL} Error: {content.error}{bcolors.ENDC}") + if content.stacktrace is not None: + print(f"{bcolors.FAIL} Stack Trace:{bcolors.ENDC}") + print(f"{bcolors.FAIL}{content.stacktrace}{bcolors.ENDC}") + print("") + print(f"{bcolors.HEADER}Record Type Specifics{bcolors.ENDC}") + + if content.record_type == PAM_USER: + print(f" {self._b('User')}: {content.item.user}") + print(f" {self._b('DN')}: {content.item.dn}") + print(f" {self._b('Database')}: {content.item.database}") + print(f" {self._b('Active')}: {content.item.active}") + print(f" {self._b('Expired')}: {content.item.expired}") + print(f" {self._b('Source')}: {content.item.source}") + elif content.record_type == PAM_MACHINE: + print(f" {self._b('Host')}: {content.item.host}") + print(f" {self._b('IP')}: {content.item.ip}") + print(f" {self._b('Port')}: {content.item.port}") + print(f" {self._b('Operating System')}: {content.item.os}") + print(f" {self._b('Provider Region')}: {content.item.provider_region}") + print(f" {self._b('Provider Group')}: {content.item.provider_group}") + print(f" {self._b('Is the Gateway')}: {content.item.is_gateway}") + print(f" {self._b('Allows Admin')}: {content.item.allows_admin}") + print(f" {self._b('Admin Reason')}: {content.item.admin_reason}") + print("") + # If facts are not set, inside discover may not have been performed for the machine. + if content.item.facts.id is not None and content.item.facts.name is not None: + print(f" {self._b('Machine Name')}: {content.item.facts.name}") + print(f" {self._b('Machine ID')}: {content.item.facts.id.machine_id}") + print(f" {self._b('Product ID')}: {content.item.facts.id.product_id}") + print(f" {self._b('Board Serial')}: {content.item.facts.id.board_serial}") + print(f" {self._b('Directories')}:") + if content.item.facts.directories is not None and len(content.item.facts.directories) > 0: + for directory in content.item.facts.directories: + print(f" * Directory Domain: {directory.domain}") + print(f" Software: {directory.software}") + print(f" Login Format: {directory.login_format}") + else: + print(" Machines is not using any directories.") + + print("") + print(f" {self._b('Services')} (Non Builtin Users):") + if len(content.item.facts.services) > 0: + for service in content.item.facts.services: + print(f" * {service.name} = {service.user}") + else: + print(" Machines has no services that are using non-builtin users.") + + print(f" {self._b('Scheduled Tasks')} (Non Builtin Users)") + if len(content.item.facts.tasks) > 0: + for task in content.item.facts.tasks: + print(f" * {task.name} = {task.user}") + else: + print(" Machines has no schedules tasks that are using non-builtin users.") + + print(f" {self._b('IIS Pools')} (Non Builtin Users)") + if len(content.item.facts.iis_pools) > 0: + for iis_pool in content.item.facts.iis_pools: + print(f" * {iis_pool.name} = {iis_pool.user}") + else: + print(" Machines has no IIS Pools that are using non-builtin users.") + + else: + print(f"{bcolors.FAIL} Machine facts are not set. Discover inside may not have been " + f"performed.{bcolors.ENDC}") + elif content.record_type == PAM_DATABASE: + print(f" {self._b('Host')}: {content.item.host}") + print(f" {self._b('IP')}: {content.item.ip}") + print(f" {self._b('Port')}: {content.item.port}") + print(f" {self._b('Database Type')}: {content.item.type}") + print(f" {self._b('Database')}: {content.item.database}") + print(f" {self._b('Use SSL')}: {content.item.use_ssl}") + print(f" {self._b('Provider Region')}: {content.item.provider_region}") + print(f" {self._b('Provider Group')}: {content.item.provider_group}") + print(f" {self._b('Allows Admin')}: {content.item.allows_admin}") + print(f" {self._b('Admin Reason')}: {content.item.admin_reason}") + elif content.record_type == PAM_DIRECTORY: + print(f" {self._b('Host')}: {content.item.host}") + print(f" {self._b('IP')}: {content.item.ip}") + print(f" {self._b('Port')}: {content.item.port}") + print(f" {self._b('Directory Type')}: {content.item.type}") + print(f" {self._b('Use SSL')}: {content.item.use_ssl}") + print(f" {self._b('Provider Region')}: {content.item.provider_region}") + print(f" {self._b('Provider Group')}: {content.item.provider_group}") + print(f" {self._b('Allows Admin')}: {content.item.allows_admin}") + print(f" {self._b('Admin Reason')}: {content.item.admin_reason}") + + print("") + print(self._h("Belongs To Vertices (Parents)")) + vertices = vertex.belongs_to_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + print(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + print(f" . {edge_type}, active: {edge.active}") + + if len(vertices) == 0: + print(f"{bcolors.FAIL} Does not belong to anyone{bcolors.ENDC}") + + print("") + print(f"{bcolors.HEADER}Vertices Belonging To (Children){bcolors.ENDC}") + vertices = vertex.has_vertices() + for vertex in vertices: + content = DiscoveryObject.get_discovery_object(vertex) + print(f" * {content.description} ({vertex.uid})") + for edge_type in [EdgeType.LINK, EdgeType.ACL, EdgeType.KEY, EdgeType.DELETION]: + edge = vertex.get_edge(vertex, edge_type=edge_type) + if edge is not None: + print(f" . {edge_type}, active: {edge.active}") + if len(vertices) == 0: + print(f" Does not have any children.") + + print("") diff --git a/keepercommander/commands/pam_import/edit.py b/keepercommander/commands/pam_import/edit.py index acc5e12a9..9a3959dca 100644 --- a/keepercommander/commands/pam_import/edit.py +++ b/keepercommander/commands/pam_import/edit.py @@ -602,8 +602,9 @@ def generate_discovery_playground_data(self, params, project: dict): resources_folder_uid = project["folders"]["resources_folder_uid"] pam_config_uid = project["pam_config"]["pam_config_uid"] - encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_config_uid, True) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_config_uid, True, + transmission_key=transmission_key) # if not tdag.check_tunneling_enabled_config(enable_connections=True): # logging.warning(f"{bcolors.WARNING}Warning: {bcolors.ENDC} Connections are disabled by PAM Configuration!") # Fix: Rotation is disabled by the PAM configuration. @@ -1607,8 +1608,9 @@ def process_data(self, params, project): # return print("Started importing data...") - encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_cfg_uid, True) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_cfg_uid, True, + transmission_key=transmission_key) pte = PAMTunnelEditCommand() prc = PAMCreateRecordRotationCommand() diff --git a/keepercommander/commands/pam_launch/__init__.py b/keepercommander/commands/pam_launch/__init__.py new file mode 100644 index 000000000..06cc4970d --- /dev/null +++ b/keepercommander/commands/pam_launch/__init__.py @@ -0,0 +1,17 @@ +from __future__ import annotations +from ...utils import value_to_boolean +import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ...params import KeeperParams + from ...keeper_dag.connection import ConnectionBase + + +def get_connection(params: KeeperParams) -> ConnectionBase: + if value_to_boolean(os.environ.get("USE_LOCAL_DAG", False)) is False: + from ...keeper_dag.connection.commander import Connection as CommanderConnection + return CommanderConnection(params=params) + else: + from ...keeper_dag.connection.local import Connection as LocalConnection + return LocalConnection() diff --git a/keepercommander/commands/pam_launch/guac_cli/__init__.py b/keepercommander/commands/pam_launch/guac_cli/__init__.py new file mode 100644 index 000000000..34e31b416 --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/__init__.py @@ -0,0 +1,42 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' callback mapping + + def register_handler(self, opcode: str, callback: Callable[[GuacInstruction], None]): + """Register a handler for a specific opcode""" + self.handlers[opcode] = callback + + def feed(self, data: bytes) -> List[GuacInstruction]: + """ + Feed raw bytes from the data channel and parse instructions. + + Args: + data: Raw bytes from Guacamole server + + Returns: + List of parsed instructions + """ + try: + # Guacamole protocol is UTF-8 text + text = data.decode('utf-8') + self.buffer += text + except UnicodeDecodeError as e: + logging.warning(f"Failed to decode Guacamole data: {e}") + return [] + + instructions = [] + + # Parse complete instructions (terminated by semicolon) + while ';' in self.buffer: + idx = self.buffer.index(';') + instruction_text = self.buffer[:idx] + self.buffer = self.buffer[idx + 1:] + + # Parse the instruction + instruction = self._parse_instruction(instruction_text) + if instruction: + instructions.append(instruction) + + # Call registered handler if exists + if instruction.opcode in self.handlers: + try: + self.handlers[instruction.opcode](instruction) + except Exception as e: + logging.error(f"Error in handler for {instruction.opcode}: {e}") + + return instructions + + def _parse_instruction(self, text: str) -> Optional[GuacInstruction]: + """ + Parse a single instruction from text format. + + Format: length.value,length.value,... + Example: "4.sync,8.12345678" -> sync instruction with timestamp + + Args: + text: Raw instruction text (without semicolon) + + Returns: + Parsed GuacInstruction or None if parsing fails + """ + if not text: + return None + + try: + elements = [] + remaining = text + + while remaining: + # Find the length prefix + if '.' not in remaining: + break + + dot_idx = remaining.index('.') + length_str = remaining[:dot_idx] + + try: + length = int(length_str) + except ValueError: + logging.warning(f"Invalid length prefix in Guacamole instruction: {length_str}") + return None + + # Extract the value + value_start = dot_idx + 1 + value_end = value_start + length + + if value_end > len(remaining): + logging.warning(f"Truncated Guacamole instruction: expected {length} bytes") + return None + + value = remaining[value_start:value_end] + elements.append(value) + + # Move to next element + remaining = remaining[value_end:] + if remaining.startswith(','): + remaining = remaining[1:] + + if not elements: + return None + + # First element is the opcode + opcode = elements[0] + args = elements[1:] + + return GuacInstruction(opcode, args) + + except Exception as e: + logging.error(f"Error parsing Guacamole instruction '{text}': {e}") + return None + + def encode_instruction(self, opcode: str, *args) -> bytes: + """ + Encode a Guacamole instruction to send to the server. + + Args: + opcode: Instruction opcode + *args: Instruction arguments + + Returns: + Encoded instruction as bytes + """ + elements = [opcode] + list(args) + encoded_parts = [] + + for element in elements: + element_str = str(element) + encoded_parts.append(f"{len(element_str)}.{element_str}") + + instruction = ','.join(encoded_parts) + ';' + return instruction.encode('utf-8') + + def encode_key(self, keycode: int, pressed: bool) -> bytes: + """ + Encode a keyboard event. + + Args: + keycode: X11 keysym value + pressed: True for press, False for release + + Returns: + Encoded key instruction + """ + return self.encode_instruction('key', str(keycode), '1' if pressed else '0') + + def encode_mouse(self, x: int, y: int, button_mask: int) -> bytes: + """ + Encode a mouse event. + + Args: + x: X coordinate + y: Y coordinate + button_mask: Bitmask of pressed buttons + + Returns: + Encoded mouse instruction + """ + return self.encode_instruction('mouse', str(x), str(y), str(button_mask)) + + def encode_size(self, width: int, height: int) -> bytes: + """ + Encode a terminal size change. + + Args: + width: Terminal width in characters + height: Terminal height in characters + + Returns: + Encoded size instruction + """ + # Size instruction: size,layer,width,height + # Layer 0 is the root layer + return self.encode_instruction('size', '0', str(width), str(height)) + + def encode_clipboard(self, mimetype: str, data: str) -> bytes: + """ + Encode clipboard data. + + Args: + mimetype: MIME type (typically "text/plain") + data: Clipboard text + + Returns: + Encoded clipboard instruction + """ + return self.encode_instruction('clipboard', mimetype, data) + + def encode_sync(self, timestamp: str) -> bytes: + """ + Encode a sync acknowledgment. + + Args: + timestamp: Timestamp from server's sync instruction + + Returns: + Encoded sync instruction + """ + return self.encode_instruction('sync', timestamp) + + +# X11 keysym mappings for common keys +# Reference: https://www.x.org/releases/X11R7.7/doc/xproto/x11protocol.html#keysym_encoding +class X11Keysym: + """X11 keysym values for common keyboard keys""" + + # Control keys + BACKSPACE = 0xFF08 + TAB = 0xFF09 + RETURN = 0xFF0D + ESCAPE = 0xFF1B + DELETE = 0xFFFF + + # Cursor movement + HOME = 0xFF50 + LEFT = 0xFF51 + UP = 0xFF52 + RIGHT = 0xFF53 + DOWN = 0xFF54 + PAGE_UP = 0xFF55 + PAGE_DOWN = 0xFF56 + END = 0xFF57 + + # Function keys + F1 = 0xFFBE + F2 = 0xFFBF + F3 = 0xFFC0 + F4 = 0xFFC1 + F5 = 0xFFC2 + F6 = 0xFFC3 + F7 = 0xFFC4 + F8 = 0xFFC5 + F9 = 0xFFC6 + F10 = 0xFFC7 + F11 = 0xFFC8 + F12 = 0xFFC9 + + # Modifiers + SHIFT_L = 0xFFE1 + SHIFT_R = 0xFFE2 + CONTROL_L = 0xFFE3 + CONTROL_R = 0xFFE4 + CAPS_LOCK = 0xFFE5 + META_L = 0xFFE7 + META_R = 0xFFE8 + ALT_L = 0xFFE9 + ALT_R = 0xFFEA + + # ASCII printable range (0x20-0x7E) maps directly + # For example: 'A' = 0x41, 'a' = 0x61, '0' = 0x30 + + @staticmethod + def from_char(ch: str) -> int: + """Convert a single character to X11 keysym""" + if len(ch) == 1: + return ord(ch) + return 0 + diff --git a/keepercommander/commands/pam_launch/guac_cli/input.py b/keepercommander/commands/pam_launch/guac_cli/input.py new file mode 100644 index 000000000..38a8dc78c --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/input.py @@ -0,0 +1,327 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Optional[str]: + """ + Read an ANSI escape sequence from stdin. + + Returns: + Escape sequence string (without ESC prefix) or None + """ + seq = "" + for _ in range(5): # Read up to 5 characters + ch = self.stdin_reader.read_char(timeout=0.05) + if ch: + seq += ch + # Common sequences end with a letter + if ch.isalpha() or ch == '~': + break + else: + break + + return seq if seq else None + + def _escape_to_keysym(self, seq: str) -> Optional[int]: + """ + Map an ANSI escape sequence to X11 keysym. + + Args: + seq: Escape sequence (without ESC prefix) + + Returns: + X11 keysym or None + """ + # Common escape sequences + mappings = { + '[A': X11Keysym.UP, + '[B': X11Keysym.DOWN, + '[C': X11Keysym.RIGHT, + '[D': X11Keysym.LEFT, + '[H': X11Keysym.HOME, + '[F': X11Keysym.END, + '[1~': X11Keysym.HOME, + '[2~': 0xFFFF, # Insert + '[3~': X11Keysym.DELETE, + '[4~': X11Keysym.END, + '[5~': X11Keysym.PAGE_UP, + '[6~': X11Keysym.PAGE_DOWN, + 'OP': X11Keysym.F1, + 'OQ': X11Keysym.F2, + 'OR': X11Keysym.F3, + 'OS': X11Keysym.F4, + '[15~': X11Keysym.F5, + '[17~': X11Keysym.F6, + '[18~': X11Keysym.F7, + '[19~': X11Keysym.F8, + '[20~': X11Keysym.F9, + '[21~': X11Keysym.F10, + '[23~': X11Keysym.F11, + '[24~': X11Keysym.F12, + } + + return mappings.get(seq) + + def _control_char_to_keysym(self, ch: str) -> Optional[int]: + """ + Map control character to X11 keysym. + + Args: + ch: Control character + + Returns: + X11 keysym or None + """ + code = ord(ch) + + # Common control characters + if code == 8: # Backspace (Ctrl+H) + return X11Keysym.BACKSPACE + elif code == 9: # Tab + return X11Keysym.TAB + elif code == 10: # Line feed (Enter on Unix) + return X11Keysym.RETURN + elif code == 13: # Carriage return (Enter on Windows) + return X11Keysym.RETURN + elif code == 27: # ESC + return X11Keysym.ESCAPE + else: + # Ctrl+letter combinations (Ctrl+A = 1, Ctrl+B = 2, etc.) + # Send as lowercase letter with Ctrl modifier + # For simplicity, just send the control character as-is + # Guacamole can interpret it + return code + + def _send_key(self, keysym: int): + """ + Send a key press and release event. + + Args: + keysym: X11 keysym value + """ + # Send key press + self.key_callback(keysym, True) + + # Send key release + self.key_callback(keysym, False) + + +class UnixStdinReader: + """Unix/Linux stdin reader with raw mode support""" + + def __init__(self): + self.old_settings = None + + def set_raw_mode(self): + """Set terminal to raw mode (non-buffered, non-echoing)""" + try: + import termios + import tty + self.old_settings = termios.tcgetattr(sys.stdin.fileno()) + tty.setraw(sys.stdin.fileno()) + except Exception as e: + logging.warning(f"Failed to set raw mode: {e}") + + def restore(self): + """Restore terminal to normal mode""" + if self.old_settings: + try: + import termios + termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, self.old_settings) + except Exception as e: + logging.warning(f"Failed to restore terminal: {e}") + self.old_settings = None + + def read_char(self, timeout: Optional[float] = None) -> Optional[str]: + """ + Read a single character from stdin. + + Args: + timeout: Read timeout in seconds (None = blocking) + + Returns: + Character or None if timeout + """ + if timeout: + import select + ready, _, _ = select.select([sys.stdin], [], [], timeout) + if not ready: + return None + + try: + return sys.stdin.read(1) + except: + return None + + +class WindowsStdinReader: + """Windows stdin reader with raw mode support""" + + def __init__(self): + self.old_mode = None + + def set_raw_mode(self): + """Set console to raw mode on Windows""" + try: + import msvcrt + # Windows console is already non-buffered for getch + pass + except: + pass + + def restore(self): + """Restore console mode""" + pass + + def read_char(self, timeout: Optional[float] = None) -> Optional[str]: + """ + Read a single character from stdin on Windows. + + Args: + timeout: Read timeout in seconds (None = blocking) + + Returns: + Character or None if timeout + """ + try: + import msvcrt + + if timeout: + # Poll for input with timeout + import time + start = time.time() + while time.time() - start < timeout: + if msvcrt.kbhit(): + ch = msvcrt.getch() + return ch.decode('utf-8', errors='ignore') + time.sleep(0.01) + return None + else: + # Blocking read + ch = msvcrt.getch() + return ch.decode('utf-8', errors='ignore') + except Exception as e: + logging.error(f"Error reading from stdin on Windows: {e}") + return None + diff --git a/keepercommander/commands/pam_launch/guac_cli/instructions.py b/keepercommander/commands/pam_launch/guac_cli/instructions.py new file mode 100644 index 000000000..6c6c7c1f9 --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/instructions.py @@ -0,0 +1,568 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' None: + """ + Handle sync instruction - Frame synchronization. + + This is critical - the server waits for sync acknowledgment. + Note: The actual sync response is sent by the caller (GuacamoleHandler). + + Args: + args: [timestamp] or [timestamp, frames] + """ + timestamp = args[0] if args else "?" + frames = args[1] if len(args) > 1 else "0" + logging.debug(f"[SYNC] timestamp={timestamp}, frames={frames}") + + +def handle_name(args: List[str]) -> None: + """ + Handle name instruction - Connection name/title. + + Args: + args: [name] + """ + name = args[0] if args else "?" + logging.debug(f"[NAME] {name}") + + +def handle_size(args: List[str]) -> None: + """ + Handle size instruction - Screen/terminal size. + + Size can have different formats: + - size,layer,width,height (3 args) + - size,layer,width,height,dpi (4 args) + - size,width,height (2 args - client sending to server) + """ + if len(args) == 2: + width, height = args + logging.debug(f"[SIZE] {width}x{height}") + elif len(args) >= 3: + layer, width, height = args[0], args[1], args[2] + dpi = args[3] if len(args) > 3 else "96" + logging.debug(f"[SIZE] layer={layer}, {width}x{height} @ {dpi}dpi") + else: + logging.debug(f"[SIZE] {args}") + + +def handle_png(args: List[str]) -> None: + """ + Handle png instruction - PNG image data. + + Args: + args: [channel, layer, x, y, ...data_args] + """ + if len(args) < 4: + logging.debug(f"[PNG] {args}") + return + + channel, layer, x, y = args[0], args[1], args[2], args[3] + data_args = args[4:] + + if data_args: + data = data_args[0] + if len(data) > 16: + try: + hex_preview = data[:16].encode('utf-8').hex() + except: + hex_preview = str(data[:16]) + logging.debug(f"[PNG] channel={channel}, layer={layer}, pos=({x},{y}), data=[{hex_preview}...] ({len(data)} chars)") + else: + logging.debug(f"[PNG] channel={channel}, layer={layer}, pos=({x},{y}), data={data}") + else: + logging.debug(f"[PNG] channel={channel}, layer={layer}, pos=({x},{y})") + + +def handle_jpeg(args: List[str]) -> None: + """ + Handle jpeg instruction - JPEG image data. + + Args: + args: [channel, layer, x, y, ...data_args] + """ + if len(args) < 4: + logging.debug(f"[JPEG] {args}") + return + + channel, layer, x, y = args[0], args[1], args[2], args[3] + data_args = args[4:] + + if data_args: + data = data_args[0] + if len(data) > 16: + try: + hex_preview = data[:16].encode('utf-8').hex() + except: + hex_preview = str(data[:16]) + logging.debug(f"[JPEG] channel={channel}, layer={layer}, pos=({x},{y}), data=[{hex_preview}...] ({len(data)} chars)") + else: + logging.debug(f"[JPEG] channel={channel}, layer={layer}, pos=({x},{y}), data={data}") + else: + logging.debug(f"[JPEG] channel={channel}, layer={layer}, pos=({x},{y})") + + +def handle_img(args: List[str]) -> None: + """ + Handle img instruction - Streamed image data. + + Args: + args: [stream, channel, layer, mimetype, x, y] + """ + if len(args) >= 6: + stream, channel, layer, mimetype, x, y = args[0], args[1], args[2], args[3], args[4], args[5] + logging.debug(f"[IMG] stream={stream}, channel={channel}, layer={layer}, type={mimetype}, pos=({x},{y})") + else: + logging.debug(f"[IMG] {args}") + + +def handle_cursor(args: List[str]) -> None: + """ + Handle cursor instruction - Cursor position/image. + + Args: + args: [x, y, ...additional params] + """ + if len(args) >= 2: + x, y = args[0], args[1] + extra = args[2:] if len(args) > 2 else [] + if extra: + logging.debug(f"[CURSOR] hotspot=({x},{y}), args={extra}") + else: + logging.debug(f"[CURSOR] pos=({x},{y})") + else: + logging.debug(f"[CURSOR] {args}") + + +def handle_move(args: List[str]) -> None: + """ + Handle move instruction - Move layer. + + Args: + args: [layer, parent, x, y, z] + """ + if len(args) >= 5: + layer, parent, x, y, z = args[0], args[1], args[2], args[3], args[4] + logging.debug(f"[MOVE] layer={layer}, parent={parent}, pos=({x},{y}), z={z}") + else: + logging.debug(f"[MOVE] {args}") + + +def handle_rect(args: List[str]) -> None: + """ + Handle rect instruction - Draw rectangle path. + + Args: + args: [layer, x, y, width, height] + """ + if len(args) >= 5: + layer, x, y, width, height = args[0], args[1], args[2], args[3], args[4] + logging.debug(f"[RECT] layer={layer}, rect=({x},{y},{width},{height})") + else: + logging.debug(f"[RECT] {args}") + + +def handle_cfill(args: List[str]) -> None: + """ + Handle cfill instruction - Fill with color. + + Args: + args: [channel, layer, r, g, b, a] + """ + if len(args) >= 6: + channel, layer, r, g, b, a = args[0], args[1], args[2], args[3], args[4], args[5] + logging.debug(f"[CFILL] channel={channel}, layer={layer}, color=rgba({r},{g},{b},{a})") + else: + logging.debug(f"[CFILL] {args}") + + +def handle_copy(args: List[str]) -> None: + """ + Handle copy instruction - Copy rectangle between layers. + + Args: + args: [src_layer, src_x, src_y, width, height, channel, dst_layer, dst_x, dst_y] + """ + if len(args) >= 9: + src_layer, src_x, src_y, width, height = args[0], args[1], args[2], args[3], args[4] + channel, dst_layer, dst_x, dst_y = args[5], args[6], args[7], args[8] + logging.debug(f"[COPY] from layer={src_layer} ({src_x},{src_y},{width},{height}) to layer={dst_layer} ({dst_x},{dst_y})") + else: + logging.debug(f"[COPY] {args}") + + +def handle_clipboard(args: List[str]) -> None: + """ + Handle clipboard instruction - Clipboard data stream. + + Args: + args: [stream, mimetype] + """ + if len(args) >= 2: + stream, mimetype = args[0], args[1] + logging.debug(f"[CLIPBOARD] stream={stream}, type={mimetype}") + else: + logging.debug(f"[CLIPBOARD] {args}") + + +def handle_ack(args: List[str]) -> None: + """ + Handle ack instruction - Acknowledgment. + + Args: + args: [stream, message, code] + """ + if len(args) >= 3: + stream, message, code = args[0], args[1], args[2] + logging.debug(f"[ACK] stream={stream}, message='{message}', code={code}") + else: + logging.debug(f"[ACK] {args}") + + +def handle_error(args: List[str]) -> None: + """ + Handle error instruction - Error message from server. + + Args: + args: [message, code] + """ + if len(args) >= 2: + message, code = args[0], args[1] + logging.error(f"[ERROR] code={code}, message='{message}'") + else: + logging.error(f"[ERROR] {args}") + + +def handle_disconnect(args: List[str]) -> None: + """ + Handle disconnect instruction - Server disconnecting. + + Args: + args: Optional disconnect parameters + """ + logging.debug(f"[DISCONNECT] {args if args else ''}") + + +def handle_mouse(args: List[str]) -> None: + """ + Handle mouse instruction - Mouse position (server-side cursor). + + Args: + args: [x, y] + """ + # Don't print mouse movements to avoid spam + if len(args) >= 2: + x, y = args[0], args[1] + logging.debug(f"MOUSE: ({x},{y})") + + +def handle_blob(args: List[str]) -> None: + """ + Handle blob instruction - Binary blob data for stream. + + Args: + args: [stream, data] + """ + if len(args) >= 2: + stream, data = args[0], args[1] + data_preview = data[:16] if len(data) > 16 else data + logging.debug(f"[BLOB] stream={stream}, data=[{data_preview}...] ({len(data)} chars)") + else: + logging.debug(f"[BLOB] {args}") + + +def handle_end(args: List[str]) -> None: + """ + Handle end instruction - End of stream. + + Args: + args: [stream] + """ + stream = args[0] if args else "?" + logging.debug(f"[END] stream={stream}") + + +def handle_pipe(args: List[str]) -> None: + """ + Handle pipe instruction - Named pipe stream. + + For SSH/TTY sessions, the server sends pipe with name "STDOUT" to + indicate terminal output will follow via blob instructions. + + Args: + args: [stream_index, mimetype, name] + """ + if len(args) >= 3: + stream, mimetype, name = args[0], args[1], args[2] + logging.debug(f"[PIPE] stream={stream}, type={mimetype}, name={name}") + else: + logging.debug(f"[PIPE] {args}") + + +def handle_args(args: List[str]) -> None: + """ + Handle args instruction - Server requests connection parameters. + + This is CRITICAL - guacd sends this after receiving 'select' to ask what + parameters are needed for the connection. + + Note: The actual handshake response is sent by the caller (GuacamoleHandler). + + Args: + args: List of parameter names that guacd expects + """ + logging.debug(f"[ARGS] Server requesting parameters: {list(args)}") + + +def handle_ready(args: List[str]) -> None: + """ + Handle ready instruction - Server confirms connection is ready. + + This is sent by guacd after processing 'connect' instruction. + + Args: + args: [connection_id] + """ + connection_id = args[0] if args else None + logging.debug(f"[READY] Connection ready (id: {connection_id})") + + +def handle_unknown(opcode: str, args: List[str]) -> None: + """ + Handle any unrecognized instruction - default handler. + + Args: + opcode: Instruction opcode + args: Instruction arguments + """ + # Truncate long arguments for display + arg_preview = [] + for arg in args: + arg_str = str(arg) + if len(arg_str) > 32: + arg_preview.append(arg_str[:32] + "...") + else: + arg_preview.append(arg_str) + + logging.debug(f"[{opcode.upper()}] {arg_preview}") + + +# ============================================================================= +# Instruction Router +# ============================================================================= + +# Map of opcode -> handler function +_INSTRUCTION_HANDLERS: Dict[str, InstructionHandler] = { + # Critical instructions + 'sync': handle_sync, + 'name': handle_name, + 'size': handle_size, + + # Image instructions + 'png': handle_png, + 'jpeg': handle_jpeg, + 'img': handle_img, + + # Display instructions + 'cursor': handle_cursor, + 'move': handle_move, + 'rect': handle_rect, + 'cfill': handle_cfill, + 'copy': handle_copy, + + # I/O instructions + 'clipboard': handle_clipboard, + 'pipe': handle_pipe, + 'blob': handle_blob, + 'end': handle_end, + 'ack': handle_ack, + + # Control instructions + 'error': handle_error, + 'disconnect': handle_disconnect, + + # Connection handshake (CRITICAL) + 'args': handle_args, + 'ready': handle_ready, + + # Mouse (logged but not printed) + 'mouse': handle_mouse, +} + + +def create_instruction_router( + custom_handlers: Optional[Dict[str, InstructionHandler]] = None, + send_ack_callback: Optional[AckCallback] = None, + stdout_stream_tracker: Optional[Any] = None, +) -> Callable[[str, List[str]], None]: + """ + Create an instruction router callback for use with Parser.oninstruction. + + The router dispatches instructions to the appropriate handler based on opcode. + Custom handlers can override the default handlers. + + For plaintext SSH/TTY streams, the router can track STDOUT pipes and decode + blob data to sys.stdout. This requires: + - send_ack_callback: Function to send ack responses + - stdout_stream_tracker: Object with `stdout_stream_index` attribute for tracking + + Args: + custom_handlers: Optional dict of opcode -> handler to override defaults. + send_ack_callback: Optional callback(stream, message, code) to send ack. + stdout_stream_tracker: Optional object with `stdout_stream_index` attribute. + When set, pipe/blob/end for STDOUT streams will be handled specially: + - pipe with name "STDOUT" stores stream index and sends ack + - blob with matching stream decodes base64 to stdout and sends ack + - end with matching stream clears tracking + + Returns: + A callback function with signature (opcode: str, args: List[str]) -> None + suitable for assigning to Parser.oninstruction. + + Example: + from guacamole import Parser + from guac_cli.instructions import create_instruction_router + + parser = Parser() + parser.oninstruction = create_instruction_router() + parser.receive("4.sync,10.1234567890;") + + Example with STDOUT handling: + class StreamTracker: + stdout_stream_index = -1 + + tracker = StreamTracker() + parser.oninstruction = create_instruction_router( + send_ack_callback=lambda s, m, c: send_ack(s, m, c), + stdout_stream_tracker=tracker, + ) + """ + # Merge default handlers with custom handlers + handlers = _INSTRUCTION_HANDLERS.copy() + if custom_handlers: + handlers.update(custom_handlers) + + def router(opcode: str, args: List[str]) -> None: + """Route instruction to appropriate handler.""" + + # Special handling for pipe/blob/end when STDOUT tracking is enabled + if stdout_stream_tracker is not None and send_ack_callback is not None: + + # Handle pipe - track STDOUT stream + if opcode == 'pipe' and len(args) >= 3: + stream_index, mimetype, name = args[0], args[1], args[2] + if name == 'STDOUT': + stdout_stream_tracker.stdout_stream_index = int(stream_index) + send_ack_callback(stream_index, 'OK', '0') + logging.debug(f"STDOUT pipe opened on stream {stream_index}") + # Still call original handler for diagnostics + handler = handlers.get(opcode) + if handler: + try: + handler(args) + except Exception as e: + logging.error(f"Error in pipe handler: {e}") + return + + # Handle blob - decode STDOUT data to sys.stdout + elif opcode == 'blob' and len(args) >= 2: + stream_index = int(args[0]) + if stream_index == stdout_stream_tracker.stdout_stream_index: + # Decode base64 and write to stdout + try: + decoded = base64.b64decode(args[1]) + # Try buffer.write for binary output, fall back to str for compatibility + if hasattr(sys.stdout, 'buffer'): + sys.stdout.buffer.write(decoded) + else: + sys.stdout.write(decoded.decode('utf-8', errors='replace')) + sys.stdout.flush() + send_ack_callback(args[0], 'OK', '0') + except Exception as e: + logging.error(f"Error decoding STDOUT blob: {e}") + return + # Non-STDOUT blob falls through to default handler + + # Handle end - clear STDOUT tracking + elif opcode == 'end' and len(args) >= 1: + stream_index = int(args[0]) + if stream_index == stdout_stream_tracker.stdout_stream_index: + stdout_stream_tracker.stdout_stream_index = -1 + logging.debug(f"STDOUT stream {stream_index} ended") + # Still call original handler for diagnostics + handler = handlers.get(opcode) + if handler: + try: + handler(args) + except Exception as e: + logging.error(f"Error in end handler: {e}") + return + + # Default routing + handler = handlers.get(opcode) + if handler: + try: + handler(args) + except Exception as e: + logging.error(f"Error handling instruction {opcode}: {e}", exc_info=True) + else: + handle_unknown(opcode, args) + + return router + + +def get_default_handlers() -> Dict[str, InstructionHandler]: + """ + Get a copy of the default instruction handlers. + + Returns: + Dict mapping opcode to handler function. + """ + return _INSTRUCTION_HANDLERS.copy() diff --git a/keepercommander/commands/pam_launch/guac_cli/renderer.py b/keepercommander/commands/pam_launch/guac_cli/renderer.py new file mode 100644 index 000000000..af21bcb6b --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/renderer.py @@ -0,0 +1,362 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' = 2: + try: + x = int(args[0]) + y = int(args[1]) + self.cursor_x = min(max(0, x), self.width - 1) + self.cursor_y = min(max(0, y), self.height - 1) + except ValueError: + pass + + def _handle_text(self, args: list): + """ + Handle text drawing. + + Args format: [layer, x, y, text] + """ + if len(args) >= 4: + try: + # layer = args[0] # Ignore layer for now + x = int(args[1]) + y = int(args[2]) + text = args[3] + + # Draw text at position + self._draw_text(x, y, text) + + except (ValueError, IndexError) as e: + logging.debug(f"Error in text instruction: {e}") + + def _draw_text(self, x: int, y: int, text: str): + """Draw text at specified position in screen buffer""" + if y < 0 or y >= self.height: + return + + for i, ch in enumerate(text): + col = x + i + if col >= 0 and col < self.width: + self.screen[y][col] = ch + self.attrs[y][col] = self._encode_attrs() + + def _encode_attrs(self) -> int: + """Encode current attributes to a single integer""" + attr = 0 + if self.current_bold: + attr |= 1 + if self.current_underline: + attr |= 2 + attr |= (self.current_fg & 0xF) << 4 + attr |= (self.current_bg & 0xF) << 8 + return attr + + def _handle_rect(self, args: list): + """Handle rectangle drawing (fill with color)""" + if len(args) >= 5: + try: + # layer = args[0] + x = int(args[1]) + y = int(args[2]) + w = int(args[3]) + h = int(args[4]) + + # Fill rectangle with spaces + for row in range(y, min(y + h, self.height)): + for col in range(x, min(x + w, self.width)): + if row >= 0 and col >= 0: + self.screen[row][col] = ' ' + self.attrs[row][col] = self._encode_attrs() + + except (ValueError, IndexError): + pass + + def _handle_cfill(self, args: list): + """Handle color fill""" + if len(args) >= 4: + try: + # Parse color (r, g, b, a) + # For simplicity, map to nearest ANSI color + r = int(args[0]) + g = int(args[1]) + b = int(args[2]) + # a = int(args[3]) # alpha + + # Map RGB to ANSI color (0-15) + self.current_bg = self._rgb_to_ansi(r, g, b) + + except (ValueError, IndexError): + pass + + def _rgb_to_ansi(self, r: int, g: int, b: int) -> int: + """Map RGB (0-255) to ANSI color code (0-15)""" + # Simple mapping to 8 basic colors + if r < 128 and g < 128 and b < 128: + return 0 # Black + elif r > 200 and g > 200 and b > 200: + return 7 # White + elif r > 128: + return 1 # Red + elif g > 128: + return 2 # Green + elif b > 128: + return 4 # Blue + else: + return 7 # Default white + + def _handle_copy(self, args: list): + """Handle copy operation (copy screen region)""" + if len(args) >= 7: + try: + # src_layer = args[0] + src_x = int(args[1]) + src_y = int(args[2]) + w = int(args[3]) + h = int(args[4]) + # dst_layer = args[5] + dst_x = int(args[6]) + dst_y = int(args[7]) + + # Copy region + for row in range(h): + for col in range(w): + src_row = src_y + row + src_col = src_x + col + dst_row = dst_y + row + dst_col = dst_x + col + + if (0 <= src_row < self.height and 0 <= src_col < self.width and + 0 <= dst_row < self.height and 0 <= dst_col < self.width): + self.screen[dst_row][dst_col] = self.screen[src_row][src_col] + self.attrs[dst_row][dst_col] = self.attrs[src_row][src_col] + + except (ValueError, IndexError): + pass + + def _handle_size(self, args: list): + """Handle terminal size change""" + if len(args) >= 3: + try: + # layer = args[0] + new_width = int(args[1]) + new_height = int(args[2]) + + if new_width > 0 and new_height > 0: + self.resize(new_width, new_height) + + except (ValueError, IndexError): + pass + + def _handle_move(self, args: list): + """Handle layer move (not relevant for text terminal)""" + pass + + def _handle_sync(self, args: list): + """Handle sync instruction - refresh the display""" + self.refresh() + + def _handle_error(self, args: list): + """Handle error message from server""" + if args: + error_msg = args[0] + logging.error(f"Guacamole server error: {error_msg}") + sys.stderr.write(f"\nServer error: {error_msg}\n") + sys.stderr.flush() + + def resize(self, new_width: int, new_height: int): + """ + Resize the terminal buffer. + + Args: + new_width: New width in characters + new_height: New height in characters + """ + # Create new buffers + new_screen = [[' ' for _ in range(new_width)] for _ in range(new_height)] + new_attrs = [[0 for _ in range(new_width)] for _ in range(new_height)] + + # Copy old content + for y in range(min(self.height, new_height)): + for x in range(min(self.width, new_width)): + new_screen[y][x] = self.screen[y][x] + new_attrs[y][x] = self.attrs[y][x] + + self.width = new_width + self.height = new_height + self.screen = new_screen + self.attrs = new_attrs + + # Clamp cursor position + self.cursor_x = min(self.cursor_x, new_width - 1) + self.cursor_y = min(self.cursor_y, new_height - 1) + + # Clear and redraw + sys.stdout.write('\033[2J') # Clear screen + self.refresh() + + def refresh(self): + """Refresh the entire display from the screen buffer""" + if not self.raw_mode: + return + + # Move cursor to home + sys.stdout.write('\033[H') + + # Render each line + for y in range(self.height): + line = ''.join(self.screen[y]) + sys.stdout.write(line) + if y < self.height - 1: + sys.stdout.write('\n') + + # Move cursor to current position + sys.stdout.write(f'\033[{self.cursor_y + 1};{self.cursor_x + 1}H') + sys.stdout.flush() + + def get_size(self) -> Tuple[int, int]: + """ + Get current terminal size. + + Returns: + Tuple of (width, height) in characters + """ + try: + # Try to get actual terminal size + import shutil + size = shutil.get_terminal_size(fallback=(80, 24)) + return (size.columns, size.lines) + except: + return (self.width, self.height) + + def clear(self): + """Clear the screen""" + self.screen = [[' ' for _ in range(self.width)] for _ in range(self.height)] + self.attrs = [[0 for _ in range(self.width)] for _ in range(self.height)] + if self.raw_mode: + sys.stdout.write('\033[2J') + sys.stdout.write('\033[H') + sys.stdout.flush() + diff --git a/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py b/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py new file mode 100644 index 000000000..53c6bdb63 --- /dev/null +++ b/keepercommander/commands/pam_launch/guac_cli/stdin_handler.py @@ -0,0 +1,635 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ (R) +# | ' - Send base64-encoded keyboard input +- end,0 - Close the stream + +Each chunk of stdin input is sent as a complete pipe/blob/end sequence. + +Platform support: +- Unix/Linux: Uses termios for raw mode, select for non-blocking reads +- macOS: Uses same approach as Unix (termios + select) +- Windows: Uses msvcrt for console input +""" + +from __future__ import annotations +import logging +import sys +import threading +from typing import Callable, Optional + + +class StdinHandler: + """ + Handles stdin input for plaintext SSH/TTY sessions. + + Reads raw stdin in non-buffered mode and sends data via callback. + Uses pipe/blob/end pattern matching kcm-cli implementation. + + Enhanced to detect escape sequences (arrow keys, function keys) and + send them as X11 key events instead of raw bytes. + """ + + def __init__(self, stdin_callback: Callable[[bytes], None], + key_callback: Optional[Callable[[int, bool], None]] = None): + """ + Initialize the stdin handler. + + Args: + stdin_callback: Callback function(data: bytes) to send stdin data. + Should call GuacamoleHandler.send_stdin() + key_callback: Optional callback function(keysym: int, pressed: bool) + to send key events. Should call GuacamoleHandler.send_key() + If provided, escape sequences will be converted to key events. + """ + self.stdin_callback = stdin_callback + self.key_callback = key_callback + self.running = False + self.thread: Optional[threading.Thread] = None + self.raw_mode_active = False + self._escape_buffer = b'' # Buffer for escape sequences + + # Platform-specific stdin reader + self._stdin_reader = self._get_stdin_reader() + + def _get_stdin_reader(self): + """Get platform-specific stdin reader.""" + if sys.platform == 'win32': + return _WindowsStdinReader() + elif sys.platform == 'darwin': + return _MacOSStdinReader() + else: + return _UnixStdinReader() + + def start(self): + """Start reading stdin in a background thread.""" + if self.running: + return + + self.running = True + self._stdin_reader.set_raw_mode() + self.raw_mode_active = True + + self.thread = threading.Thread(target=self._input_loop, daemon=True) + self.thread.start() + logging.debug("StdinHandler started") + + def stop(self): + """Stop reading stdin and restore terminal.""" + self.running = False + # Flush any pending escape sequence + if self._escape_buffer: + # If we have just ESC (0x1B) with no following bytes, treat as standalone ESC key + if len(self._escape_buffer) == 1 and self._escape_buffer[0] == 0x1B: + if self.key_callback: + self._send_key(0xFF1B) # X11Keysym.ESCAPE + else: + self.stdin_callback(self._escape_buffer) + else: + # Incomplete escape sequence - send as regular data + logging.debug(f"Flushing incomplete escape sequence: {self._escape_buffer}") + self.stdin_callback(self._escape_buffer) + self._escape_buffer = b'' + if self.raw_mode_active: + self._stdin_reader.restore() + self.raw_mode_active = False + if self.thread: + # Don't wait too long - stdin.read() might be blocking + self.thread.join(timeout=0.5) + logging.debug("StdinHandler stopped") + + def _input_loop(self): + """Main stdin reading loop.""" + import time + last_escape_time = None + + while self.running: + try: + # Read available data (non-blocking with short timeout) + data = self._stdin_reader.read(timeout=0.1) + if data: + self._process_input(data) + # Reset escape timer if we got data + last_escape_time = None + elif self._escape_buffer and len(self._escape_buffer) == 1 and self._escape_buffer[0] == 0x1B: + # We have a standalone ESC in buffer with no more data + # Wait a short time to see if more bytes arrive (escape sequence) + if last_escape_time is None: + last_escape_time = time.time() + elif time.time() - last_escape_time > 0.05: # 50ms timeout + # No more bytes after 50ms - treat as standalone ESC key + logging.debug("Standalone ESC key (timeout)") + self._send_key(0xFF1B) # X11Keysym.ESCAPE + self._escape_buffer = b'' + last_escape_time = None + except Exception as e: + if self.running: # Only log if not shutting down + logging.error(f"Error in stdin loop: {e}") + break + + def _process_input(self, data: bytes): + """ + Process input data, detecting escape sequences and converting them to key events. + + Args: + data: Raw bytes from stdin + """ + if not self.key_callback: + # No key callback - send everything as stdin (original behavior) + self.stdin_callback(data) + return + + # Combine any pending escape buffer with new data + if self._escape_buffer: + data = self._escape_buffer + data + self._escape_buffer = b'' + + # Process data byte by byte to detect escape sequences + i = 0 + while i < len(data): + byte = data[i] + + # Check if we're in an escape sequence + if self._escape_buffer: + self._escape_buffer += bytes([byte]) + keysym = self._detect_escape_sequence() + if keysym is not None: + # Found a complete escape sequence - send as key event + logging.debug(f"Detected escape sequence: {self._escape_buffer.hex()} -> keysym 0x{keysym:04X}") + self._send_key(keysym) + self._escape_buffer = b'' + i += 1 + continue + elif len(self._escape_buffer) > 10: + # Escape sequence too long - treat as regular data + logging.warning(f"Invalid escape sequence (too long): {self._escape_buffer.hex()}") + self.stdin_callback(self._escape_buffer) + self._escape_buffer = b'' + i += 1 + continue + else: + # Still waiting for more bytes in escape sequence + # Check if we've reached the end of current data + if i == len(data) - 1: + # Last byte, might need more - keep in buffer for next read + # But if we only have ESC (0x1B) and no more data, treat as standalone ESC + if len(self._escape_buffer) == 1 and self._escape_buffer[0] == 0x1B: + logging.debug("Standalone ESC key (no more data available)") + self._send_key(0xFF1B) # X11Keysym.ESCAPE + self._escape_buffer = b'' + i += 1 + continue + break + i += 1 + continue + + # Check for start of escape sequence + # Unix/Linux/macOS: ESC = 0x1B + # Windows: Extended key = 0xE0 or 0x00 + if byte == 0x1B: + # Start of potential Unix-style escape sequence + # Check if there are more bytes immediately available + if i < len(data) - 1: + # More bytes available - might be an escape sequence + self._escape_buffer = bytes([byte]) + i += 1 + # Continue processing to see if sequence completes in this read + continue + else: + # This is the last byte - could be standalone ESC or start of sequence + # For now, treat as standalone ESC key (user can press ESC twice if needed) + # If it's part of a sequence, the next read will handle it + logging.debug("Standalone ESC key detected") + self._send_key(0xFF1B) # X11Keysym.ESCAPE + i += 1 + continue + elif byte == 0xE0 or byte == 0x00: + # Windows extended key sequence (0xE0 or 0x00 followed by scan code) + self._escape_buffer = bytes([byte]) + i += 1 + # Check if we can read more bytes immediately + if i < len(data): + # Continue processing to see if sequence completes in this read + continue + else: + # End of data, wait for next read + break + + # Regular character - send as stdin + # But first check if it's a control character that might be part of an escape sequence + if byte < 32 and byte != 0x1B: # Control char but not ESC + # Send control characters as-is (they might be Ctrl+key combinations) + self.stdin_callback(bytes([byte])) + elif byte >= 32: # Printable character + self.stdin_callback(bytes([byte])) + else: + # Shouldn't reach here, but send anyway + self.stdin_callback(bytes([byte])) + i += 1 + + def _detect_escape_sequence(self) -> Optional[int]: + """ + Detect if the escape buffer contains a known escape sequence. + + Returns: + X11 keysym if sequence is recognized and complete, None if incomplete or unknown + """ + if not self._escape_buffer or len(self._escape_buffer) < 2: + return None + + # Check for Windows extended key sequences (0xE0 or 0x00 prefix) + if len(self._escape_buffer) == 2 and (self._escape_buffer[0] == 0xE0 or self._escape_buffer[0] == 0x00): + # Windows console extended key code + scan_code = self._escape_buffer[1] + + # Windows scan codes for arrow keys + if scan_code == 0x48: # 'H' = Up arrow + return 0xFF52 # UP + elif scan_code == 0x50: # 'P' = Down arrow + return 0xFF54 # DOWN + elif scan_code == 0x4D: # 'M' = Right arrow + return 0xFF53 # RIGHT + elif scan_code == 0x4B: # 'K' = Left arrow + return 0xFF51 # LEFT + elif scan_code == 0x47: # Home + return 0xFF50 # HOME + elif scan_code == 0x4F: # End + return 0xFF57 # END + elif scan_code == 0x49: # Page Up + return 0xFF55 # PAGE_UP + elif scan_code == 0x51: # Page Down + return 0xFF56 # PAGE_DOWN + elif scan_code == 0x52: # Insert + return 0xFF63 # INSERT + elif scan_code == 0x53: # Delete + return 0xFFFF # DELETE + # Function keys F1-F10 (Windows scan codes) + elif scan_code == 0x3B: # F1 + return 0xFFBE # F1 + elif scan_code == 0x3C: # F2 + return 0xFFBF # F2 + elif scan_code == 0x3D: # F3 + return 0xFFC0 # F3 + elif scan_code == 0x3E: # F4 + return 0xFFC1 # F4 + elif scan_code == 0x3F: # F5 + return 0xFFC2 # F5 + elif scan_code == 0x40: # F6 + return 0xFFC3 # F6 + elif scan_code == 0x41: # F7 + return 0xFFC4 # F7 + elif scan_code == 0x42: # F8 + return 0xFFC5 # F8 + elif scan_code == 0x43: # F9 + return 0xFFC6 # F9 + elif scan_code == 0x44: # F10 + return 0xFFC7 # F10 + elif scan_code == 0x85: # F11 + return 0xFFC8 # F11 + elif scan_code == 0x86: # F12 + return 0xFFC9 # F12 + else: + return None + + # Unix/Linux/macOS VT100/xterm escape sequences + # Convert to string for pattern matching + try: + seq = self._escape_buffer[1:].decode('ascii', errors='ignore') + except Exception: + return None + + # Arrow keys and navigation (VT100/xterm style - universal on Linux/macOS) + # These sequences are standard across all Unix-like terminals + if seq == '[A': + return 0xFF52 # UP + elif seq == '[B': + return 0xFF54 # DOWN + elif seq == '[C': + return 0xFF53 # RIGHT + elif seq == '[D': + return 0xFF51 # LEFT + elif seq == '[H': + return 0xFF50 # HOME + elif seq == '[F': + return 0xFF57 # END + # Some terminals send arrow keys with modifiers (e.g., [1;2A for Shift+Up, [1;5A for Ctrl+Up) + # We ignore the modifier part and just use the base key + elif seq.startswith('[1;') and len(seq) >= 4: + # Extract the final character (A, B, C, D for arrows) + final_char = seq[-1] + if final_char == 'A': + return 0xFF52 # UP + elif final_char == 'B': + return 0xFF54 # DOWN + elif final_char == 'C': + return 0xFF53 # RIGHT + elif final_char == 'D': + return 0xFF51 # LEFT + + # Function keys (VT100/xterm style - single character after ESC) + elif seq == 'OP': + return 0xFFBE # F1 + elif seq == 'OQ': + return 0xFFBF # F2 + elif seq == 'OR': + return 0xFFC0 # F3 + elif seq == 'OS': + return 0xFFC1 # F4 + + # Function keys (xterm style - with tilde) + elif seq == '[11~': + return 0xFFBE # F1 + elif seq == '[12~': + return 0xFFBF # F2 + elif seq == '[13~': + return 0xFFC0 # F3 + elif seq == '[14~': + return 0xFFC1 # F4 + elif seq == '[15~': + return 0xFFC2 # F5 + elif seq == '[17~': + return 0xFFC3 # F6 + elif seq == '[18~': + return 0xFFC4 # F7 + elif seq == '[19~': + return 0xFFC5 # F8 + elif seq == '[20~': + return 0xFFC6 # F9 + elif seq == '[21~': + return 0xFFC7 # F10 + elif seq == '[23~': + return 0xFFC8 # F11 + elif seq == '[24~': + return 0xFFC9 # F12 + + # Other special keys + elif seq == '[1~': + return 0xFF50 # HOME + elif seq == '[2~': + return 0xFF63 # INSERT + elif seq == '[3~': + return 0xFFFF # DELETE + elif seq == '[4~': + return 0xFF57 # END + elif seq == '[5~': + return 0xFF55 # PAGE_UP + elif seq == '[6~': + return 0xFF56 # PAGE_DOWN + + # Check if sequence might be incomplete (common patterns that need more bytes) + # If it starts with '[' and doesn't end with '~' or a letter, might need more + if seq.startswith('[') and len(seq) >= 2: + # Check if it looks like it might be complete (ends with letter or ~) + if seq[-1].isalpha() or seq[-1] == '~': + # Might be complete but not recognized - return None to continue waiting + # This handles edge cases where we might have partial sequences + pass + + # Not a recognized sequence yet - might need more bytes + return None + + def _send_key(self, keysym: int): + """ + Send a key press and release event. + + Args: + keysym: X11 keysym value + """ + if self.key_callback: + # Send key press + self.key_callback(keysym, True) + # Send key release + self.key_callback(keysym, False) + + +class _UnixStdinReader: + """Unix/Linux stdin reader with raw mode support using termios.""" + + def __init__(self): + self.old_settings = None + + def set_raw_mode(self): + """Set terminal to raw mode (non-buffered, non-echoing).""" + try: + import termios + import tty + import time + + # Flush stdout before changing terminal attributes to ensure all output is complete + sys.stdout.flush() + sys.stderr.flush() + + self.old_settings = termios.tcgetattr(sys.stdin.fileno()) + tty.setraw(sys.stdin.fileno()) + + # Small delay to allow terminal to process the attribute change + # This helps prevent visual glitches where lines appear to be deleted + time.sleep(0.01) # 10ms delay + + # Flush again after setting raw mode + sys.stdout.flush() + sys.stderr.flush() + except Exception as e: + logging.warning(f"Failed to set raw mode: {e}") + + def restore(self): + """Restore terminal to normal mode.""" + if self.old_settings: + try: + import termios + termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, self.old_settings) + except Exception as e: + logging.warning(f"Failed to restore terminal: {e}") + self.old_settings = None + + def read(self, timeout: Optional[float] = None) -> Optional[bytes]: + """ + Read available data from stdin. + + Args: + timeout: Read timeout in seconds (None = blocking) + + Returns: + Bytes data or None if timeout/no data + """ + import select + + if timeout: + ready, _, _ = select.select([sys.stdin], [], [], timeout) + if not ready: + return None + + try: + # Read what's available (up to 4KB) + data = sys.stdin.buffer.read1(4096) + return data if data else None + except Exception: + return None + + +class _MacOSStdinReader: + """ + macOS stdin reader with raw mode support. + + macOS uses the same POSIX termios approach as Linux, but may have + slight differences in terminal handling. This class provides + macOS-specific optimizations if needed. + """ + + def __init__(self): + self.old_settings = None + + def set_raw_mode(self): + """Set terminal to raw mode (non-buffered, non-echoing).""" + try: + import termios + import tty + import time + + # Flush stdout before changing terminal attributes to ensure all output is complete + sys.stdout.flush() + sys.stderr.flush() + + self.old_settings = termios.tcgetattr(sys.stdin.fileno()) + # Use setraw for macOS - same as Linux + tty.setraw(sys.stdin.fileno()) + + # Small delay to allow terminal to process the attribute change + # This helps prevent visual glitches where lines appear to be deleted + time.sleep(0.01) # 10ms delay + + # Flush again after setting raw mode + sys.stdout.flush() + sys.stderr.flush() + except Exception as e: + logging.warning(f"Failed to set raw mode on macOS: {e}") + + def restore(self): + """Restore terminal to normal mode.""" + if self.old_settings: + try: + import termios + termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, self.old_settings) + except Exception as e: + logging.warning(f"Failed to restore terminal on macOS: {e}") + self.old_settings = None + + def read(self, timeout: Optional[float] = None) -> Optional[bytes]: + """ + Read available data from stdin on macOS. + + Args: + timeout: Read timeout in seconds (None = blocking) + + Returns: + Bytes data or None if timeout/no data + """ + import select + + if timeout: + # Use select for non-blocking check + ready, _, _ = select.select([sys.stdin], [], [], timeout) + if not ready: + return None + + try: + # Read what's available (up to 4KB) + # Note: On macOS, read1() may not be available on all Python versions, + # so we use os.read() as fallback + import os + fd = sys.stdin.fileno() + data = os.read(fd, 4096) + return data if data else None + except Exception: + return None + + +class _WindowsStdinReader: + """Windows stdin reader using msvcrt for console input.""" + + def __init__(self): + self.old_mode = None + + def set_raw_mode(self): + """Set console to raw mode on Windows.""" + try: + import time + + # Flush stdout before changing console mode to ensure all output is complete + sys.stdout.flush() + sys.stderr.flush() + + # Windows console is already suitable for getch-style reading + # No explicit raw mode needed for msvcrt, but we still flush and delay + # to prevent visual glitches when entering CLI mode + + # Small delay to allow console to process any pending output + # This helps prevent visual glitches where lines appear to be deleted + time.sleep(0.01) # 10ms delay + + # Flush again after the delay + sys.stdout.flush() + sys.stderr.flush() + except Exception as e: + logging.warning(f"Failed to set raw mode on Windows: {e}") + + def restore(self): + """Restore console mode.""" + # Nothing to restore for basic msvcrt usage + pass + + def read(self, timeout: Optional[float] = None) -> Optional[bytes]: + """ + Read available data from stdin on Windows. + + Args: + timeout: Read timeout in seconds (None = blocking) + + Returns: + Bytes data or None if timeout/no data + """ + try: + import msvcrt + import time + + result = b'' + start = time.time() + + # Collect all available characters + while True: + if timeout and (time.time() - start >= timeout): + break + + if msvcrt.kbhit(): + ch = msvcrt.getch() + result += ch + # Continue collecting if more chars available immediately + continue + elif result: + # Have data, return it + break + else: + # No data yet, brief sleep to avoid busy-wait + time.sleep(0.01) + + return result if result else None + + except Exception as e: + logging.error(f"Error reading from stdin on Windows: {e}") + return None diff --git a/keepercommander/commands/pam_launch/guacamole/__init__.py b/keepercommander/commands/pam_launch/guacamole/__init__.py new file mode 100644 index 000000000..c55ffb665 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/__init__.py @@ -0,0 +1,85 @@ +""" +Guacamole Protocol Library for Python. + +A reusable implementation of the Apache Guacamole protocol, ported from +guacamole-common-js. This library provides protocol parsing, event handling, +and client functionality for building Guacamole-based applications. + +Example usage: + from guacamole import Parser, Client, Status + + # Parse incoming instructions + parser = Parser() + parser.oninstruction = lambda opcode, args: print(f"{opcode}: {args}") + parser.receive("4.sync,10.1234567890;") + + # Create instruction strings + instruction = Parser.to_instruction(["key", "65", "1"]) + + # Build a client with a tunnel + class MyTunnel(Tunnel): + # ... implementation + pass + + client = Client(my_tunnel) + client.onstatechange = lambda state: print(f"State: {state}") + client.connect() +""" + +# Exceptions +from .exceptions import ( + GuacamoleError, + InvalidInstructionError, + ProtocolError, + TunnelError, + ClientError, +) + +# Parser +from .parser import Parser, code_point_count, to_instruction + +# Event system +from .event import Event, EventTarget + +# Status +from .status import Status, StatusCode + +# Integer pool +from .integer_pool import IntegerPool + +# Tunnel +from .tunnel import Tunnel, TunnelState + +# Client +from .client import Client, ClientState, ClientMessage, InputStream, OutputStream + + +__all__ = [ + # Exceptions + "GuacamoleError", + "InvalidInstructionError", + "ProtocolError", + "TunnelError", + "ClientError", + # Parser + "Parser", + "code_point_count", + "to_instruction", + # Event system + "Event", + "EventTarget", + # Status + "Status", + "StatusCode", + # Integer pool + "IntegerPool", + # Tunnel + "Tunnel", + "TunnelState", + # Client + "Client", + "ClientState", + "ClientMessage", + "InputStream", + "OutputStream", +] diff --git a/keepercommander/commands/pam_launch/guacamole/client.py b/keepercommander/commands/pam_launch/guacamole/client.py new file mode 100644 index 000000000..7ddd7b720 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/client.py @@ -0,0 +1,615 @@ +""" +Guacamole protocol client. + +This module provides the Client class for handling Guacamole protocol +communication. This is a terminal-focused implementation that handles +the instruction routing and state management needed for SSH/RDP/VNC +terminal sessions. + +Note: GUI-related handlers (png, jpeg, img, rect, cfill, copy, move, cursor, +video, audio, etc.) are not implemented in this port as they require a +graphical display layer. This implementation focuses on: +- Connection state management +- Keep-alive/sync handling +- Keyboard and mouse input +- Clipboard support +- Stream management +- Error handling +""" + +import time +from enum import IntEnum +from typing import Any, Callable, Dict, List, Optional + +from .integer_pool import IntegerPool +from .parser import Parser +from .status import Status, StatusCode +from .tunnel import Tunnel, TunnelState + + +class ClientState(IntEnum): + """ + All possible Guacamole client states. + + Attributes: + IDLE: The client is idle, with no active connection. + CONNECTING: The client is in the process of establishing a connection. + WAITING: The client is waiting on further information or a remote + server to establish the connection. + CONNECTED: The client is actively connected to a remote server. + DISCONNECTING: The client is in the process of disconnecting. + DISCONNECTED: The client has completed disconnection. + """ + IDLE = 0 + CONNECTING = 1 + WAITING = 2 + CONNECTED = 3 + DISCONNECTING = 4 + DISCONNECTED = 5 + + +class ClientMessage(IntEnum): + """ + Possible messages that can be sent by the server. + + Attributes: + USER_JOINED: A user has joined the connection. + USER_LEFT: A user has left the connection. + """ + USER_JOINED = 0x0001 + USER_LEFT = 0x0002 + + +class InputStream: + """ + Guacamole input stream for receiving data from the server. + + Attributes: + client: The client that owns this stream. + index: The index of this stream. + onblob: Callback for received blob data. + onend: Callback when stream ends. + """ + + def __init__(self, client: 'Client', index: int): + """ + Initialize a new InputStream. + + Args: + client: The client that owns this stream. + index: The index of this stream. + """ + self.client = client + self.index = index + self.onblob: Optional[Callable[[str], None]] = None + self.onend: Optional[Callable[[], None]] = None + + +class OutputStream: + """ + Guacamole output stream for sending data to the server. + + Attributes: + client: The client that owns this stream. + index: The index of this stream. + onack: Callback for acknowledgement of sent data. + """ + + def __init__(self, client: 'Client', index: int): + """ + Initialize a new OutputStream. + + Args: + client: The client that owns this stream. + index: The index of this stream. + """ + self.client = client + self.index = index + self.onack: Optional[Callable[[Status], None]] = None + + +class Client: + """ + Guacamole protocol client for terminal-focused applications. + + Given a Tunnel, automatically handles incoming and outgoing Guacamole + instructions via the provided tunnel. This implementation focuses on + terminal operations (keyboard, mouse, clipboard) rather than graphical + display rendering. + + Attributes: + tunnel: The tunnel used for communication. + State: Alias for ClientState enum. + Message: Alias for ClientMessage enum. + + Callbacks: + onstatechange: Called when client state changes. + onerror: Called when an error occurs. + onname: Called when connection name is received. + onsync: Called when sync instruction is received. + onclipboard: Called when clipboard data is available. + onfile: Called when a file transfer starts. + onpipe: Called when a named pipe is created. + onargv: Called when argument value is received. + onrequired: Called when additional parameters are required. + onjoin: Called when a user joins. + onleave: Called when a user leaves. + onmsg: Called for general messages. + + Example: + client = Client(tunnel) + client.onstatechange = lambda state: print(f"State: {state}") + client.connect("hostname=example.com") + """ + + # Expose enums as class attributes + State = ClientState + Message = ClientMessage + + # Keep-alive ping frequency in milliseconds + KEEP_ALIVE_FREQUENCY = 5000 + + def __init__(self, tunnel: Tunnel): + """ + Initialize a new Client. + + Args: + tunnel: The tunnel to use for communication. + """ + self.tunnel = tunnel + self._state = ClientState.IDLE + self._current_timestamp = 0 + self._last_sent_keepalive = 0 + self._keepalive_timeout: Optional[float] = None + + # Stream management + self._stream_indices = IntegerPool() + self._streams: Dict[int, InputStream] = {} + self._output_streams: Dict[int, OutputStream] = {} + + # Callbacks + self.onstatechange: Optional[Callable[[ClientState], None]] = None + self.onerror: Optional[Callable[[Status], None]] = None + self.onname: Optional[Callable[[str], None]] = None + self.onsync: Optional[Callable[[int, int], None]] = None + self.onclipboard: Optional[Callable[[InputStream, str], None]] = None + self.onfile: Optional[Callable[[InputStream, str, str], None]] = None + self.onpipe: Optional[Callable[[InputStream, str, str], None]] = None + self.onargv: Optional[Callable[[InputStream, str, str], None]] = None + self.onrequired: Optional[Callable[[List[str]], None]] = None + self.onjoin: Optional[Callable[[str, str], None]] = None + self.onleave: Optional[Callable[[str, str], None]] = None + self.onmsg: Optional[Callable[[int, List[str]], Optional[bool]]] = None + + # Set up instruction handlers + self._instruction_handlers: Dict[str, Callable[[List[str]], None]] = { + 'ack': self._handle_ack, + 'argv': self._handle_argv, + 'blob': self._handle_blob, + 'clipboard': self._handle_clipboard, + 'disconnect': self._handle_disconnect, + 'end': self._handle_end, + 'error': self._handle_error, + 'file': self._handle_file, + 'msg': self._handle_msg, + 'name': self._handle_name, + 'nop': self._handle_nop, + 'pipe': self._handle_pipe, + 'required': self._handle_required, + 'sync': self._handle_sync, + } + + # Wire up tunnel instruction handler + tunnel.oninstruction = self._on_instruction + + @property + def state(self) -> ClientState: + """Get the current client state.""" + return self._state + + def _set_state(self, state: ClientState) -> None: + """ + Set the client state, firing onstatechange if changed. + + Args: + state: The new client state. + """ + if state != self._state: + self._state = state + if self.onstatechange: + self.onstatechange(state) + + def _is_connected(self) -> bool: + """Return whether the client is connected or waiting.""" + return self._state in (ClientState.CONNECTED, ClientState.WAITING) + + def _on_instruction(self, opcode: str, parameters: List[str]) -> None: + """ + Handle a received instruction. + + Args: + opcode: The instruction opcode. + parameters: The instruction parameters. + """ + handler = self._instruction_handlers.get(opcode) + if handler: + handler(parameters) + + # Schedule next keep-alive on any network activity + self._schedule_keepalive() + + # ========================================================================== + # Instruction Handlers + # ========================================================================== + + def _handle_ack(self, parameters: List[str]) -> None: + """Handle ack instruction.""" + stream_index = int(parameters[0]) + reason = parameters[1] + code = int(parameters[2]) + + stream = self._output_streams.get(stream_index) + if stream: + if stream.onack: + stream.onack(Status(code, reason)) + + # If error code, invalidate stream + if code >= 0x0100 and self._output_streams.get(stream_index) is stream: + self._stream_indices.free(stream_index) + del self._output_streams[stream_index] + + def _handle_argv(self, parameters: List[str]) -> None: + """Handle argv instruction (argument value stream).""" + stream_index = int(parameters[0]) + mimetype = parameters[1] + name = parameters[2] + + if self.onargv: + stream = InputStream(self, stream_index) + self._streams[stream_index] = stream + self.onargv(stream, mimetype, name) + else: + self.send_ack(stream_index, "Receiving argument values unsupported", 0x0100) + + def _handle_blob(self, parameters: List[str]) -> None: + """Handle blob instruction (stream data).""" + stream_index = int(parameters[0]) + data = parameters[1] + + stream = self._streams.get(stream_index) + if stream and stream.onblob: + stream.onblob(data) + + def _handle_clipboard(self, parameters: List[str]) -> None: + """Handle clipboard instruction.""" + stream_index = int(parameters[0]) + mimetype = parameters[1] + + if self.onclipboard: + stream = InputStream(self, stream_index) + self._streams[stream_index] = stream + self.onclipboard(stream, mimetype) + else: + self.send_ack(stream_index, "Clipboard unsupported", 0x0100) + + def _handle_disconnect(self, parameters: List[str]) -> None: + """Handle disconnect instruction.""" + self.disconnect() + + def _handle_end(self, parameters: List[str]) -> None: + """Handle end instruction (stream end).""" + stream_index = int(parameters[0]) + + stream = self._streams.get(stream_index) + if stream: + if stream.onend: + stream.onend() + del self._streams[stream_index] + + def _handle_error(self, parameters: List[str]) -> None: + """Handle error instruction.""" + reason = parameters[0] + code = int(parameters[1]) + + if self.onerror: + self.onerror(Status(code, reason)) + + self.disconnect() + + def _handle_file(self, parameters: List[str]) -> None: + """Handle file instruction (file transfer).""" + stream_index = int(parameters[0]) + mimetype = parameters[1] + filename = parameters[2] + + if self.onfile: + stream = InputStream(self, stream_index) + self._streams[stream_index] = stream + self.onfile(stream, mimetype, filename) + else: + self.send_ack(stream_index, "File transfer unsupported", 0x0100) + + def _handle_msg(self, parameters: List[str]) -> None: + """Handle msg instruction (general message).""" + msgid = int(parameters[0]) + + # Fire general message handler first + allow_default = True + if self.onmsg: + result = self.onmsg(msgid, parameters[1:]) + if result is not None: + allow_default = result + + # Fire specific convenience events if allowed + if allow_default: + if msgid == ClientMessage.USER_JOINED: + user_id = parameters[1] + username = parameters[2] + if self.onjoin: + self.onjoin(user_id, username) + elif msgid == ClientMessage.USER_LEFT: + user_id = parameters[1] + username = parameters[2] + if self.onleave: + self.onleave(user_id, username) + + def _handle_name(self, parameters: List[str]) -> None: + """Handle name instruction (connection name).""" + if self.onname: + self.onname(parameters[0]) + + def _handle_nop(self, parameters: List[str]) -> None: + """Handle nop instruction (no operation / keep-alive).""" + # No operation needed - just confirms connection is alive + pass + + def _handle_pipe(self, parameters: List[str]) -> None: + """Handle pipe instruction (named pipe).""" + stream_index = int(parameters[0]) + mimetype = parameters[1] + name = parameters[2] + + if self.onpipe: + stream = InputStream(self, stream_index) + self._streams[stream_index] = stream + self.onpipe(stream, mimetype, name) + else: + self.send_ack(stream_index, "Named pipes unsupported", 0x0100) + + def _handle_required(self, parameters: List[str]) -> None: + """Handle required instruction (additional parameters needed).""" + if self.onrequired: + self.onrequired(parameters) + + def _handle_sync(self, parameters: List[str]) -> None: + """Handle sync instruction.""" + timestamp = int(parameters[0]) + frames = int(parameters[1]) if len(parameters) > 1 else 0 + + # Send sync response + if timestamp != self._current_timestamp: + self.tunnel.send_message("sync", timestamp) + self._current_timestamp = timestamp + + # Transition from WAITING to CONNECTED on first sync + if self._state == ClientState.WAITING: + self._set_state(ClientState.CONNECTED) + + # Fire callback + if self.onsync: + self.onsync(timestamp, frames) + + # ========================================================================== + # Keep-alive Management + # ========================================================================== + + def _send_keepalive(self) -> None: + """Send a keep-alive nop instruction.""" + self.tunnel.send_message('nop') + self._last_sent_keepalive = time.time() * 1000 + + def _schedule_keepalive(self) -> None: + """Schedule the next keep-alive ping.""" + current_time = time.time() * 1000 + keepalive_delay = max( + self._last_sent_keepalive + self.KEEP_ALIVE_FREQUENCY - current_time, + 0 + ) + + if keepalive_delay <= 0: + self._send_keepalive() + else: + # In async environments, this would schedule a timeout + # For sync usage, keep-alive is sent on next network activity + self._keepalive_timeout = current_time + keepalive_delay + + def _stop_keepalive(self) -> None: + """Stop sending keep-alive pings.""" + self._keepalive_timeout = None + + # ========================================================================== + # Public API - Sending + # ========================================================================== + + def send_key_event(self, pressed: bool, keysym: int) -> None: + """ + Send a key event to the server. + + Args: + pressed: True if key is pressed, False if released. + keysym: The X11 keysym of the key. + """ + if not self._is_connected(): + return + self.tunnel.send_message("key", keysym, 1 if pressed else 0) + + def send_mouse_state(self, x: int, y: int, button_mask: int) -> None: + """ + Send a mouse state to the server. + + Args: + x: X coordinate of the mouse. + y: Y coordinate of the mouse. + button_mask: Bitmask of pressed buttons (1=left, 2=middle, 4=right, + 8=scroll-up, 16=scroll-down). + """ + if not self._is_connected(): + return + self.tunnel.send_message("mouse", x, y, button_mask) + + def send_size(self, width: int, height: int) -> None: + """ + Send the current screen size to the server. + + Args: + width: Screen width in pixels. + height: Screen height in pixels. + """ + if not self._is_connected(): + return + self.tunnel.send_message("size", width, height) + + def send_ack(self, stream_index: int, message: str, code: int) -> None: + """ + Acknowledge receipt of data on a stream. + + Args: + stream_index: The index of the stream. + message: Human-readable status message. + code: Status code (0 for success). + """ + if not self._is_connected(): + return + self.tunnel.send_message("ack", stream_index, message, code) + + def send_blob(self, stream_index: int, data: str) -> None: + """ + Send blob data on a stream. + + Args: + stream_index: The index of the stream. + data: Base64-encoded data to send. + """ + if not self._is_connected(): + return + self.tunnel.send_message("blob", stream_index, data) + + def end_stream(self, stream_index: int) -> None: + """ + Mark a stream as complete. + + Args: + stream_index: The index of the stream to end. + """ + if not self._is_connected(): + return + + self.tunnel.send_message("end", stream_index) + + # Free stream index + if stream_index in self._output_streams: + self._stream_indices.free(stream_index) + del self._output_streams[stream_index] + + # ========================================================================== + # Public API - Stream Management + # ========================================================================== + + def create_output_stream(self) -> OutputStream: + """ + Create a new output stream. + + Returns: + A new OutputStream with an allocated index. + """ + index = self._stream_indices.next() + stream = OutputStream(self, index) + self._output_streams[index] = stream + return stream + + def create_clipboard_stream(self, mimetype: str) -> OutputStream: + """ + Create a clipboard stream for sending clipboard data. + + Args: + mimetype: The mimetype of the clipboard data. + + Returns: + An output stream for sending clipboard data. + """ + stream = self.create_output_stream() + self.tunnel.send_message("clipboard", stream.index, mimetype) + return stream + + def create_file_stream(self, mimetype: str, filename: str) -> OutputStream: + """ + Create a file stream for sending a file. + + Args: + mimetype: The mimetype of the file. + filename: The name of the file. + + Returns: + An output stream for sending file data. + """ + stream = self.create_output_stream() + self.tunnel.send_message("file", stream.index, mimetype, filename) + return stream + + def create_pipe_stream(self, mimetype: str, name: str) -> OutputStream: + """ + Create a named pipe stream. + + Args: + mimetype: The mimetype of the data. + name: The name of the pipe. + + Returns: + An output stream for the pipe. + """ + stream = self.create_output_stream() + self.tunnel.send_message("pipe", stream.index, mimetype, name) + return stream + + # ========================================================================== + # Public API - Connection Management + # ========================================================================== + + def connect(self, data: Optional[str] = None) -> None: + """ + Connect to the Guacamole server. + + Args: + data: Arbitrary connection data to send during handshake. + + Raises: + Status: If an error occurs during connection. + """ + self._set_state(ClientState.CONNECTING) + + try: + self.tunnel.connect(data) + except Exception as e: + self._set_state(ClientState.IDLE) + raise + + # Start keep-alive pings + self._schedule_keepalive() + + self._set_state(ClientState.WAITING) + + def disconnect(self) -> None: + """Disconnect from the Guacamole server.""" + if self._state in (ClientState.DISCONNECTED, ClientState.DISCONNECTING): + return + + self._set_state(ClientState.DISCONNECTING) + + # Stop keep-alive + self._stop_keepalive() + + # Send disconnect and close tunnel + self.tunnel.send_message("disconnect") + self.tunnel.disconnect() + + self._set_state(ClientState.DISCONNECTED) diff --git a/keepercommander/commands/pam_launch/guacamole/event.py b/keepercommander/commands/pam_launch/guacamole/event.py new file mode 100644 index 000000000..b5509d375 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/event.py @@ -0,0 +1,179 @@ +""" +Guacamole event system. + +This module provides the Event and EventTarget classes for implementing +event-driven architecture in Guacamole applications. +""" + +import time +from typing import Any, Callable, Dict, List, Optional + + +class Event: + """ + An arbitrary event that can be dispatched by an EventTarget. + + This class serves as the base for more specific event types. Each event + has a type name and timestamp. + + Attributes: + type: The unique name of this event type. + timestamp: Timestamp in seconds when this event was created. + + Example: + event = Event("connection_state_changed") + print(f"Event type: {event.type}") + print(f"Age: {event.get_age()} seconds") + """ + + def __init__(self, event_type: str): + """ + Initialize a new Event. + + Args: + event_type: The unique name of this event type. + """ + self.type: str = event_type + self.timestamp: float = time.time() + + def get_age(self) -> float: + """ + Return the number of seconds elapsed since this event was created. + + Returns: + The age of this event in seconds. + """ + return time.time() - self.timestamp + + def invoke_legacy_handler(self, target: 'EventTarget') -> None: + """ + Invoke the legacy event handler associated with this event. + + This method is called automatically by EventTarget.dispatch() and + provides backward compatibility with single-handler patterns like + "onmousedown" or "onkeyup". + + Subclasses should override this method to invoke the appropriate + legacy handler on the target. + + Args: + target: The EventTarget that emitted this event. + """ + # Default implementation does nothing + pass + + +class EventTarget: + """ + An object that can dispatch Event objects to registered listeners. + + Listeners registered with on() are automatically invoked based on the + event type when dispatch() is called. This class is typically subclassed + by objects that need to emit events. + + Example: + target = EventTarget() + + def on_state_change(event, source): + print(f"State changed: {event.type}") + + target.on("state_change", on_state_change) + target.dispatch(Event("state_change")) + """ + + # Type alias for listener callbacks + Listener = Callable[['Event', 'EventTarget'], None] + + def __init__(self): + """Initialize a new EventTarget.""" + self._listeners: Dict[str, List[EventTarget.Listener]] = {} + + def on(self, event_type: str, listener: 'EventTarget.Listener') -> None: + """ + Register a listener for events of the given type. + + Args: + event_type: The unique name of the event type to listen for. + listener: The callback function to invoke when an event of this + type is dispatched. The function receives the Event object + and the dispatching EventTarget. + """ + if event_type not in self._listeners: + self._listeners[event_type] = [] + self._listeners[event_type].append(listener) + + def on_each(self, types: List[str], listener: 'EventTarget.Listener') -> None: + """ + Register a listener for multiple event types. + + This is equivalent to calling on() for each type in the list. + + Args: + types: List of event type names to listen for. + listener: The callback function to invoke for any of these events. + """ + for event_type in types: + self.on(event_type, listener) + + def off(self, event_type: str, listener: 'EventTarget.Listener') -> bool: + """ + Unregister a previously registered listener. + + If the same listener was registered multiple times, only the first + occurrence is removed. + + Args: + event_type: The event type the listener was registered for. + listener: The listener function to remove. + + Returns: + True if the listener was found and removed, False otherwise. + """ + if event_type not in self._listeners: + return False + + listeners = self._listeners[event_type] + for i, registered_listener in enumerate(listeners): + if registered_listener is listener: + listeners.pop(i) + return True + + return False + + def off_each(self, types: List[str], listener: 'EventTarget.Listener') -> bool: + """ + Unregister a listener from multiple event types. + + This is equivalent to calling off() for each type in the list. + + Args: + types: List of event type names to unregister from. + listener: The listener function to remove. + + Returns: + True if the listener was removed from at least one event type. + """ + changed = False + for event_type in types: + if self.off(event_type, listener): + changed = True + return changed + + def dispatch(self, event: Event) -> None: + """ + Dispatch an event to all registered listeners. + + First invokes the legacy handler (if the event supports it), then + invokes all listeners registered for this event type. + + Args: + event: The event to dispatch. + """ + # Invoke legacy handler for backward compatibility + event.invoke_legacy_handler(self) + + # Invoke all registered listeners + listeners = self._listeners.get(event.type) + if listeners: + for listener in listeners: + listener(event, self) diff --git a/keepercommander/commands/pam_launch/guacamole/exceptions.py b/keepercommander/commands/pam_launch/guacamole/exceptions.py new file mode 100644 index 000000000..a1da065a6 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/exceptions.py @@ -0,0 +1,85 @@ +""" +Custom exceptions for the Guacamole protocol library. + +This module defines the exception hierarchy for all Guacamole-related errors. +""" + +from typing import Optional + + +class GuacamoleError(Exception): + """Base exception for all Guacamole-related errors.""" + + def __init__(self, message: str, code: Optional[int] = None): + """ + Initialize a GuacamoleError. + + Args: + message: Human-readable error description. + code: Optional Guacamole status code. + """ + super().__init__(message) + self.message = message + self.code = code + + def __str__(self) -> str: + if self.code is not None: + return f"[{self.code}] {self.message}" + return self.message + + +class InvalidInstructionError(GuacamoleError): + """Raised when a malformed Guacamole instruction is encountered.""" + + def __init__(self, message: str, instruction: Optional[str] = None): + """ + Initialize an InvalidInstructionError. + + Args: + message: Description of why the instruction is invalid. + instruction: The malformed instruction data, if available. + """ + super().__init__(message) + self.instruction = instruction + + +class ProtocolError(GuacamoleError): + """Raised when a protocol-level error occurs.""" + + def __init__(self, message: str, code: Optional[int] = None): + """ + Initialize a ProtocolError. + + Args: + message: Description of the protocol error. + code: Optional Guacamole status code. + """ + super().__init__(message, code) + + +class TunnelError(GuacamoleError): + """Raised when a tunnel-related error occurs.""" + + def __init__(self, message: str, code: Optional[int] = None): + """ + Initialize a TunnelError. + + Args: + message: Description of the tunnel error. + code: Optional Guacamole status code. + """ + super().__init__(message, code) + + +class ClientError(GuacamoleError): + """Raised when a client-related error occurs.""" + + def __init__(self, message: str, code: Optional[int] = None): + """ + Initialize a ClientError. + + Args: + message: Description of the client error. + code: Optional Guacamole status code. + """ + super().__init__(message, code) diff --git a/keepercommander/commands/pam_launch/guacamole/integer_pool.py b/keepercommander/commands/pam_launch/guacamole/integer_pool.py new file mode 100644 index 000000000..ca0afca43 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/integer_pool.py @@ -0,0 +1,71 @@ +""" +Integer pool for reusing stream indices. + +This module provides the IntegerPool class for efficiently managing +integer indices that can be acquired and released. +""" + +from typing import List + + +class IntegerPool: + """ + Integer pool that returns consistently increasing integers while in use, + and previously-used integers when possible. + + This is used by the Guacamole client for managing stream indices, + allowing freed indices to be reused to conserve the index space. + + Example: + pool = IntegerPool() + idx1 = pool.next() # Returns 0 + idx2 = pool.next() # Returns 1 + pool.free(idx1) # Release index 0 + idx3 = pool.next() # Returns 0 (reused) + idx4 = pool.next() # Returns 2 + + Attributes: + next_int: The next integer to return if no freed integers are available. + """ + + def __init__(self): + """Initialize a new IntegerPool.""" + self._pool: List[int] = [] + self.next_int: int = 0 + + def next(self) -> int: + """ + Return the next available integer from the pool. + + If previously freed integers exist, one of those is returned. + Otherwise, a new integer is allocated and returned. + + Returns: + The next available integer. + """ + if self._pool: + return self._pool.pop(0) + result = self.next_int + self.next_int += 1 + return result + + def free(self, integer: int) -> None: + """ + Free the given integer, allowing it to be reused. + + Args: + integer: The integer to free. + """ + self._pool.append(integer) + + def __contains__(self, integer: int) -> bool: + """ + Check if an integer is currently in the free pool. + + Args: + integer: The integer to check. + + Returns: + True if the integer is in the free pool, False otherwise. + """ + return integer in self._pool diff --git a/keepercommander/commands/pam_launch/guacamole/parser.py b/keepercommander/commands/pam_launch/guacamole/parser.py new file mode 100644 index 000000000..e785eeb8e --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/parser.py @@ -0,0 +1,313 @@ +""" +Guacamole protocol parser. + +This module provides the Parser class for parsing Guacamole protocol instructions +from incoming data streams. It handles the length-prefixed element format and +properly counts Unicode codepoints. + +The Guacamole protocol uses instructions in the format: + LENGTH.VALUE,LENGTH.VALUE,...,LENGTH.VALUE; + +Where LENGTH is the number of Unicode codepoints in VALUE. +""" + +import re +from typing import Any, Callable, List, Optional + +from .exceptions import InvalidInstructionError + + +# Regex pattern for detecting UTF-16 surrogate pairs +# In Python strings (which are UTF-32 internally), surrogate pairs appear +# as two separate characters when the string came from UTF-16 encoding +_SURROGATE_PAIR_PATTERN = re.compile(r'[\uD800-\uDBFF][\uDC00-\uDFFF]') + +# Minimum codepoint that requires a surrogate pair in UTF-16 +_MIN_CODEPOINT_REQUIRES_SURROGATE = 0x10000 + +# Range checks for surrogates +_HIGH_SURROGATE_MIN = 0xD800 +_HIGH_SURROGATE_MAX = 0xDBFF +_LOW_SURROGATE_MIN = 0xDC00 +_LOW_SURROGATE_MAX = 0xDFFF + + +def _is_high_surrogate(char_code: int) -> bool: + """Check if a character code is a high surrogate.""" + return _HIGH_SURROGATE_MIN <= char_code <= _HIGH_SURROGATE_MAX + + +def _is_low_surrogate(char_code: int) -> bool: + """Check if a character code is a low surrogate.""" + return _LOW_SURROGATE_MIN <= char_code <= _LOW_SURROGATE_MAX + + +class Parser: + """ + Simple Guacamole protocol parser that invokes an oninstruction callback when + full instructions are available from data received via receive(). + + The parser handles the Guacamole wire protocol format where each element + is prefixed with its length in Unicode codepoints, followed by a period, + then the element value, and finally a terminator (',' for more elements + or ';' for end of instruction). + + Example: + parser = Parser() + parser.oninstruction = lambda opcode, args: print(f"{opcode}: {args}") + parser.receive("4.sync,10.1234567890;") + # Output: sync: ['1234567890'] + + Attributes: + oninstruction: Callback invoked when a complete instruction is parsed. + Signature: (opcode: str, parameters: List[str]) -> None + """ + + # Number of parsed characters before truncating the buffer to conserve memory + BUFFER_TRUNCATION_THRESHOLD = 4096 + + def __init__(self): + """Initialize a new Parser instance.""" + # Current buffer of received data + self._buffer: str = '' + + # Buffer of all received, complete elements for current instruction + self._element_buffer: List[str] = [] + + # Character offset of current element's terminator (-1 if not yet known) + self._element_end: int = -1 + + # Character offset where parser should start looking for next element + self._start_index: int = 0 + + # Declared length of current element in Unicode codepoints + self._element_codepoints: int = 0 + + # Callback for completed instructions + self.oninstruction: Optional[Callable[[str, List[str]], None]] = None + + def receive(self, packet: str, is_buffer: bool = False) -> None: + """ + Append instruction data to the internal buffer and execute all + completed instructions at the beginning of the buffer. + + Args: + packet: The instruction data to receive. + is_buffer: If True, the packet is treated as an external buffer + that grows continuously. The packet MUST always start with + the data provided to the previous call. If False (default), + only new data should be provided and previously-received + data will be buffered automatically. + + Raises: + InvalidInstructionError: If a malformed instruction is encountered. + """ + if is_buffer: + self._buffer = packet + else: + # Truncate buffer as necessary to conserve memory + if (self._start_index > self.BUFFER_TRUNCATION_THRESHOLD and + self._element_end >= self._start_index): + self._buffer = self._buffer[self._start_index:] + # Reset parse positions relative to truncation + self._element_end -= self._start_index + self._start_index = 0 + + # Append data to buffer only if there is outstanding data. + # Otherwise, parse the received buffer as-is for efficiency. + if self._buffer: + self._buffer += packet + else: + self._buffer = packet + + # Parse while search is within currently received data + while self._element_end < len(self._buffer): + + # If we are waiting for element data + if self._element_end >= self._start_index: + + # Count codepoints in the expected element substring + codepoints = code_point_count( + self._buffer, + self._start_index, + self._element_end + ) + + # If we don't have enough codepoints yet, adjust element_end + # This handles characters that are represented as surrogate pairs + if codepoints < self._element_codepoints: + self._element_end += self._element_codepoints - codepoints + continue + + # If element_end points to a character that's part of a surrogate pair, + # we need to adjust. Two cases: + # 1. element_end-1 is HIGH surrogate and element_end is LOW surrogate + # (we're about to split a pair, need to include the LOW) + # 2. element_end-1 is >= 0x10000 (combined supplementary char in Python, + # though this is rare since Python usually keeps surrogates separate) + if (self._element_codepoints and + self._element_end > 0 and + self._element_end < len(self._buffer)): + last_char_index = self._element_end - 1 + if last_char_index >= self._start_index: + last_char_code = ord(self._buffer[last_char_index]) + term_char_code = ord(self._buffer[self._element_end]) + + # Case 1: Last char is HIGH surrogate, terminator pos is LOW surrogate + # This means we're about to cut a surrogate pair in half + if _is_high_surrogate(last_char_code) and _is_low_surrogate(term_char_code): + self._element_end += 1 + continue + + # Case 2: Character >= 0x10000 (combined supplementary char) + if last_char_code >= _MIN_CODEPOINT_REQUIRES_SURROGATE: + self._element_end += 1 + continue + + # We now have enough data for the element - parse it + element = self._buffer[self._start_index:self._element_end] + + # Get terminator character + if self._element_end < len(self._buffer): + terminator = self._buffer[self._element_end] + else: + # Need more data + break + + # Add element to array + self._element_buffer.append(element) + + # If last element (semicolon terminator), handle instruction + if terminator == ';': + # Get opcode (first element) + opcode = self._element_buffer.pop(0) + + # Call instruction handler + if self.oninstruction is not None: + self.oninstruction(opcode, self._element_buffer) + + # Clear elements for next instruction + self._element_buffer = [] + + # Immediately truncate buffer if fully parsed + if not is_buffer and self._element_end + 1 == len(self._buffer): + self._element_end = -1 + self._buffer = '' + + elif terminator != ',': + raise InvalidInstructionError( + 'Element terminator of instruction was not ";" nor ",".', + instruction=self._buffer[:self._element_end + 1] + ) + + # Start searching for length at character after terminator + self._start_index = self._element_end + 1 + + # Search for end of length (the period) + length_end = self._buffer.find('.', self._start_index) + if length_end != -1: + # Parse length + length_str = self._buffer[self._element_end + 1:length_end] + try: + self._element_codepoints = int(length_str) + except ValueError: + raise InvalidInstructionError( + 'Non-numeric character in element length.', + instruction=length_str + ) + + # Calculate start of element value + self._start_index = length_end + 1 + + # Calculate location of element terminator + self._element_end = self._start_index + self._element_codepoints + + else: + # No period yet, continue search when more data is received + self._start_index = len(self._buffer) + break + + +def code_point_count(s: str, start: int = 0, end: Optional[int] = None) -> int: + """ + Return the number of Unicode codepoints in the given string or substring. + + In Python, strings are stored as proper Unicode (UTF-32 internally), so + len() already gives the codepoint count. However, this function also handles + edge cases where surrogate characters might appear in strings that originated + from UTF-16 encoding. + + Unlike JavaScript's string.length which counts UTF-16 code units (where + surrogate pairs count as 2), this function counts actual Unicode codepoints. + + Args: + s: The string to inspect. + start: The starting index (default 0). + end: The ending index (exclusive). If None, counts to end of string. + + Returns: + The number of Unicode codepoints in the specified portion of the string. + + Example: + >>> code_point_count("hello") + 5 + >>> code_point_count("test string", 0, 4) + 4 + """ + # Extract substring + substring = s[start:end] + + # In Python, len() gives codepoint count for normal strings. + # However, if the string contains unpaired surrogates (from malformed UTF-16), + # we need to handle surrogate pairs that are stored as two characters. + # Find proper surrogate pairs (high surrogate followed by low surrogate) + surrogate_pairs = _SURROGATE_PAIR_PATTERN.findall(substring) + + # Each surrogate pair represents a single codepoint but is stored as + # two characters in Python when originating from UTF-16 data. + # Subtract the number of pairs to get the actual codepoint count. + return len(substring) - len(surrogate_pairs) + + +def to_instruction(elements: List[Any]) -> str: + """ + Convert a list of values into a properly formatted Guacamole instruction. + + Each element is converted to a string and prefixed with its length in + Unicode codepoints, followed by a period. Elements are separated by + commas, and the instruction ends with a semicolon. + + Args: + elements: The values to encode as instruction elements. Must have at + least one element (the opcode). Each element will be converted + to a string. + + Returns: + A complete Guacamole instruction string. + + Example: + >>> to_instruction(["key", "65", "1"]) + '3.key,2.65,1.1;' + >>> to_instruction(["sync", "1234567890"]) + '4.sync,10.1234567890;' + """ + if not elements: + raise ValueError("Instruction must have at least one element (opcode)") + + def to_element(value: Any) -> str: + """Convert a value to a length-prefixed element string.""" + s = str(value) + length = code_point_count(s) + return f"{length}.{s}" + + # Build instruction: first element, then comma-separated remaining elements + instruction = to_element(elements[0]) + for element in elements[1:]: + instruction += ',' + to_element(element) + + return instruction + ';' + + +# Expose functions at module level for convenience +Parser.code_point_count = staticmethod(code_point_count) +Parser.to_instruction = staticmethod(to_instruction) diff --git a/keepercommander/commands/pam_launch/guacamole/status.py b/keepercommander/commands/pam_launch/guacamole/status.py new file mode 100644 index 000000000..9917eee44 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/status.py @@ -0,0 +1,157 @@ +""" +Guacamole status codes and Status class. + +This module provides the Status class and StatusCode enum for representing +Guacamole protocol status codes and associated messages. +""" + +from enum import IntEnum +from typing import Optional + + +class StatusCode(IntEnum): + """ + Enumeration of all Guacamole protocol status codes. + + Status codes are divided into ranges: + - 0x0000-0x00FF: Success/informational + - 0x0100-0x01FF: Unsupported operations + - 0x0200-0x02FF: Server errors + - 0x0300-0x03FF: Client errors + """ + + # Success + SUCCESS = 0x0000 + + # Unsupported + UNSUPPORTED = 0x0100 + + # Server errors + SERVER_ERROR = 0x0200 + SERVER_BUSY = 0x0201 + UPSTREAM_TIMEOUT = 0x0202 + UPSTREAM_ERROR = 0x0203 + RESOURCE_NOT_FOUND = 0x0204 + RESOURCE_CONFLICT = 0x0205 + RESOURCE_CLOSED = 0x0206 + UPSTREAM_NOT_FOUND = 0x0207 + UPSTREAM_UNAVAILABLE = 0x0208 + SESSION_CONFLICT = 0x0209 + SESSION_TIMEOUT = 0x020A + SESSION_CLOSED = 0x020B + + # Client errors + CLIENT_BAD_REQUEST = 0x0300 + CLIENT_UNAUTHORIZED = 0x0301 + CLIENT_FORBIDDEN = 0x0303 + CLIENT_TIMEOUT = 0x0308 + CLIENT_OVERRUN = 0x030D + CLIENT_BAD_TYPE = 0x030F + CLIENT_TOO_MANY = 0x031D + + @classmethod + def from_http_code(cls, http_status: int) -> 'StatusCode': + """ + Return the Guacamole status code that most closely represents + the given HTTP status code. + + Args: + http_status: The HTTP status code to translate. + + Returns: + The corresponding Guacamole status code. + """ + http_to_guac = { + 400: cls.CLIENT_BAD_REQUEST, + 403: cls.CLIENT_FORBIDDEN, + 404: cls.RESOURCE_NOT_FOUND, + 429: cls.CLIENT_TOO_MANY, + 503: cls.SERVER_BUSY, + } + return http_to_guac.get(http_status, cls.SERVER_ERROR) + + @classmethod + def from_websocket_code(cls, ws_code: int) -> 'StatusCode': + """ + Return the Guacamole status code that most closely represents + the given WebSocket close code. + + Args: + ws_code: The WebSocket status code to translate. + + Returns: + The corresponding Guacamole status code. + """ + # Successful disconnect + if ws_code == 1000: # Normal Closure + return cls.SUCCESS + + # Server not reachable + if ws_code in (1006, 1015): # Abnormal Closure, TLS Handshake + return cls.UPSTREAM_NOT_FOUND + + # Server busy/unavailable + if ws_code in (1001, 1012, 1013, 1014): # Going Away, Service Restart, Try Again, Bad Gateway + return cls.UPSTREAM_UNAVAILABLE + + return cls.SERVER_ERROR + + +class Status: + """ + A Guacamole status consisting of a status code and optional message. + + The status code is defined by the protocol, while the message is an + optional human-readable description, typically for debugging. + + Attributes: + code: The Guacamole status code. + message: Optional human-readable message. + + Example: + status = Status(StatusCode.SUCCESS, "Connection established") + if status.is_error(): + print(f"Error: {status.message}") + """ + + def __init__(self, code: int, message: Optional[str] = None): + """ + Initialize a new Status. + + Args: + code: The Guacamole status code (can be int or StatusCode). + message: Optional human-readable message. + """ + self.code: int = int(code) + self.message: Optional[str] = message + + def is_error(self) -> bool: + """ + Return whether this status represents an error. + + Returns: + True if this is an error status, False otherwise. + """ + return self.code < 0 or self.code > 0x00FF + + def __repr__(self) -> str: + """Return a string representation of this status.""" + try: + code_name = StatusCode(self.code).name + except ValueError: + code_name = f"UNKNOWN({self.code})" + + if self.message: + return f"Status({code_name}, {self.message!r})" + return f"Status({code_name})" + + def __str__(self) -> str: + """Return a human-readable string for this status.""" + try: + code_name = StatusCode(self.code).name + except ValueError: + code_name = f"Code {self.code}" + + if self.message: + return f"{code_name}: {self.message}" + return code_name diff --git a/keepercommander/commands/pam_launch/guacamole/tunnel.py b/keepercommander/commands/pam_launch/guacamole/tunnel.py new file mode 100644 index 000000000..c55fbb794 --- /dev/null +++ b/keepercommander/commands/pam_launch/guacamole/tunnel.py @@ -0,0 +1,156 @@ +""" +Guacamole tunnel abstract base class. + +This module provides the abstract Tunnel class that defines the interface +for Guacamole protocol communication channels. Concrete implementations +should handle the actual transport mechanism (WebSocket, HTTP, WebRTC, etc.). +""" + +from abc import ABC, abstractmethod +from enum import IntEnum +from typing import Any, Callable, List, Optional + +from .status import Status + + +class TunnelState(IntEnum): + """ + All possible tunnel states. + + Attributes: + CONNECTING: A connection is pending. It is not yet known whether + connection was successful. + OPEN: Connection was successful, and data is being received. + CLOSED: The connection is closed. Connection may not have been + successful, the tunnel may have been explicitly closed by + either side, or an error may have occurred. + UNSTABLE: The connection is open, but communication appears to be + disrupted, and the connection may close as a result. + """ + CONNECTING = 0 + OPEN = 1 + CLOSED = 2 + UNSTABLE = 3 + + +class Tunnel(ABC): + """ + Abstract base class for Guacamole protocol tunnels. + + This class defines the interface for sending and receiving Guacamole + protocol instructions over a communication channel. Concrete implementations + should handle the specific transport mechanism. + + Attributes: + state: The current state of this tunnel. + uuid: The UUID uniquely identifying this tunnel, or None if not yet known. + receive_timeout: Maximum time (ms) to wait for data before closing. + unstable_threshold: Time (ms) before connection is considered unstable. + oninstruction: Callback for received instructions. + onstatechange: Callback for state changes. + onerror: Callback for errors. + onuuid: Callback when UUID becomes known. + + Example: + class MyTunnel(Tunnel): + def connect(self, data=None): + # Implementation + pass + + def disconnect(self): + # Implementation + pass + + def send_message(self, *elements): + # Implementation + pass + """ + + # Internal data opcode used by tunnel implementations + INTERNAL_DATA_OPCODE = '' + + def __init__(self): + """Initialize a new Tunnel instance.""" + self._state: TunnelState = TunnelState.CLOSED + self.uuid: Optional[str] = None + self.receive_timeout: int = 15000 + self.unstable_threshold: int = 1500 + + # Callbacks + self.oninstruction: Optional[Callable[[str, List[str]], None]] = None + self.onstatechange: Optional[Callable[[TunnelState], None]] = None + self.onerror: Optional[Callable[[Status], None]] = None + self.onuuid: Optional[Callable[[str], None]] = None + + @property + def state(self) -> TunnelState: + """Get the current tunnel state.""" + return self._state + + @state.setter + def state(self, value: TunnelState) -> None: + """Set the tunnel state (use set_state() for callback notification).""" + self._state = value + + def set_state(self, state: TunnelState) -> None: + """ + Change the tunnel state, firing onstatechange if the state changes. + + Args: + state: The new state of this tunnel. + """ + if state != self._state: + self._state = state + if self.onstatechange: + self.onstatechange(state) + + def set_uuid(self, uuid: str) -> None: + """ + Set the tunnel UUID, firing onuuid callback. + + Args: + uuid: The unique identifier for this tunnel. + """ + self.uuid = uuid + if self.onuuid: + self.onuuid(uuid) + + def is_connected(self) -> bool: + """ + Return whether this tunnel is currently connected. + + Returns: + True if the tunnel is in OPEN or UNSTABLE state, False otherwise. + """ + return self._state in (TunnelState.OPEN, TunnelState.UNSTABLE) + + @abstractmethod + def connect(self, data: Optional[str] = None) -> None: + """ + Connect to the tunnel with the given optional data. + + The data is typically used for authentication. The format of data + accepted is up to the tunnel implementation. + + Args: + data: Optional data to send during connection (e.g., auth tokens). + """ + pass + + @abstractmethod + def disconnect(self) -> None: + """Disconnect from the tunnel.""" + pass + + @abstractmethod + def send_message(self, *elements: Any) -> None: + """ + Send a message through the tunnel. + + All messages are guaranteed to be received in the order sent. + + Args: + *elements: The elements of the message to send. These will be + formatted as a Guacamole instruction. + """ + pass diff --git a/keepercommander/commands/pam_launch/launch.py b/keepercommander/commands/pam_launch/launch.py new file mode 100644 index 000000000..7712f8471 --- /dev/null +++ b/keepercommander/commands/pam_launch/launch.py @@ -0,0 +1,525 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' bool: + """ + Check if a record is a valid PAM record type. + + Args: + params: KeeperParams instance + record_uid: Record UID to check + + Returns: + True if record is a valid PAM type (version 3 TypedRecord with PAM type), False otherwise + """ + try: + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + return False + if record.version != 3: + return False + return record.record_type in self.VALID_PAM_RECORD_TYPES + except Exception as e: + logging.debug(f"Error checking record type for {record_uid}: {e}") + return False + + def find_record(self, params: KeeperParams, record_token: str) -> Optional[str]: + """ + Find a record by UID, path, or title. + + Args: + params: KeeperParams instance + record_token: Record identifier (UID, path, or title) + + Returns: + Record UID if found, None otherwise + + Raises: + CommandError: If multiple records match + """ + if not record_token: + return None + + record_token = record_token.strip() + + # Step 1: Try UID lookup + uid_pattern = re.compile(r'^[A-Za-z0-9_-]{22}$') + if uid_pattern.match(record_token): + if record_token in params.record_cache: + # Validate it's a PAM record type + if self._is_valid_pam_record(params, record_token): + logging.debug(f"Found record by UID: {record_token}") + return record_token + else: + logging.debug(f"Record {record_token} found but is not a valid PAM record type") + return None + + # Step 2: Try path lookup + record_uid = self._find_by_path(params, record_token) + if record_uid: + return record_uid + + # Step 3: Try full title match + record_uid = self._find_by_title(params, record_token) + if record_uid: + return record_uid + + return None + + def _find_by_path(self, params: KeeperParams, path: str) -> Optional[str]: + """ + Find record by path resolution. + + Args: + params: KeeperParams instance + path: Path to the record + + Returns: + Record UID if found, None otherwise + + Raises: + CommandError: If multiple records match + """ + rs = try_resolve_path(params, path) + if rs is None: + return None + + folder, name = rs + if folder is None or name is None: + return None + + folder_uid = folder.uid or '' + if folder_uid not in params.subfolder_record_cache: + return None + + # Find all records in the folder with matching title (only valid PAM types) + matched_uids = [] + for uid in params.subfolder_record_cache[folder_uid]: + r = api.get_record(params, uid) + if r and r.title and r.title.lower() == name.lower(): + # Only include valid PAM record types + if self._is_valid_pam_record(params, uid): + matched_uids.append(uid) + + if len(matched_uids) > 1: + raise CommandError('pam launch', f'Multiple valid PAM records found with path "{path}". Please use a unique identifier.') + + if matched_uids: + logging.debug(f"Found record by path: {path} -> {matched_uids[0]}") + return matched_uids[0] + + return None + + def _find_by_title(self, params: KeeperParams, title: str) -> Optional[str]: + """ + Find record by exact title match. + + Args: + params: KeeperParams instance + title: Title to match + + Returns: + Record UID if found, None otherwise + + Raises: + CommandError: If multiple records match + """ + matched_uids = [] + for record_uid in params.record_cache: + record = vault.KeeperRecord.load(params, record_uid) + if record and record.title and record.title.lower() == title.lower(): + # Only include valid PAM record types + if self._is_valid_pam_record(params, record_uid): + matched_uids.append(record_uid) + + if len(matched_uids) > 1: + raise CommandError('pam launch', f'Multiple valid PAM records found with title "{title}". Please use a unique identifier (UID or full path).') + + if matched_uids: + logging.debug(f"Found record by title: {title} -> {matched_uids[0]}") + return matched_uids[0] + + return None + + def find_gateway(self, params: KeeperParams, record_uid: str) -> Optional[Dict]: + """ + Find the gateway associated with a PAM record. + + Args: + params: KeeperParams instance + record_uid: Record UID to find gateway for (must be pre-validated as PAM type) + + Returns: + Dictionary with gateway information including: + - gateway_uid: Gateway UID (str) + - gateway_name: Gateway name (str) + - config_uid: PAM configuration UID (str) + - gateway_proto: Gateway protobuf object (pam_pb2.PAMController) + Returns None if no gateway found + + Raises: + CommandError: If gateway configuration issues exist + """ + # Get the gateway UID from the record + # Note: Record type validation happens in find_record() + gateway_uid = get_gateway_uid_from_record(params, vault, record_uid) + + if not gateway_uid: + raise CommandError('pam launch', f'No gateway found for record {record_uid}. ') + + logging.debug(f"Found gateway UID for record: {gateway_uid}") + + # Get all gateways to find the matching one + all_gateways = get_all_gateways(params) + + # Find the gateway by UID + gateway_uid_bytes = url_safe_str_to_bytes(gateway_uid) + gateway_proto = next((g for g in all_gateways if g.controllerUid == gateway_uid_bytes), None) + + if not gateway_proto: + raise CommandError('pam launch', f'Gateway {gateway_uid} not found in available gateways.') + + gateway_name = gateway_proto.controllerName if gateway_proto else 'Unknown' + logging.debug(f"Found gateway: {gateway_name} ({gateway_uid})") + + # Get the configuration UID + config_uid = get_config_uid_from_record(params, vault, record_uid) + + return { + 'gateway_uid': gateway_uid, + 'gateway_name': gateway_name, + 'config_uid': config_uid, + 'gateway_proto': gateway_proto + } + + def execute(self, params: KeeperParams, **kwargs): + """ + Execute the PAM launch command + + Args: + params: KeeperParams instance containing session state + **kwargs: Command arguments including 'record' (record path or UID) + """ + # Save original root logger level and set to ERROR if not in DEBUG mode + root_logger = logging.getLogger() + original_level = root_logger.level + + if root_logger.getEffectiveLevel() > logging.DEBUG: + root_logger.setLevel(logging.ERROR) + + try: + record_token = kwargs.get('record') + + if not record_token: + logging.error("Record path or UID is required") + return + + # Find the record + record_uid = self.find_record(params, record_token) + + if not record_uid: + raise CommandError('pam launch', f'Record not found: {record_token}') + + logging.debug(f"Found record: {record_uid}") + + # Find the gateway for this record + gateway_info = self.find_gateway(params, record_uid) + + if not gateway_info: + raise CommandError('pam launch', f'No gateway found for record {record_uid}') + + logging.debug(f"Found gateway: {gateway_info['gateway_name']} ({gateway_info['gateway_uid']})") + logging.debug(f"Configuration: {gateway_info['config_uid']}") + + # Check if Gateway is online before attempting WebRTC connection + try: + connected_gateways = router_get_connected_gateways(params) + connected_gateway_uids = [x.controllerUid for x in connected_gateways.controllers] + gateway_uid_bytes = url_safe_str_to_bytes(gateway_info['gateway_uid']) + + if gateway_uid_bytes not in connected_gateway_uids: + raise CommandError( + 'pam launch', + f'Gateway "{gateway_info["gateway_name"]}" ({gateway_info["gateway_uid"]}) is currently offline. ' + f'Please start the Gateway before attempting to connect. ' + f'Use "pam gateway list" to check Gateway status.' + ) + + logging.debug(f"✓ Gateway is online and connected") + except Exception as e: + # If router is down or there's an error checking status, still try to connect + # (the connection attempt will fail later with a more specific error) + if isinstance(e, CommandError): + raise + logging.warning(f"Could not verify Gateway online status: {e}. Continuing anyway...") + + # Launch terminal connection + result = launch_terminal_connection(params, record_uid, gateway_info, **kwargs) + + if result.get('success'): + logging.debug(f"Terminal connection launched successfully") + logging.debug(f"Protocol: {result.get('protocol')}") + + # Always start interactive CLI session + self._start_cli_session(result, params) + else: + error_msg = result.get('error', 'Unknown error') + raise CommandError('pam launch', f'Failed to launch connection: {error_msg}') + finally: + # Restore original root logger level + root_logger.setLevel(original_level) + + def _start_cli_session(self, tunnel_result: Dict[str, Any], params: KeeperParams): + """ + Start CLI session using PythonHandler protocol mode. + + In PythonHandler mode: + - Python initiates connection via tube_registry.open_handler_connection() + - Rust forwards OpenConnection to Gateway and handles Ping/Pong heartbeat + - Gateway starts guacd and connects to target + - Python receives Guacamole protocol data via callback + - Python sends Guacamole responses back via tube_registry.send_handler_data() + + Flow: + 1. Wait for WebRTC connection to be established + 2. Send OpenConnection to Gateway (conn_no=1) + 3. Gateway starts guacd, sends 'args' instruction + 4. Python responds with 'connect', 'size', 'audio', 'image' + 5. guacd sends 'ready', terminal session begins + + Args: + tunnel_result: Result from launch_terminal_connection + params: KeeperParams instance + """ + shutdown_requested = False + + def signal_handler_fn(signum, frame): + nonlocal shutdown_requested + shutdown_requested = True + logging.warning("\n\n* Interrupt received - shutting down...") + + original_handler = signal.signal(signal.SIGINT, signal_handler_fn) + + try: + tube_id = tunnel_result['tunnel'].get('tube_id') + if not tube_id: + raise CommandError('pam launch', 'No tube ID in tunnel result') + + tube_registry = tunnel_result['tunnel'].get('tube_registry') + if not tube_registry: + raise CommandError('pam launch', 'No tube registry in tunnel result') + + python_handler = tunnel_result['tunnel'].get('python_handler') + if not python_handler: + raise CommandError('pam launch', 'No python_handler in tunnel result - ensure Rust module supports PythonHandler mode') + + conversation_id = tunnel_result['tunnel'].get('conversation_id') + + logging.debug(f"Starting PythonHandler CLI session for tube {tube_id}") + + # Display connection banner + logging.debug(f"\n{'-' * 60}") + logging.debug(f"CLI Terminal Mode - PythonHandler") + logging.debug(f"Protocol: {tunnel_result['protocol']}") + logging.debug(f"Target: {tunnel_result['settings']['hostname']}:{tunnel_result['settings']['port']}") + logging.debug(f"Tube ID: {tube_id}") + logging.debug(f"{'-' * 60}") + logging.debug("Python sends: OpenConnection (initiates guacd session)") + logging.debug("Rust handles: Ping/Pong heartbeat, message routing") + logging.debug("Python receives: Guacamole protocol data via callback") + logging.debug(f"{'=' * 60}\n") + + # Start the Python handler + python_handler.start() + + # Wait for WebRTC connection to be established + logging.debug("Waiting for WebRTC connection...") + max_wait = 15 + start_time = time.time() + connected = False + + while time.time() - start_time < max_wait: + try: + state = tube_registry.get_connection_state(tube_id) + if state and state.lower() == 'connected': + logging.debug(f"✓ WebRTC connection established: {state}") + connected = True + break + except Exception as e: + logging.debug(f"Checking connection state: {e}") + time.sleep(0.1) + + if not connected: + raise CommandError('pam launch', "WebRTC connection not established within timeout") + + # Send OpenConnection to Gateway to initiate guacd session + # This is critical - without it, Gateway doesn't start guacd and no Guacamole traffic flows + logging.debug(f"Sending OpenConnection to Gateway (conn_no=1, conversation_id={conversation_id})") + try: + tube_registry.open_handler_connection(conversation_id, 1) + logging.debug("✓ OpenConnection sent successfully") + except Exception as e: + logging.error(f"Failed to send OpenConnection: {e}") + raise CommandError('pam launch', f"Failed to send OpenConnection: {e}") + + # Wait for Guacamole ready + print("Waiting for Guacamole connection...") + + # Clear screen by printing terminal height worth of newlines + # This prevents raw mode from overwriting existing screen lines + terminal_height = 24 + try: + terminal_size = shutil.get_terminal_size() + terminal_height = terminal_size.lines + except Exception: + terminal_height = 24 + print("\n" * terminal_height, end='', flush=True) + + guac_ready_timeout = 10.0 # Reduced from 30s - sync triggers readiness quickly + + if python_handler.wait_for_ready(guac_ready_timeout): + logging.debug("* Guacamole connection ready!") + logging.debug("Terminal session active. Press Ctrl+C to exit.") + else: + logging.warning(f"Guacamole did not report ready within {guac_ready_timeout}s") + logging.warning("Terminal may still work if data is flowing.") + + # Create stdin handler for pipe/blob/end input pattern + # StdinHandler reads raw stdin and sends via send_stdin (base64-encoded) + # This matches kcm-cli's implementation for plaintext SSH/TTY streams + stdin_handler = StdinHandler( + stdin_callback=lambda data: python_handler.send_stdin(data), + key_callback=lambda keysym, pressed: python_handler.send_key(keysym, pressed) + ) + + # Main event loop with stdin input + try: + # Start stdin handler (runs in background thread) + stdin_handler.start() + logging.debug("STDIN handler started") # (pipe/blob/end mode) + + elapsed = 0 + while not shutdown_requested and python_handler.running: + # Check if tube/connection is closed + try: + state = tube_registry.get_connection_state(tube_id) + if state and state.lower() in ('closed', 'disconnected', 'failed'): + logging.debug(f"Tube/connection closed (state: {state}) - exiting") + python_handler.running = False + break + except Exception: + # If we can't check state, continue (tube might be closing) + pass + time.sleep(0.1) + elapsed += 0.1 + + # Status indicator every 30 seconds + if elapsed % 30.0 < 0.1 and elapsed > 0.1: + rx = python_handler.messages_received + tx = python_handler.messages_sent + syncs = python_handler.sync_count + logging.debug(f"[{int(elapsed)}s] Session active (rx={rx}, tx={tx}, syncs={syncs})") + + except KeyboardInterrupt: + logging.debug("\n\nExiting CLI terminal mode...") + + finally: + # Stop stdin handler first (restores terminal) + logging.debug("Stopping stdin handler...") + try: + stdin_handler.stop() + except Exception as e: + logging.debug(f"Error stopping stdin handler: {e}") + + # Cleanup - check if connection is already closed to avoid deadlock + logging.debug("Stopping Python handler...") + try: + # Check if tube is already closed - if so, skip sending disconnect + try: + state = tube_registry.get_connection_state(tube_id) + skip_disconnect = state and state.lower() in ('closed', 'disconnected', 'failed') + except Exception: + skip_disconnect = False + + python_handler.stop(skip_disconnect=skip_disconnect) + except Exception as e: + logging.debug(f"Error stopping Python handler: {e}") + + # Close the tube (Rust handles CloseConnection automatically) + logging.debug("Closing WebRTC tunnel...") + try: + tube_registry.close_tube(tube_id) + logging.debug(f"Closed tube: {tube_id}") + except Exception as e: + logging.debug(f"Error closing tube: {e}") + + # Clean up registrations + try: + unregister_tunnel_session(tube_id) + if conversation_id: + unregister_conversation_key(conversation_id) + except Exception as e: + logging.debug(f"Error unregistering: {e}") + + logging.info("CLI session ended - cleanup complete") + + except Exception as e: + logging.error(f"Error in PythonHandler CLI session: {e}") + raise CommandError('pam launch', f'Failed to start CLI session: {e}') + finally: + signal.signal(signal.SIGINT, original_handler) diff --git a/keepercommander/commands/pam_launch/python_handler.py b/keepercommander/commands/pam_launch/python_handler.py new file mode 100644 index 000000000..dbdc13d69 --- /dev/null +++ b/keepercommander/commands/pam_launch/python_handler.py @@ -0,0 +1,868 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' CONNECTED on first sync, NOT on 'ready' + # We track both for compatibility: + # - handshake_complete: Set when we receive 'ready' instruction (informational) + # - data_flowing: Set when we receive first 'sync' instruction (TRUE readiness) + self.handshake_complete = threading.Event() # 'ready' received (custom extension) + self.data_flowing = threading.Event() # First 'sync' received (protocol standard) + + # For backwards compatibility, connection_ready = data_flowing + # This matches JS client behavior where sync = ready + self.connection_ready = self.data_flowing + + self.guac_connection_id: Optional[str] = None + self.sync_count = 0 # Track number of syncs received + + # Statistics + self.messages_received = 0 + self.bytes_received = 0 + self.messages_sent = 0 + self.bytes_sent = 0 + + def start(self): + """Start the handler.""" + if self.running: + return + self.running = True + logging.debug(f"GuacamoleHandler started (conversation_id={self.conversation_id})") + + def stop(self, skip_disconnect: bool = False): + """ + Stop the handler and optionally send disconnect to guacd. + + Args: + skip_disconnect: If True, skip sending disconnect instruction. + Use this when connection is already closed to avoid deadlock. + """ + if not self.running: + return + + self.running = False + + # Send graceful disconnect to guacd (unless connection already closed) + if not skip_disconnect: + try: + disconnect_instruction = self._format_instruction('disconnect') + self._send_to_gateway(disconnect_instruction) + logging.debug("Sent disconnect instruction to guacd") + except Exception as e: + # Don't warn if connection is already closed - this is expected + if "closed" not in str(e).lower() and "disconnected" not in str(e).lower(): + logging.warning(f"Failed to send disconnect instruction: {e}") + + logging.debug( + f"GuacamoleHandler stopped (conversation_id={self.conversation_id}, " + f"rx={self.messages_received}, tx={self.messages_sent})" + ) + + def handle_events(self, events: List[Dict[str, Any]]): + """ + Handle a batch of events from Rust PythonHandler. + + This is called by the Rust handler task with a list of events. + Events are batched for GIL efficiency (up to 10 messages per batch). + + Args: + events: List of event dicts, each with: + - type: "connection_opened" | "data" | "connection_closed" + - conn_no: Connection number (int) + - conversation_id: Conversation ID (str) + - payload: Bytes data (for "data" events) + - reason: Close reason code (for "connection_closed" events) + """ + for event in events: + try: + self._handle_single_event(event) + except Exception as e: + logging.error(f"Error handling event: {e}", exc_info=True) + + def _handle_single_event(self, event: Dict[str, Any]): + """Handle a single event from Rust.""" + event_type = event.get('type') + conn_no = event.get('conn_no', 1) + + if event_type == 'connection_opened': + self._on_connection_opened(conn_no) + elif event_type == 'data': + payload = event.get('payload', b'') + self._on_data(conn_no, payload) + elif event_type == 'connection_closed': + reason = event.get('reason', 0) + self._on_connection_closed(conn_no, reason) + else: + logging.warning(f"Unknown event type: {event_type}") + + def _on_connection_opened(self, conn_no: int): + """ + Handle connection_opened event from Rust. + + This is sent when the Gateway has acknowledged the OpenConnection + request and the virtual connection is now established. + + Flow: + 1. Python calls tube_registry.open_handler_connection(conversation_id, conn_no) + 2. Rust sends OpenConnection control frame to Gateway via WebRTC + 3. Gateway receives OpenConnection, starts guacd, connects to target + 4. Gateway sends ConnectionOpened back to Rust + 5. Rust notifies Python via this callback + 6. Gateway/guacd sends 'args' instruction (Guacamole handshake starts) + """ + logging.debug(f"✓ Connection opened: conn_no={conn_no}") + self.conn_no = conn_no + + # The connection is now ready for Guacamole protocol + # Gateway will send guacd's 'args' instruction next + + def _on_data(self, conn_no: int, payload: bytes): + """ + Handle data event from Rust. + + This contains pure Guacamole protocol data (no channel prefix). + The Rust layer has already stripped the frame header. + """ + if not payload: + return + + self.messages_received += 1 + self.bytes_received += len(payload) + + try: + # Decode Guacamole instructions + instructions_str = payload.decode('utf-8') + + # Log received data for debugging + if logging.getLogger().isEnabledFor(logging.DEBUG): + preview = instructions_str[:100] + logging.debug( + f"<<< GUACD DATA: {len(payload)} bytes, preview: {preview}" + f"{'...' if len(instructions_str) > 100 else ''}" + ) + + # Parse and dispatch instructions + self.parser.receive(instructions_str) + + except UnicodeDecodeError: + logging.debug(f"Binary data received ({len(payload)} bytes): {payload[:32].hex()}...") + + def _on_connection_closed(self, conn_no: int, reason: int): + """ + Handle connection_closed event from Rust. + + This is sent when the gateway/guacd closes the connection. + """ + reason_name = self._close_reason_name(reason) + logging.debug(f"Connection closed: conn_no={conn_no}, reason={reason} ({reason_name})") + + # Stop without sending disconnect (connection already closed) + self.stop(skip_disconnect=True) + + if self.on_disconnect: + try: + self.on_disconnect(reason_name) + except Exception as e: + logging.error(f"Error in disconnect callback: {e}") + + def _on_args(self, args: List[str]) -> None: + """ + Handle args instruction from guacd (via Gateway). + + This is the critical handshake step. When guacd receives 'select' from + the Gateway, it responds with 'args' listing the parameters it needs. + We must respond with 'connect' containing the parameter values, + followed by 'size', 'audio', and 'image' instructions. + + Guacamole handshake sequence: + 1. Gateway sends 'select ' to guacd + 2. guacd responds with 'args' (list of required params) + 3. We respond with 'connect' (param values), 'size', 'audio', 'image' + 4. guacd responds with 'ready' + + Args: + args: Parameter names that guacd expects (first is version, rest are params) + """ + if self.handshake_sent: + logging.debug(f"Ignoring duplicate 'args' instruction (handshake already sent)") + return + + logging.debug(f"✓ Received 'args' from guacd: {list(args)}") + + try: + # Build and send the handshake response + self._send_handshake_response(list(args)) + self.handshake_sent = True + logging.debug("✓ Guacamole handshake sent (connect+size+audio+image)") + except Exception as e: + logging.error(f"Error sending handshake response: {e}", exc_info=True) + + def _send_handshake_response(self, args_list: List[str]): + """ + Send the complete Guacamole handshake response. + + Args: + args_list: List of parameter names from guacd's 'args' instruction + """ + settings = self.connection_settings + + # Get terminal dimensions (default to standard CLI size) + width = settings.get('width', 800) + height = settings.get('height', 600) + dpi = settings.get('dpi', 96) + + # Get guacd parameters (hostname, port, username, password, etc.) + guacd_params = settings.get('guacd_params', {}) + + # Build connect args: first arg is version (from guacd), rest are param values + connect_args = [] + + # First arg from guacd is the version requirement + if args_list: + version = args_list[0] if args_list[0] else "VERSION_1_5_0" + connect_args.append(version) + + # For each remaining parameter guacd requested, provide the value + for param_name in args_list[1:]: + # Normalize param name for lookup (remove hyphens/underscores, lowercase) + normalized = param_name.replace('-', '').replace('_', '').lower() + + # Look up in guacd_params with various key formats + value = "" + for key in [param_name, normalized, param_name.replace('-', '_'), param_name.replace('_', '-')]: + if key in guacd_params: + value = str(guacd_params[key]) + break + # Also try lowercase version + if key.lower() in guacd_params: + value = str(guacd_params[key.lower()]) + break + + connect_args.append(value) + + # Send connect instruction + connect_instruction = self._format_instruction('connect', *connect_args) + self._send_to_gateway(connect_instruction) + logging.debug(f"Sent 'connect' with {len(connect_args)} args") + + # Send size instruction + size_instruction = self._format_instruction('size', width, height, dpi) + self._send_to_gateway(size_instruction) + logging.debug(f"Sent 'size': {width}x{height} @ {dpi}dpi") + + # Send audio instruction (supported audio mimetypes) + audio_mimetypes = settings.get('audio_mimetypes', []) + audio_instruction = self._format_instruction('audio', *audio_mimetypes) + self._send_to_gateway(audio_instruction) + logging.debug(f"Sent 'audio': {audio_mimetypes}") + + # Send video instruction (supported video mimetypes - usually empty for terminal) + video_mimetypes = settings.get('video_mimetypes', []) + video_instruction = self._format_instruction('video', *video_mimetypes) + self._send_to_gateway(video_instruction) + logging.debug(f"Sent 'video': {video_mimetypes}") + + # Send image instruction (supported image mimetypes) + image_mimetypes = settings.get('image_mimetypes', ['image/png', 'image/jpeg', 'image/webp']) + image_instruction = self._format_instruction('image', *image_mimetypes) + self._send_to_gateway(image_instruction) + logging.debug(f"Sent 'image': {image_mimetypes}") + + def _on_sync(self, args: List[str]) -> None: + """ + Handle sync instruction from guacd. + + Guacamole requires sync acknowledgments to maintain connection. + + IMPORTANT: The JS client uses the first sync as the TRUE readiness signal, + transitioning from WAITING to CONNECTED state. We follow the same pattern. + This is more reliable than waiting for 'ready' (which is a custom extension). + + Args: + args: [timestamp] or [timestamp, frames] + """ + timestamp = args[0] if args else "0" + frames = args[1] if len(args) > 1 else "0" + + self.last_sync_timestamp = timestamp + self.sync_count += 1 + + # First sync = TRUE connection ready (matches JS client behavior) + # JS Client.js line 1679: if (currentState === WAITING) setState(CONNECTED) + if self.sync_count == 1: + self.data_flowing.set() + logging.debug(f"* First sync received - connection ready (timestamp={timestamp})") + + # Call on_ready callback if not already called by 'ready' instruction + if self.on_ready and not self.handshake_complete.is_set(): + try: + self.on_ready() + except Exception as e: + logging.error(f"Error in ready callback: {e}") + + # Log but don't spam + logging.debug(f"SYNC #{self.sync_count}: timestamp={timestamp}, frames={frames}") + + # Send sync acknowledgment back to guacd + try: + ack = self._format_instruction('sync', timestamp) + self._send_to_gateway(ack) + except Exception as e: + logging.error(f"Error sending sync ack: {e}") + + def _on_ready(self, args: List[str]) -> None: + """ + Handle ready instruction from guacd. + + This indicates the Guacamole handshake is complete. + NOTE: This is a custom extension - the JS client doesn't have a 'ready' handler. + The TRUE readiness signal is the first 'sync' instruction. + + Args: + args: [connection_id] + """ + connection_id = args[0] if args else "" + self.guac_connection_id = connection_id + self.handshake_complete.set() + + # Also signal data_flowing for compatibility (in case ready comes before sync) + # This ensures wait_for_ready() returns true on either signal + self.data_flowing.set() + + logging.debug(f"✓ Guacamole ready! Connection established: connection_id={connection_id}") + + if self.on_ready: + try: + self.on_ready() + except Exception as e: + logging.error(f"Error in ready callback: {e}") + + def _on_guac_disconnect(self, args: List[str]) -> None: + """Handle disconnect instruction from guacd.""" + logging.debug(f"Server sent disconnect instruction (args: {args})") + + # Stop without sending disconnect (server already disconnected) + self.stop(skip_disconnect=True) + + if self.on_disconnect: + try: + self.on_disconnect("server_disconnect") + except Exception as e: + logging.error(f"Error in disconnect callback: {e}") + + def _on_error(self, args: List[str]) -> None: + """Handle error instruction from guacd.""" + message = args[0] if args else "Unknown error" + code = args[1] if len(args) > 1 else "0" + + logging.error(f"Guacamole error {code}: {message}") + + def _format_instruction(self, *elements) -> bytes: + """Format elements into a Guacamole instruction.""" + # Use the new guacamole module's to_instruction function + # It takes a list, returns str, we encode to bytes + instruction_str = to_instruction(list(elements)) + return instruction_str.encode('utf-8') + + def _send_ack(self, stream_index: str, message: str, code: str): + """ + Send ack instruction for stream acknowledgment. + + Used by the instruction router to acknowledge pipe/blob instructions. + + Args: + stream_index: Stream index (as string) + message: Acknowledgment message (usually "OK") + code: Status code (usually "0" for success) + """ + try: + instruction = self._format_instruction('ack', stream_index, message, code) + self._send_to_gateway(instruction) + logging.debug(f"ACK sent: stream={stream_index}, message={message}, code={code}") + except Exception as e: + logging.error(f"Error sending ack: {e}") + + def _send_to_gateway(self, data: bytes): + """ + Send Guacamole data back to gateway via Rust. + + Uses send_handler_data() which routes through the WebRTC channel. + """ + if isinstance(data, str): + data = data.encode('utf-8') + + try: + self.tube_registry.send_handler_data( + self.conversation_id, + self.conn_no, + data + ) + self.messages_sent += 1 + self.bytes_sent += len(data) + logging.debug(f">>> GUACD SEND: {len(data)} bytes") + except Exception as e: + logging.error(f"Failed to send data to gateway: {e}") + raise + + def send_stdin(self, data: bytes): + """ + Send stdin data to guacd using the pipe/blob/end pattern. + + This is the preferred method for plaintext SSH/TTY streams. + It matches the kcm-cli implementation: + - pipe,0,text/plain,STDIN (open stream) + - blob,0, (send data) + - end,0 (close stream) + + Only sends if session is active (running and data flowing). + + Args: + data: Raw bytes to send as stdin (e.g., keyboard input) + """ + # Guard: only send during active session + if not self.running: + logging.debug("Ignoring stdin - handler not running") + return + if not self.data_flowing.is_set(): + logging.debug("Ignoring stdin - connection not ready") + return + + try: + # Use stream index 0 for STDIN (matching kcm-cli) + stream_index = '0' + + # Send pipe instruction to open STDIN stream + pipe_instruction = self._format_instruction('pipe', stream_index, 'text/plain', 'STDIN') + self._send_to_gateway(pipe_instruction) + + # Send blob with base64-encoded data + data_base64 = base64.b64encode(data).decode('ascii') + blob_instruction = self._format_instruction('blob', stream_index, data_base64) + self._send_to_gateway(blob_instruction) + + # Send end to close the stream + end_instruction = self._format_instruction('end', stream_index) + self._send_to_gateway(end_instruction) + + # Log for debugging + if logging.getLogger().isEnabledFor(logging.DEBUG): + preview = data[:20].decode('utf-8', errors='replace') if len(data) <= 20 else data[:20].decode('utf-8', errors='replace') + '...' + logging.debug(f"STDIN: sent {len(data)} bytes: {repr(preview)}") + + except Exception as e: + logging.error(f"Error sending stdin: {e}") + + def send_key(self, keysym: int, pressed: bool): + """ + Send a key event to guacd using X11 keysym. + + NOTE: For plaintext SSH/TTY streams, use send_stdin() instead. + This method is for graphical protocols (RDP, VNC) that use X11 keysyms. + + Only sends if session is active (running and data flowing). + + Args: + keysym: X11 keysym value + pressed: True for press, False for release + """ + # Guard: only send keys during active session + if not self.running: + logging.debug(f"Ignoring key event - handler not running") + return + if not self.data_flowing.is_set(): + logging.debug(f"Ignoring key event - connection not ready") + return + + try: + instruction = self._format_instruction('key', keysym, 1 if pressed else 0) + self._send_to_gateway(instruction) + # Log key events for debugging (only press, not release to reduce spam) + if pressed: + # Show printable chars, hex for control/special keys + if 32 <= keysym < 127: + logging.debug(f"KEY: '{chr(keysym)}' (0x{keysym:04X})") + else: + logging.debug(f"KEY: 0x{keysym:04X} (special)") + except Exception as e: + logging.error(f"Error sending key event: {e}") + + def send_mouse(self, x: int, y: int, buttons: int = 0): + """ + Send a mouse event to guacd. + + Only sends if session is active (running and data flowing). + + Args: + x: X coordinate + y: Y coordinate + buttons: Button mask + """ + if not self.running or not self.data_flowing.is_set(): + return + + try: + instruction = self._format_instruction('mouse', x, y, buttons) + self._send_to_gateway(instruction) + except Exception as e: + logging.error(f"Error sending mouse event: {e}") + + def send_size(self, width: int, height: int, dpi: int = 96): + """ + Send terminal size to guacd. + + Only sends if session is active (running and data flowing). + + Args: + width: Width in pixels + height: Height in pixels + dpi: DPI (default 96) + """ + if not self.running or not self.data_flowing.is_set(): + return + + try: + instruction = self._format_instruction('size', width, height, dpi) + self._send_to_gateway(instruction) + except Exception as e: + logging.error(f"Error sending size: {e}") + + def send_clipboard(self, text: str): + """ + Send clipboard data to guacd. + + Only sends if session is active (running and data flowing). + + Args: + text: Clipboard text + """ + if not self.running or not self.data_flowing.is_set(): + return + + try: + instruction = self._format_instruction('clipboard', 'text/plain', text) + self._send_to_gateway(instruction) + except Exception as e: + logging.error(f"Error sending clipboard: {e}") + + def wait_for_ready(self, timeout: float = 10.0) -> bool: + """ + Wait for the Guacamole connection to be ready. + + Connection is considered ready when: + - First 'sync' instruction is received (matches JS client behavior), OR + - 'ready' instruction is received (custom extension) + + The JS Guacamole client (guacamole-common-js) considers the connection + CONNECTED when the first sync is received, not when 'ready' is received. + We follow the same pattern for reliability. + + Args: + timeout: Maximum seconds to wait (default: 10.0, was 30.0) + Handshake typically completes in <500ms on normal networks. + + Returns: + True if ready (sync or ready received), False if timeout + """ + import time + start = time.time() + + result = self.connection_ready.wait(timeout) + + elapsed = time.time() - start + if result: + logging.debug(f"Connection ready after {elapsed:.3f}s") + else: + # Provide diagnostic info on timeout + logging.warning( + f"Timeout after {elapsed:.1f}s waiting for ready - " + f"received {self.messages_received} messages ({self.bytes_received} bytes), " + f"syncs={self.sync_count}, handshake_sent={self.handshake_sent}" + ) + + return result + + def is_data_flowing(self) -> bool: + """ + Check if data is flowing (sync messages being received). + + Returns: + True if at least one sync has been received + """ + return self.sync_count > 0 + + @staticmethod + def _close_reason_name(reason: int) -> str: + """Convert close reason code to name.""" + reasons = { + 0: "unknown", + 1: "normal", + 2: "timeout", + 3: "error", + 4: "refused", + 5: "unreachable", + 6: "reset", + } + return reasons.get(reason, f"code_{reason}") + + def __enter__(self): + """Context manager entry.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.stop() + return False + + +def create_handler_callback(handler: GuacamoleHandler) -> Callable[[List[Dict]], None]: + """ + Create a callback function for the Rust PythonHandler. + + This wraps the handler's handle_events method in a function + that can be passed to the Rust create_tube() call. + + Args: + handler: GuacamoleHandler instance + + Returns: + Callback function that accepts a list of event dicts + """ + def callback(events: List[Dict[str, Any]]): + handler.handle_events(events) + + return callback + + +def create_python_handler( + tube_registry, + conversation_id: str, + conn_no: int = 1, + connection_settings: Optional[Dict[str, Any]] = None, + on_ready: Optional[Callable[[], None]] = None, + on_disconnect: Optional[Callable[[str], None]] = None, +) -> tuple: + """ + Create a PythonHandler callback and handler for Guacamole CLI. + + This is the main entry point for setting up PythonHandler mode. + It creates both the handler instance and the callback function + that should be passed to tube_registry.create_tube(). + + Args: + tube_registry: PyTubeRegistry instance + conversation_id: Conversation/channel ID + conn_no: Connection number (default: 1) + connection_settings: Connection parameters for Guacamole handshake: + - protocol: Protocol type (ssh, telnet, mysql, etc.) + - hostname: Target hostname + - port: Target port + - width: Terminal width in pixels + - height: Terminal height in pixels + - dpi: Display DPI (default 96) + - audio_mimetypes: List of supported audio types + - image_mimetypes: List of supported image types + - guacd_params: Dict of guacd connection parameters + on_ready: Optional callback when Guacamole connection is ready + on_disconnect: Optional callback when connection closes + + Returns: + Tuple of (callback_function, handler_instance) + + Example: + callback, handler = create_python_handler( + tube_registry, + conversation_id, + connection_settings={ + 'protocol': 'ssh', + 'width': 800, + 'height': 600, + 'dpi': 96, + 'guacd_params': { + 'hostname': '192.168.1.100', + 'port': '22', + 'username': 'admin', + 'password': 'secret', + } + }, + on_ready=lambda: print("Connected!"), + on_disconnect=lambda reason: print(f"Disconnected: {reason}") + ) + + # Pass callback to create_tube + result = tube_registry.create_tube( + conversation_id=conversation_id, + settings={...}, + handler_callback=callback, + ... + ) + + # Start handler + handler.start() + + # Wait for connection + if handler.wait_for_ready(timeout=30): + print("Guacamole session ready!") + """ + handler = GuacamoleHandler( + tube_registry=tube_registry, + conversation_id=conversation_id, + conn_no=conn_no, + connection_settings=connection_settings, + on_ready=on_ready, + on_disconnect=on_disconnect, + ) + + callback = create_handler_callback(handler) + + return callback, handler diff --git a/keepercommander/commands/pam_launch/terminal_connection.py b/keepercommander/commands/pam_launch/terminal_connection.py new file mode 100644 index 000000000..64a25b638 --- /dev/null +++ b/keepercommander/commands/pam_launch/terminal_connection.py @@ -0,0 +1,1281 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' Dict[str, int]: + """Convert character columns/rows into pixel measurements for the Gateway.""" + col_value = columns if isinstance(columns, int) and columns > 0 else DEFAULT_TERMINAL_COLUMNS + row_value = rows if isinstance(rows, int) and rows > 0 else DEFAULT_TERMINAL_ROWS + return { + "columns": col_value, + "rows": row_value, + "pixel_width": col_value * DEFAULT_CELL_WIDTH_PX, + "pixel_height": row_value * DEFAULT_CELL_HEIGHT_PX, + "dpi": DEFAULT_SCREEN_DPI, + } + + +DEFAULT_SCREEN_INFO = _build_screen_info(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS) + +MAX_MESSAGE_SIZE_LINE = "a=max-message-size:1073741823" + + +def _ensure_max_message_size_attribute(sdp_offer: Optional[str]) -> Optional[str]: + """ + Ensure the SDP offer advertises the same max-message-size attribute as Web Vault. + + Args: + sdp_offer: Original SDP offer string returned by the WebRTC module. + + Returns: + SDP string containing the MAX_MESSAGE_SIZE_LINE (added if it was missing). + """ + if not sdp_offer or MAX_MESSAGE_SIZE_LINE in sdp_offer: + return sdp_offer + + newline = "\r\n" if "\r\n" in sdp_offer else "\n" + insert_location = None + lower_offer = sdp_offer.lower() + + # Prefer to inject directly after the SCTP port attribute to mimic Web Vault ordering. + sctp_idx = lower_offer.find("a=sctp-port:") + if sctp_idx != -1: + after_sctp = sdp_offer.find(newline, sctp_idx) + if after_sctp == -1: + insert_location = len(sdp_offer) + else: + insert_location = after_sctp + len(newline) + else: + # If the SCTP line is missing, try to inject immediately after the datachannel media line. + media_idx = lower_offer.find("m=application") + if media_idx != -1: + after_media = sdp_offer.find(newline, media_idx) + if after_media == -1: + insert_location = len(sdp_offer) + else: + insert_location = after_media + len(newline) + + if insert_location is None: + # Append at the end, keeping the existing newline style and ensuring we end with one blank line. + suffix = "" if sdp_offer.endswith(newline) else newline + updated_offer = f"{sdp_offer}{suffix}{MAX_MESSAGE_SIZE_LINE}{newline}" + else: + updated_offer = ( + sdp_offer[:insert_location] + + f"{MAX_MESSAGE_SIZE_LINE}{newline}" + + sdp_offer[insert_location:] + ) + + logging.debug("Injected `%s` into SDP offer to match Web Vault behavior", MAX_MESSAGE_SIZE_LINE) + return updated_offer + + +def _notify_gateway_connection_close(params, router_token, terminated=True): + """ + Notify the gateway/router that a WebRTC session should be closed. + + This mirrors the gateway's own POST to /api/device/connect_state so that + stale tubes are cleaned up when Commander aborts before a session fully starts. + + Note: gateway_cookies parameter was removed in commit 338a9fda as router + affinity is now handled server-side. + """ + if not router_token: + logging.debug("Skipping connection_close notification - router_token missing") + return + + try: + router_url = get_router_url(params) + payload = { + "token": router_token, + "type": "connection_close", + } + if terminated is not None: + payload["terminated"] = terminated + + response = requests.post( + f"{router_url}/api/device/connect_state", + json=payload, + verify=VERIFY_SSL, + timeout=10, + ) + if response.status_code >= 400: + logging.warning( + "Gateway connection_close notification failed (%s): %s", + response.status_code, + response.text, + ) + else: + logging.debug("Sent connection_close notification for router token") + except Exception as notify_err: + logging.debug(f"Failed to notify gateway about connection_close: {notify_err}") + + +def detect_protocol(params: KeeperParams, record_uid: str) -> Optional[str]: + """ + Detect the terminal protocol from a PAM record. + + Args: + params: KeeperParams instance + record_uid: Record UID + + Returns: + Protocol string (ssh, telnet, kubernetes, mysql, postgresql, sqlserver) or None + + Raises: + CommandError: If record type is not supported or protocol cannot be determined + """ + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + raise CommandError('pam launch', f'Record {record_uid} is not a TypedRecord') + + record_type = record.record_type + + # pamMachine -> SSH or Telnet + if record_type == 'pamMachine': + # Check if telnet is explicitly configured + # Look for telnet-specific fields or settings + pam_settings = record.get_typed_field('pamSettings') + if pam_settings: + settings_value = pam_settings.get_default_value(dict) + if settings_value: + connection = settings_value.get('connection', {}) + if isinstance(connection, dict): + # Check for telnet protocol indicator + protocol_field = connection.get('protocol') + if protocol_field and 'telnet' in str(protocol_field).lower(): + return ProtocolType.TELNET + + # Default to SSH for pamMachine + return ProtocolType.SSH + + # pamDirectory -> Kubernetes + elif record_type == 'pamDirectory': + return ProtocolType.KUBERNETES + + # pamDatabase -> MySQL, PostgreSQL, or SQL Server + elif record_type == 'pamDatabase': + # Inspect the database type field + pam_settings = record.get_typed_field('pamSettings') + if pam_settings: + settings_value = pam_settings.get_default_value(dict) + if settings_value: + connection = settings_value.get('connection', {}) + if isinstance(connection, dict): + db_type = connection.get('databaseType', '').lower() + + if 'mysql' in db_type: + return ProtocolType.MYSQL + elif 'postgres' in db_type or 'postgresql' in db_type: + return ProtocolType.POSTGRESQL + elif 'sql server' in db_type or 'sqlserver' in db_type or 'mssql' in db_type: + return ProtocolType.SQLSERVER + + # Try to infer from port if database type not specified + hostname_field = record.get_typed_field('pamHostname') + if hostname_field: + host_value = hostname_field.get_default_value(dict) + if host_value: + port = host_value.get('port') + if port: + port_int = int(port) if isinstance(port, str) else port + if port_int == 3306: + return ProtocolType.MYSQL + elif port_int == 5432: + return ProtocolType.POSTGRESQL + elif port_int == 1433: + return ProtocolType.SQLSERVER + + # Default to MySQL if we can't determine + logging.warning(f"Could not determine database type for record {record_uid}, defaulting to MySQL") + return ProtocolType.MYSQL + + else: + raise CommandError('pam launch', + f'Record type "{record_type}" is not supported for terminal connections. ' + f'Supported types: pamMachine, pamDirectory, pamDatabase') + + +def extract_terminal_settings(params: KeeperParams, record_uid: str, protocol: str) -> Dict[str, Any]: + """ + Extract terminal connection settings from a PAM record. + + Args: + params: KeeperParams instance + record_uid: Record UID + protocol: Protocol type (from detect_protocol) + + Returns: + Dictionary containing terminal settings: + - hostname: Target hostname + - port: Target port + - clipboard: {disableCopy: bool, disablePaste: bool} + - terminal: {colorScheme: str, fontSize: str} + - recording: {includeKeys: bool} + - protocol_specific: Protocol-specific settings dict + + Raises: + CommandError: If required fields are missing + """ + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + raise CommandError('pam launch', f'Record {record_uid} is not a TypedRecord') + + settings = { + 'hostname': None, + 'port': None, + 'clipboard': {'disableCopy': False, 'disablePaste': False}, + 'terminal': {'colorScheme': 'gray-black', 'fontSize': '12'}, + 'recording': {'includeKeys': False}, + 'protocol_specific': {} + } + + # Extract hostname and port + hostname_field = record.get_typed_field('pamHostname') + if not hostname_field: + raise CommandError('pam launch', f'No hostname configured for record {record_uid}') + + host_value = hostname_field.get_default_value(dict) + if not host_value: + raise CommandError('pam launch', f'Invalid hostname configuration for record {record_uid}') + + settings['hostname'] = host_value.get('hostName') + if not settings['hostname']: + raise CommandError('pam launch', f'Hostname not found in record {record_uid}') + + # Get port (use default if not specified) + port_value = host_value.get('port') + if port_value: + settings['port'] = int(port_value) if isinstance(port_value, str) else port_value + else: + settings['port'] = DEFAULT_PORTS.get(protocol, 22) + + # Extract PAM settings + pam_settings_field = record.get_typed_field('pamSettings') + if pam_settings_field: + pam_settings_value = pam_settings_field.get_default_value(dict) + if pam_settings_value: + connection = pam_settings_value.get('connection', {}) + if isinstance(connection, dict): + # Clipboard settings + settings['clipboard']['disableCopy'] = connection.get('disableCopy', False) + settings['clipboard']['disablePaste'] = connection.get('disablePaste', False) + + # Terminal display settings + color_scheme = connection.get('colorScheme') + if color_scheme: + settings['terminal']['colorScheme'] = color_scheme + + font_size = connection.get('fontSize') + if font_size: + settings['terminal']['fontSize'] = str(font_size) + + # Recording settings + settings['recording']['includeKeys'] = connection.get('recordingIncludeKeys', False) + + # Protocol-specific settings + if protocol == ProtocolType.SSH: + settings['protocol_specific'] = _extract_ssh_settings(connection) + elif protocol == ProtocolType.TELNET: + settings['protocol_specific'] = _extract_telnet_settings(connection) + elif protocol == ProtocolType.KUBERNETES: + settings['protocol_specific'] = _extract_kubernetes_settings(connection) + elif protocol in ProtocolType.DATABASE: + settings['protocol_specific'] = _extract_database_settings(connection, protocol) + + return settings + + +def _extract_ssh_settings(connection: Dict[str, Any]) -> Dict[str, Any]: + """Extract SSH-specific settings""" + return { + 'publicHostKey': connection.get('publicHostKey', ''), + 'executeCommand': connection.get('executeCommand', ''), + 'sftpEnabled': connection.get('sftpEnabled', False), + } + + +def _extract_telnet_settings(connection: Dict[str, Any]) -> Dict[str, Any]: + """Extract Telnet-specific settings""" + return { + 'usernameRegex': connection.get('usernameRegex', ''), + 'passwordRegex': connection.get('passwordRegex', ''), + } + + +def _extract_kubernetes_settings(connection: Dict[str, Any]) -> Dict[str, Any]: + """Extract Kubernetes-specific settings""" + return { + 'namespace': connection.get('namespace', 'default'), + 'pod': connection.get('pod', ''), + 'container': connection.get('container', ''), + 'ignoreServerCertificate': connection.get('ignoreServerCertificate', False), + 'caCertificate': connection.get('caCertificate', ''), + 'clientCertificate': connection.get('clientCertificate', ''), + 'clientKey': connection.get('clientKey', ''), + } + + +def _extract_database_settings(connection: Dict[str, Any], protocol: str) -> Dict[str, Any]: + """Extract database-specific settings""" + settings = { + 'defaultDatabase': connection.get('defaultDatabase', ''), + 'disableCsvExport': connection.get('disableCsvExport', False), + 'disableCsvImport': connection.get('disableCsvImport', False), + } + + # Add protocol-specific database settings + if protocol == ProtocolType.MYSQL: + settings['useSSL'] = connection.get('useSSL', False) + elif protocol == ProtocolType.POSTGRESQL: + settings['useSSL'] = connection.get('useSSL', False) + elif protocol == ProtocolType.SQLSERVER: + settings['useSSL'] = connection.get('useSSL', True) # SQL Server typically uses SSL by default + + return settings + + +def create_connection_context(params: KeeperParams, + record_uid: str, + gateway_uid: str, + protocol: str, + settings: Dict[str, Any], + connect_as: Optional[str] = None) -> Dict[str, Any]: + """ + Build connection context for WebRTC tunnel. + + Args: + params: KeeperParams instance + record_uid: Record UID + gateway_uid: Gateway UID + protocol: Protocol type + settings: Terminal settings from extract_terminal_settings + connect_as: Optional username to connect as (overrides record) + + Returns: + Connection context dictionary ready for tunnel opening + """ + context = { + 'protocol': protocol, + 'recordUid': record_uid, + 'controllerUid': gateway_uid, + 'targetHost': { + 'hostname': settings['hostname'], + 'port': settings['port'] + }, + 'clipboard': settings['clipboard'], + 'terminal': settings['terminal'], + 'recording': settings['recording'], + 'connectAs': connect_as, + 'conversationType': _get_conversation_type(protocol), + } + + # Add protocol-specific settings + if protocol == ProtocolType.SSH: + context['ssh'] = settings['protocol_specific'] + elif protocol == ProtocolType.TELNET: + context['telnet'] = settings['protocol_specific'] + elif protocol == ProtocolType.KUBERNETES: + context['kubernetes'] = settings['protocol_specific'] + elif protocol in ProtocolType.DATABASE: + context['database'] = settings['protocol_specific'] + context['database']['type'] = protocol + + return context + + +def _get_conversation_type(protocol: str) -> str: + """Map protocol to Guacamole conversation type""" + # Map our protocol names to Guacamole conversation types + mapping = { + ProtocolType.SSH: 'ssh', + ProtocolType.TELNET: 'telnet', + ProtocolType.KUBERNETES: 'kubernetes', + ProtocolType.MYSQL: 'mysql', + ProtocolType.POSTGRESQL: 'postgresql', + ProtocolType.SQLSERVER: 'sql-server', + } + return mapping.get(protocol, protocol) + + +def _build_guacamole_connection_settings( + params: 'KeeperParams', + record_uid: str, + protocol: str, + settings: Dict[str, Any], + context: Dict[str, Any], + screen_info: Dict[str, int], +) -> Dict[str, Any]: + """ + Build connection settings for Guacamole handshake in PythonHandler mode. + + When guacd sends 'args' instruction requesting connection parameters, + we respond with 'connect' containing these values. + + Args: + params: KeeperParams instance + record_uid: Record UID + protocol: Protocol type (ssh, telnet, mysql, etc.) + settings: Terminal settings from extract_terminal_settings() + context: Connection context from create_connection_context() + screen_info: Screen dimensions dict + + Returns: + Dictionary with connection settings for GuacamoleHandler + """ + # Get credentials from the record + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + raise CommandError('pam launch', f'Record {record_uid} is not a TypedRecord') + + # Extract login credentials + login_field = record.get_typed_field('login') + username = '' + if login_field: + username = login_field.get_default_value(str) or '' + + password_field = record.get_typed_field('password') + password = '' + if password_field: + password = password_field.get_default_value(str) or '' + + # Build guacd parameters dictionary + # These map to guacd's expected parameter names + # The 'protocol' field is required for guacd to know which backend to use + guacd_protocol = _get_conversation_type(protocol) # Convert to guacd protocol name (e.g., ssh, telnet) + guacd_params = { + 'protocol': guacd_protocol, # Required: tells guacd which protocol handler to use + 'hostname': settings.get('hostname', ''), + 'port': str(settings.get('port', '')), + 'username': username, + 'password': password, + } + + # Add protocol-specific parameters + protocol_specific = settings.get('protocol_specific', {}) + + if protocol == ProtocolType.SSH: + # SSH-specific params + if protocol_specific.get('publicHostKey'): + guacd_params['host-key'] = protocol_specific['publicHostKey'] + if protocol_specific.get('executeCommand'): + guacd_params['command'] = protocol_specific['executeCommand'] + # Enable SFTP if configured + if protocol_specific.get('sftpEnabled'): + guacd_params['enable-sftp'] = 'true' + + elif protocol == ProtocolType.TELNET: + # Telnet-specific params + if protocol_specific.get('usernameRegex'): + guacd_params['username-regex'] = protocol_specific['usernameRegex'] + if protocol_specific.get('passwordRegex'): + guacd_params['password-regex'] = protocol_specific['passwordRegex'] + + elif protocol == ProtocolType.KUBERNETES: + # Kubernetes-specific params + if protocol_specific.get('namespace'): + guacd_params['namespace'] = protocol_specific['namespace'] + if protocol_specific.get('pod'): + guacd_params['pod'] = protocol_specific['pod'] + if protocol_specific.get('container'): + guacd_params['container'] = protocol_specific['container'] + if protocol_specific.get('caCertificate'): + guacd_params['ca-cert'] = protocol_specific['caCertificate'] + if protocol_specific.get('clientCertificate'): + guacd_params['client-cert'] = protocol_specific['clientCertificate'] + if protocol_specific.get('clientKey'): + guacd_params['client-key'] = protocol_specific['clientKey'] + if protocol_specific.get('ignoreServerCertificate'): + guacd_params['ignore-cert'] = 'true' + + elif protocol in ProtocolType.DATABASE: + # Database-specific params + if protocol_specific.get('defaultDatabase'): + guacd_params['database'] = protocol_specific['defaultDatabase'] + + # Terminal display settings + terminal_settings = settings.get('terminal', {}) + if terminal_settings.get('colorScheme'): + guacd_params['color-scheme'] = terminal_settings['colorScheme'] + if terminal_settings.get('fontSize'): + guacd_params['font-size'] = terminal_settings['fontSize'] + + # Clipboard settings + clipboard_settings = settings.get('clipboard', {}) + if clipboard_settings.get('disableCopy'): + guacd_params['disable-copy'] = 'true' + if clipboard_settings.get('disablePaste'): + guacd_params['disable-paste'] = 'true' + + # Build final connection settings + connection_settings = { + 'protocol': protocol, + 'hostname': settings.get('hostname', ''), + 'port': settings.get('port', 22), + 'width': screen_info.get('pixel_width', 800), + 'height': screen_info.get('pixel_height', 600), + 'dpi': screen_info.get('dpi', 96), + 'guacd_params': guacd_params, + # Supported mimetypes for terminal sessions + 'audio_mimetypes': [], # No audio for terminal + 'video_mimetypes': [], # No video for terminal + 'image_mimetypes': ['image/png', 'image/jpeg', 'image/webp'], + } + + logging.debug(f"Built Guacamole connection settings for {protocol}: " + f"hostname={settings.get('hostname')}, port={settings.get('port')}, " + f"width={connection_settings['width']}x{connection_settings['height']}") + + return connection_settings + + +def _open_terminal_webrtc_tunnel(params: KeeperParams, + record_uid: str, + gateway_uid: str, + protocol: str, + settings: Dict[str, Any], + context: Dict[str, Any], + **kwargs) -> Dict[str, Any]: + """ + Open a WebRTC tunnel for terminal/Guacamole connection. + + This function adapts start_rust_tunnel for terminal protocols by: + - Using the protocol-specific conversation type + - Not requiring local socket listening (Guacamole renders server-side) + - Setting up for text/image streaming only (no audio/video) + + Args: + params: KeeperParams instance + record_uid: Record UID + gateway_uid: Gateway UID + protocol: Protocol type (ssh, telnet, etc.) + settings: Terminal settings + context: Connection context + + Returns: + Dictionary with tunnel information: + - success: bool + - tube_id: str + - conversation_id: str + - tube_registry: PyTubeRegistry + - signal_handler: TunnelSignalHandler + - websocket_thread: Thread + - error: error message if failed + """ + logging.debug(f"{bcolors.HIGHINTENSITYWHITE}Establishing {protocol.upper()} terminal connection via WebRTC...{bcolors.ENDC}") + screen_info = DEFAULT_SCREEN_INFO + + try: + router_token = None + + # Get encryption seed from record + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + return {"success": False, "error": "Invalid record type"} + + # Get traffic encryption seed + seed_field = record.get_typed_field('trafficEncryptionSeed') + if seed_field: + seed = seed_field.get_default_value() + if isinstance(seed, str): + seed = base64_to_bytes(seed) + else: + # Generate a random seed if not present + import secrets + seed = secrets.token_bytes(32) + logging.debug("No trafficEncryptionSeed found, using generated seed") + + # Generate 128-bit (16-byte) random nonce + nonce = os.urandom(MAIN_NONCE_LENGTH) + + # Derive the encryption key using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=SYMMETRIC_KEY_LENGTH, # 256-bit key + salt=nonce, + info=b"KEEPER_TUNNEL_ENCRYPT_AES_GCM_128", + backend=default_backend() + ).derive(seed) + symmetric_key = AESGCM(hkdf) + + # Get tube registry (Rust WebRTC library) + tube_registry = get_or_create_tube_registry(params) + if not tube_registry: + return {"success": False, "error": "Rust WebRTC library (keeper_pam_webrtc_rs) not available"} + + # For terminal connections, we act as client (not server mode) + tube_registry.set_server_mode(False) + + # Generate conversation ID + conversation_id_original = GatewayAction.generate_conversation_id() + conversation_id_bytes = url_safe_str_to_bytes(conversation_id_original) + conversation_id = base64.b64encode(conversation_id_bytes).decode('utf-8') + + logging.debug(f"Generated conversation_id_original: {conversation_id_original}") + logging.debug(f"Base64 encoded conversation_id: {conversation_id}") + + base64_nonce = bytes_to_base64(nonce) + + # Get relay server configuration + relay_url = 'krelay.' + params.server + krelay_url = os.getenv('KRELAY_URL') + if krelay_url: + relay_url = krelay_url + + response = router_get_relay_access_creds(params=params, expire_sec=60000000) + if response is None: + return {"success": False, "error": "Failed to get relay access credentials"} + + # Create WebRTC settings for terminal (no local socket needed) + webrtc_settings = { + "turn_only": False, + "relay_url": relay_url, + "stun_url": f"stun:{relay_url}:3478", + "turn_url": f"turn:{relay_url}:3478", + "turn_username": response.username, + "turn_password": response.password, + "conversationType": context['conversationType'], # ssh, telnet, kubernetes, mysql, etc. + "local_listen_addr": "", # No local socket for terminal + "target_host": settings['hostname'], + "target_port": settings['port'], + "socks_mode": False, # Terminal connections don't use SOCKS + "control_channel_label": "control", # Ensure WebRTC data channel label matches gateway expectation + "callback_token": bytes_to_base64(nonce) + } + + # Debug: Log settings to verify control_channel_label is present + logging.debug(f"WebRTC settings before create_tube: {json.dumps(webrtc_settings, default=str)}") + + # Register the encryption key in the global conversation store + register_conversation_key(conversation_id, symmetric_key) + # Create a temporary tunnel session + import uuid + temp_tube_id = str(uuid.uuid4()) + + # Pre-create tunnel session to buffer early ICE candidates + conversation_type = context.get('conversationType', protocol) + + tunnel_session = TunnelSession( + tube_id=temp_tube_id, + conversation_id=conversation_id, + gateway_uid=gateway_uid, + symmetric_key=symmetric_key, + offer_sent=False, + host=None, # No local host for terminal + port=None # No local port for terminal + ) + + # Register the temporary session + register_tunnel_session(temp_tube_id, tunnel_session) + + # Determine trickle ICE setting from kwargs + no_trickle_ice = kwargs.get('no_trickle_ice', False) + trickle_ice = not no_trickle_ice + + # Create signal handler for Rust events + signal_handler = TunnelSignalHandler( + params=params, + record_uid=record_uid, + gateway_uid=gateway_uid, + symmetric_key=symmetric_key, + base64_nonce=base64_nonce, + conversation_id=conversation_id, + tube_registry=tube_registry, + tube_id=temp_tube_id, + trickle_ice=trickle_ice + ) + + # Store signal handler reference + tunnel_session.signal_handler = signal_handler # type: ignore[assignment] + + logging.debug(f"{bcolors.OKBLUE}Creating WebRTC offer for {protocol} connection...{bcolors.ENDC}") + if trickle_ice: + logging.debug("Using trickle ICE for real-time candidate exchange") + else: + logging.debug("Trickle ICE disabled - using standard ICE") + + # Check if PythonHandler mode is requested + use_python_handler = kwargs.get('use_python_handler', True) # Default to True for new mode + python_handler = None + handler_callback = None + + if use_python_handler: + # Import and create PythonHandler for simplified Guacamole protocol handling + from .python_handler import create_python_handler + + logging.debug("Using PythonHandler mode - Rust handles control frames automatically") + + # Set conversationType to "python_handler" to enable PythonHandler protocol mode in Rust + # The actual protocol (ssh, telnet, etc.) is passed via guacd_params["protocol"] + webrtc_settings["conversationType"] = "python_handler" + logging.debug(f"Set conversationType to 'python_handler' (actual protocol: {protocol})") + + # Build connection settings for Guacamole handshake + # These are used when guacd sends 'args' instruction + connection_settings = _build_guacamole_connection_settings( + params=params, + record_uid=record_uid, + protocol=protocol, + settings=settings, + context=context, + screen_info=screen_info, + ) + + # Create the handler and callback + handler_callback, python_handler = create_python_handler( + tube_registry=tube_registry, + conversation_id=conversation_id, + conn_no=1, + connection_settings=connection_settings, + ) + + logging.debug(f"Created PythonHandler for conversation {conversation_id}") + logging.debug(f"DEBUG: handler_callback is {'SET' if handler_callback else 'None'}, type={type(handler_callback)}") + logging.debug(f"DEBUG: python_handler is {'SET' if python_handler else 'None'}") + logging.debug(f"DEBUG: connection_settings has {len(connection_settings)} keys: {list(connection_settings.keys())}") + + # Create the tube to get the WebRTC offer + logging.debug(f"DEBUG: Calling create_tube with handler_callback={'SET' if handler_callback else 'None'}") + logging.debug(f"DEBUG: Calling create_tube with handler_callback={'SET' if handler_callback else 'None'}") + logging.debug(f"DEBUG: webrtc_settings['conversationType'] = {webrtc_settings.get('conversationType')}") + offer = tube_registry.create_tube( + conversation_id=conversation_id, + settings=webrtc_settings, + trickle_ice=trickle_ice, + callback_token=webrtc_settings["callback_token"], + ksm_config="", + krelay_server=relay_url, + client_version="Commander-Python-Terminal", + offer=None, # Let Rust create the offer + signal_callback=signal_handler.signal_from_rust, + handler_callback=handler_callback, # PythonHandler callback (None if not using) + ) + + if not offer or 'tube_id' not in offer or 'offer' not in offer: + error_msg = "Failed to create tube" + if offer: + error_msg = offer.get('error', error_msg) + # Clean up temporary session on failure + unregister_tunnel_session(temp_tube_id) + unregister_conversation_key(conversation_id) + return {"success": False, "error": error_msg} + + commander_tube_id = offer['tube_id'] + logging.debug(f"Created tube with ID: {commander_tube_id}") + logging.debug(f"Conversation ID for this tube: {conversation_id_original}") + logging.debug(f"Data channel will be named: {conversation_id}") + + # Update signal handler and tunnel session with real tube ID + signal_handler.tube_id = commander_tube_id + tunnel_session.tube_id = commander_tube_id + + # Unregister temporary session and register with real tube ID + unregister_tunnel_session(temp_tube_id) + register_tunnel_session(commander_tube_id, tunnel_session) + + logging.debug(f"Registered encryption key for conversation: {conversation_id}") + logging.debug(f"Expecting WebSocket responses for conversation ID: {conversation_id}") + + # Start WebSocket listener + websocket_thread = start_websocket_listener(params, tube_registry, timeout=300, gateway_uid=gateway_uid, tunnel_session=tunnel_session) + + # Wait a moment for WebSocket to establish connection + import time + time.sleep(1.5) + + # Send offer to gateway via HTTP POST + logging.debug(f"{bcolors.OKBLUE}Sending {protocol} connection offer to gateway...{bcolors.ENDC}") + + # Prepare the offer data with terminal-specific parameters + # Match webvault format: host, size, audio, video, image (for guacd configuration) + # These parameters are needed by Gateway to configure guacd BEFORE OpenConnection + import shutil + + raw_columns = DEFAULT_TERMINAL_COLUMNS + raw_rows = DEFAULT_TERMINAL_ROWS + # Get terminal size for Guacamole size parameter + try: + terminal_size = shutil.get_terminal_size(fallback=(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS)) + raw_columns = terminal_size.columns + raw_rows = terminal_size.lines + except Exception: + logging.debug("Falling back to default terminal size for offer payload") + screen_info = _build_screen_info(raw_columns, raw_rows) + logging.debug( + f"Using terminal metrics columns={screen_info['columns']} rows={screen_info['rows']} -> " + f"{screen_info['pixel_width']}x{screen_info['pixel_height']}px @ {screen_info['dpi']}dpi" + ) + + offer_payload = offer.get("offer") + decoded_offer_bytes = None + decoded_offer_text = None + use_re_encoded_offer = False + + if isinstance(offer_payload, str): + try: + # Offers coming from the Rust module are base64-encoded SDP blobs. + decoded_offer_bytes = base64.b64decode(offer_payload, validate=True) + decoded_offer_text = decoded_offer_bytes.decode('utf-8') + use_re_encoded_offer = True + except Exception: + decoded_offer_text = offer_payload + elif isinstance(offer_payload, bytes): + decoded_offer_text = offer_payload.decode('utf-8', errors='ignore') + + if decoded_offer_text is None: + decoded_offer_text = offer_payload + + offer_sdp = _ensure_max_message_size_attribute(decoded_offer_text) + + if offer_sdp is None: + offer_payload = offer.get("offer") + elif use_re_encoded_offer: + offer_payload = base64.b64encode(offer_sdp.encode('utf-8')).decode('utf-8') + else: + offer_payload = offer_sdp + + offer_data = { + "offer": offer_payload, + "audio": ["audio/L8", "audio/L16"], # Supported audio codecs + "video": [], # Supported video codecs - None for terminal + "size": [screen_info['pixel_width'], screen_info['pixel_height'], screen_info['dpi']], # [width, height, dpi] + "image": ["image/jpeg", "image/png", "image/webp"], # Supported image formats + # CRITICAL: Gateway needs 'host' to configure guacd connection + "host": { + "hostName": settings['hostname'], + "port": settings['port'] + } + # these are not sent by webvault during open connection for terminal connections + # "protocol": protocol, + # "terminalSettings": { + # "colorScheme": settings['terminal']['colorScheme'], + # "fontSize": settings['terminal']['fontSize'], + # } + } + + # TODO: Add protocol-specific settings to offer + # if 'protocol_specific' in settings and settings['protocol_specific']: + # offer_data["protocolSettings"] = settings['protocol_specific'] + + # Log what we're sending in the initial offer + logging.debug(f"Sending initial offer with connection parameters: {json.dumps(offer_data, indent=2)}") + + string_data = json.dumps(offer_data) + logging.debug(f"payload.inputs.data JSON before encryption: {string_data}") + bytes_data = string_to_bytes(string_data) + encrypted_data = tunnel_encrypt(symmetric_key, bytes_data) + + # Extract userRecordUid from pamSettings + user_record_uid = None + pam_settings_field = record.get_typed_field('pamSettings') + if pam_settings_field: + pam_settings_value = pam_settings_field.get_default_value(dict) + if pam_settings_value: + connection = pam_settings_value.get('connection', {}) + if isinstance(connection, dict): + user_records = connection.get('userRecords', []) + if user_records and len(user_records) > 0: + user_record_uid = user_records[0] + logging.debug(f"Found userRecordUid: {user_record_uid}") + + if not user_record_uid: + logging.warning(f"No userRecordUid found in pamSettings for record {record_uid}") + + time.sleep(1) # Allow time for WebSocket listener to start + + # Send offer via HTTP POST - two paths: streaming vs non-streaming + try: + # Build inputs dict - matching working session format + inputs = { + "recordUid": record_uid, + 'kind': 'start', + 'base64Nonce': base64_nonce, + 'conversationType': context['conversationType'], + "data": encrypted_data, + "trickleICE": trickle_ice, # Set trickle ICE flag + } + + # Add userRecordUid and credentialType if we have a linked user + if user_record_uid: + inputs['userRecordUid'] = user_record_uid + inputs['credentialType'] = 'linked' + + # Router token is no longer extracted from cookies (removed in commit 338a9fda) + # Router affinity is now handled server-side + + # Generate messageId from conversationId (replace + with -, / with _) + message_id = GatewayAction.conversation_id_to_message_id(conversation_id_original) + logging.debug(f"Generated messageId: {message_id} from conversationId: {conversation_id_original}") + + # Two paths: streaming vs non-streaming + if trickle_ice: + # Streaming path: Response will come via WebSocket + router_response = router_send_action_to_gateway( + params=params, + destination_gateway_uid_str=gateway_uid, + gateway_action=GatewayActionWebRTCSession( + conversation_id=conversation_id_original, + inputs=inputs, + message_id=message_id + ), + message_type=pam_pb2.CMT_CONNECT, + is_streaming=True, # Response will come via WebSocket + gateway_timeout=30000 + ) + + logging.debug(f"{bcolors.OKGREEN}Offer sent to gateway (streaming mode){bcolors.ENDC}") + + # Mark offer as sent + signal_handler.offer_sent = True + tunnel_session.offer_sent = True + + # Send any buffered ICE candidates + if tunnel_session.buffered_ice_candidates: + logging.debug(f"Flushing {len(tunnel_session.buffered_ice_candidates)} buffered ICE candidates") + for candidate in tunnel_session.buffered_ice_candidates: + signal_handler._send_ice_candidate_immediately(candidate, commander_tube_id) + tunnel_session.buffered_ice_candidates.clear() + + logging.debug(f"{bcolors.OKGREEN}Terminal connection established for {protocol.upper()}{bcolors.ENDC}") + logging.debug(f"{bcolors.OKBLUE}Connection state: {bcolors.ENDC}gathering candidates...") + + return { + "success": True, + "tube_id": commander_tube_id, + "conversation_id": conversation_id, + "tube_registry": tube_registry, + "signal_handler": signal_handler, + "websocket_thread": websocket_thread, + "status": "connecting", + "screen_info": screen_info, + "python_handler": python_handler, # PythonHandler for simplified guac protocol + "use_python_handler": use_python_handler, + } + else: + # Non-streaming path: Handle response immediately + router_response = router_send_action_to_gateway( + params=params, + destination_gateway_uid_str=gateway_uid, + gateway_action=GatewayActionWebRTCSession( + conversation_id=conversation_id_original, + inputs=inputs, + message_id=message_id + ), + message_type=pam_pb2.CMT_CONNECT, + is_streaming=False, # Response comes immediately in HTTP response + gateway_timeout=30000 + ) + + logging.debug(f"{bcolors.OKGREEN}Offer sent to gateway (non-streaming mode){bcolors.ENDC}") + logging.debug(f"Router response: {router_response}") + + # Handle immediate response + if router_response and router_response.get('response'): + response_dict = router_response['response'] + logging.debug(f"Received immediate response from gateway: {response_dict}") + response_payload = response_dict.get('payload') if isinstance(response_dict, dict) else "{}" + if isinstance(response_payload, str): + try: + response_payload = json.loads(response_payload) + except json.JSONDecodeError: + response_payload = {} + + # Check for errors in response + if not (response_payload.get('is_ok') or response_payload.get('isOk')): + error_msg = response_payload.get('error', 'Unknown error from gateway') + raise Exception(f"Gateway error: {error_msg} Payload: {response_payload}") + + # Decrypt and handle payload.data if present (contains SDP answer) + if response_payload.get('is_ok') and response_payload.get('data'): + data_field = response_payload.get('data', '') + + # Check if this is a plain text acknowledgment (not encrypted) + if isinstance(data_field, str) and ( + "ice candidate" in data_field.lower() or + "buffered" in data_field.lower() or + "connected" in data_field.lower() or + "disconnected" in data_field.lower() or + "error" in data_field.lower() or + data_field.endswith(conversation_id_original) + ): + logging.debug(f"Received plain text acknowledgment: {data_field}") + else: + # This is encrypted data - decrypt it + encrypted_data = data_field + if encrypted_data: + logging.debug(f"Found encrypted data in response, length: {len(encrypted_data)}") + try: + decrypted_data = tunnel_decrypt(symmetric_key, encrypted_data) + if decrypted_data: + data_text = bytes_to_string(decrypted_data).replace("'", '"') + logging.debug(f"Successfully decrypted data for {conversation_id_original}, length: {len(data_text)}") + + # Parse JSON + try: + data_json = json.loads(data_text) + + # Ensure data_json is a dictionary + if isinstance(data_json, dict): + logging.debug(f"🔓 Decrypted payload type: {data_json.get('type', 'unknown')}, keys: {list(data_json.keys())}") + + # Handle SDP answer + if "answer" in data_json: + answer_sdp = data_json.get('answer') + if answer_sdp: + logging.debug(f"Found SDP answer in non-streaming response, sending to Rust for conversation: {conversation_id_original}") + tube_registry.set_remote_description(commander_tube_id, answer_sdp, is_answer=True) + + if hasattr(tunnel_session, "gateway_ready_event") and tunnel_session.gateway_ready_event is not None: + tunnel_session.gateway_ready_event.set() + logging.debug(f"{bcolors.OKBLUE}Connection state: {bcolors.ENDC}SDP answer received, connecting...") + + # Send any buffered local ICE candidates now that we have the answer + if tunnel_session.buffered_ice_candidates: + logging.debug(f"Sending {len(tunnel_session.buffered_ice_candidates)} buffered ICE candidates after answer") + for candidate in tunnel_session.buffered_ice_candidates: + signal_handler._send_ice_candidate_immediately(candidate, commander_tube_id) + tunnel_session.buffered_ice_candidates.clear() + elif "offer" in data_json or (data_json.get("type") == "offer"): + # Gateway is sending us an ICE restart offer (unlikely in non-streaming mode) + logging.warning(f"Received ICE restart offer in non-streaming mode - this is unexpected") + except json.JSONDecodeError as e: + logging.error(f"Failed to parse decrypted data as JSON: {e}") + logging.debug(f"Data text: {data_text[:200]}...") + else: + logging.warning(f"Decryption returned None for conversation {conversation_id_original}") + except Exception as e: + logging.error(f"Failed to decrypt data in non-streaming response: {e}") + logging.debug(f"Data content: {encrypted_data[:100]}...") + # Don't fail the connection if decryption fails - might be a plain text response + + # Mark offer as sent + signal_handler.offer_sent = True + tunnel_session.offer_sent = True + + # No ICE candidates to send in non-streaming mode (all candidates in SDP) + logging.debug(f"{bcolors.OKGREEN}Terminal connection established for {protocol.upper()}{bcolors.ENDC}") + logging.debug(f"{bcolors.OKBLUE}Connection state: {bcolors.ENDC}established (non-streaming mode)...") + + return { + "success": True, + "tube_id": commander_tube_id, + "conversation_id": conversation_id, + "tube_registry": tube_registry, + "signal_handler": signal_handler, + "websocket_thread": websocket_thread, + "status": "connected", + "router_response": router_response, + "screen_info": screen_info, + "python_handler": python_handler, # PythonHandler for simplified guac protocol + "use_python_handler": use_python_handler, + } + + except Exception as e: + signal_handler.cleanup() + unregister_tunnel_session(commander_tube_id) + unregister_conversation_key(conversation_id) + _notify_gateway_connection_close(params, router_token) + return {"success": False, "error": f"Failed to send offer via HTTP: {e}"} + + except Exception as e: + logging.error(f"Error opening terminal WebRTC tunnel: {e}") + if 'conversation_id' in locals() and conversation_id: + unregister_conversation_key(conversation_id) + if 'signal_handler' in locals(): + signal_handler.cleanup() + return {"success": False, "error": f"Failed to establish tunnel: {e}"} + + +def launch_terminal_connection(params: KeeperParams, + record_uid: str, + gateway_info: Dict[str, Any], + connect_as: Optional[str] = None, + **kwargs) -> Dict[str, Any]: + """ + Launch a terminal connection for a PAM record. + + This is the main entry point for terminal connections. It: + 1. Detects the protocol + 2. Extracts settings + 3. Builds connection context + 4. Opens WebRTC tunnel + 5. Manages lifecycle + + Args: + params: KeeperParams instance + record_uid: Record UID + gateway_info: Gateway information from find_gateway + connect_as: Optional username to connect as + + Returns: + Dictionary with connection status: + - success: bool + - protocol: str + - context: connection context dict + - tunnel: tunnel result dict + - error: error message if failed + + Raises: + CommandError: If connection cannot be established + """ + try: + # Step 1: Detect protocol + protocol = detect_protocol(params, record_uid) + if not protocol: + raise CommandError('pam launch', f'Could not detect protocol for record {record_uid}') + + logging.debug(f"Detected protocol: {protocol}") + + # Step 2: Extract settings + settings = extract_terminal_settings(params, record_uid, protocol) + logging.debug(f"Extracted settings: hostname={settings['hostname']}, port={settings['port']}") + + # Step 3: Build connection context + context = create_connection_context( + params, + record_uid, + gateway_info['gateway_uid'], + protocol, + settings, + connect_as + ) + logging.debug(f"Built connection context for {protocol}") + + # Step 4: Open WebRTC tunnel + tunnel_result = _open_terminal_webrtc_tunnel( + params, + record_uid, + gateway_info['gateway_uid'], + protocol, + settings, + context, + **kwargs + ) + + if not tunnel_result.get('success'): + error_msg = tunnel_result.get('error', 'Unknown error') + raise CommandError('pam launch', f'Failed to open WebRTC tunnel: {error_msg}') + + logging.debug(f"Terminal connection established for {protocol}") + logging.debug(f"Target: {settings['hostname']}:{settings['port']}") + logging.debug(f"Gateway: {gateway_info['gateway_name']} ({gateway_info['gateway_uid']})") + + return { + 'success': True, + 'protocol': protocol, + 'context': context, + 'settings': settings, + 'gateway_info': gateway_info, + 'tunnel': { + **tunnel_result, + 'registry': tunnel_result.get('tube_registry') # Add registry for CLI mode + }, + 'message': f'Terminal connection established for {protocol}' + } + + except CommandError: + raise + except Exception as e: + logging.error(f"Error launching terminal connection: {e}") + raise CommandError('pam launch', f'Failed to launch terminal connection: {e}') + + diff --git a/keepercommander/commands/pedm/pedm_admin.py b/keepercommander/commands/pedm/pedm_admin.py index 8fed5af17..645966c56 100644 --- a/keepercommander/commands/pedm/pedm_admin.py +++ b/keepercommander/commands/pedm/pedm_admin.py @@ -173,7 +173,7 @@ def resolve_existing_collections( class PedmCommand(base.GroupCommandNew): def __init__(self): - super().__init__('Privilege Manager - PEDM') + super().__init__('Administration of Endpoint Privilege Manager features') self.register_command_new(PedmSyncDownCommand(), 'sync-down') self.register_command_new(PedmDeploymentCommand(), 'deployment', 'd') self.register_command_new(PedmAgentCommand(), 'agent', 'a') @@ -186,7 +186,7 @@ def __init__(self): class PedmScimCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='scim', description='Sync PEDM user/group collections from AD or AzureAD') + parser = argparse.ArgumentParser(prog='scim', description='Sync EPM user/group collections from AD or AzureAD') subparsers = parser.add_subparsers(title='Directory Type', dest='auth_type', required=True, help='Authentication method') record_parser = subparsers.add_parser('record', help='Connection parameters from Keeper record') @@ -497,16 +497,16 @@ def build_group(group: ScimGroup) -> Optional[Tuple[admin_types.CollectionData, update_collections = list(update_map.values()) if len(add_collections) == 0 and len(update_collections) == 0: - logging.info('No PEDM collections to add or update.') + logging.info('No EPM collections to add or update.') return status = plugin.modify_collections(add_collections=add_collections, update_collections=update_collections) - logging.info('PEDM SCIM sync completed. Added: %d, Updated: %d', len(status.add), len(status.update)) + logging.info('EPM SCIM sync completed. Added: %d, Updated: %d', len(status.add), len(status.update)) class PedmSyncDownCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='sync-down', description='Sync down PEDM data from the backend') + parser = argparse.ArgumentParser(prog='sync-down', description='Sync down EPM data from the backend') parser.add_argument('--reload', dest='reload', action='store_true', help='Perform full sync') super().__init__(parser) @@ -517,7 +517,7 @@ def execute(self, context: KeeperParams, **kwargs): class PedmDeploymentCommand(base.GroupCommandNew): def __init__(self): - super().__init__('Manage PEDM deployments') + super().__init__('Manage EPM deployments') self.register_command_new(PedmDeploymentListCommand(), 'list', 'l') self.register_command_new(PedmDeploymentAddCommand(), 'add', 'a') self.register_command_new(PedmDeploymentUpdateCommand(), 'edit') @@ -528,7 +528,7 @@ def __init__(self): class PedmDeploymentListCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='list', description='List PEDM deployments', parents=[base.report_output_parser]) + parser = argparse.ArgumentParser(prog='list', description='List EPM deployments', parents=[base.report_output_parser]) parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='print verbose information') super().__init__(parser) @@ -562,7 +562,7 @@ def execute(self, context: KeeperParams, **kwargs): class PedmDeploymentAddCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='add', description='Add PEDM deployments') + parser = argparse.ArgumentParser(prog='add', description='Add EPM deployments') parser.add_argument('-f', '--force', dest='force', action='store_true', help='do not prompt for confirmation') # parser.add_argument('--spiffe-cert', dest='spiffe', action='store', @@ -614,7 +614,7 @@ def execute(self, context: KeeperParams, **kwargs): class PedmDeploymentUpdateCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='update', description='Update PEDM deployment') + parser = argparse.ArgumentParser(prog='update', description='Update EPM deployment') parser.add_argument('--disable', dest='disable', action='store', choices=['on', 'off'], help='do not prompt for confirmation') # parser.add_argument('--spiffe-cert', dest='spiffe', action='store', @@ -659,7 +659,7 @@ def execute(self, context: KeeperParams, **kwargs): class PedmDeploymentDeleteCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='delete', description='Delete PEDM deployment') + parser = argparse.ArgumentParser(prog='delete', description='Delete EPM deployment') parser.add_argument('-f', '--force', dest='force', action='store_true', help='do not prompt for confirmation') parser.add_argument('deployment', metavar='DEPLOYMENT', nargs='+', @@ -705,7 +705,7 @@ def execute(self, context: KeeperParams, **kwargs): class PedmDeploymentDownloadCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='download', description='Download PEDM deployment package') + parser = argparse.ArgumentParser(prog='download', description='Download EPM deployment package') grp = parser.add_mutually_exclusive_group() grp.add_argument('--file', dest='file', action='store', help='File name') grp.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='Verbose output') @@ -769,7 +769,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Optional[str]: class PedmAgentCommand(base.GroupCommandNew): def __init__(self): - super().__init__('Manage PEDM agents') + super().__init__('Manage EPM agents') self.register_command_new(PedmAgentListCommand(), 'list', 'l') self.register_command_new(PedmAgentEditCommand(), 'edit', 'e') self.register_command_new(PedmAgentDeleteCommand(), 'delete') @@ -780,7 +780,7 @@ def __init__(self): class PedmAgentCollectionCommand(base.ArgparseCommand): def __init__(self): parser = argparse.ArgumentParser(prog='list', parents=[base.report_output_parser], - description='List PEDM agent resources') + description='List EPM agent resources') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='print verbose information') parser.add_argument('--type', dest='type', action='store', type=int, @@ -836,7 +836,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmAgentDeleteCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='update', description='Delete PEDM agents') + parser = argparse.ArgumentParser(prog='update', description='Delete EPM agents') parser.add_argument('--force', dest='force', action='store_true', help='do not prompt for confirmation') parser.add_argument('agent', nargs='+', help='Agent UID(s)') @@ -866,7 +866,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmAgentEditCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='update', description='Update PEDM agents') + parser = argparse.ArgumentParser(prog='update', description='Update EPM agents') parser.add_argument('--enable', dest='enable', action='store', choices=['on', 'off'], help='Enables or disables agents') parser.add_argument('--deployment', dest='deployment', action='store', @@ -920,7 +920,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmAgentListCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='list', description='List PEDM agents', + parser = argparse.ArgumentParser(prog='list', description='List EPM agents', parents=[base.report_output_parser]) parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='print verbose information') @@ -968,7 +968,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmPolicyCommand(base.GroupCommandNew): def __init__(self): - super().__init__('Manage PEDM policies') + super().__init__('Manage EPM policies') self.register_command_new(PedmPolicyListCommand(), 'list', 'l') self.register_command_new(PedmPolicyAddCommand(), 'add', 'a') self.register_command_new(PedmPolicyEditCommand(), 'edit', 'e') @@ -1256,7 +1256,7 @@ def get_policy_filter(plugin: admin_plugin.PedmPlugin, **kwargs) -> Dict[str, An class PedmPolicyListCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='list', description='List PEDM policies', + parser = argparse.ArgumentParser(prog='list', description='List EPM policies', parents=[base.report_output_parser]) super().__init__(parser) @@ -1290,7 +1290,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmPolicyAddCommand(base.ArgparseCommand, PedmPolicyMixin): def __init__(self): - parser = argparse.ArgumentParser(prog='add', description='Add PEDM policy', parents=[PedmPolicyMixin.policy_filter]) + parser = argparse.ArgumentParser(prog='add', description='Add EPM policy', parents=[PedmPolicyMixin.policy_filter]) parser.add_argument('--policy-type', dest='policy_type', action='store', default='elevation', choices=['elevation', 'file_access', 'command', 'least_privilege'], help='Policy type') @@ -1409,7 +1409,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmPolicyEditCommand(base.ArgparseCommand, PedmPolicyMixin): def __init__(self): - parser = argparse.ArgumentParser(prog='edit', description='Edit PEDM policy', parents=[PedmPolicyMixin.policy_filter]) + parser = argparse.ArgumentParser(prog='edit', description='Edit EPM policy', parents=[PedmPolicyMixin.policy_filter]) parser.add_argument('policy', help='Policy UID') parser.add_argument('--policy-name', dest='policy_name', action='store', help='Policy name') @@ -1469,7 +1469,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmPolicyViewCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='view', parents=[base.json_output_parser], description='View PEDM policy') + parser = argparse.ArgumentParser(prog='view', parents=[base.json_output_parser], description='View EPM policy') parser.add_argument('policy', help='Policy UID or name') super().__init__(parser) @@ -1489,7 +1489,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmPolicyDeleteCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='delete', description='Delete PEDM policy') + parser = argparse.ArgumentParser(prog='delete', description='Delete EPM policy') parser.add_argument('policy', type=str, nargs='+', help='Policy UID or name') super().__init__(parser) @@ -1607,7 +1607,7 @@ def __init__(self): class PedmCollectionWipeOutCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='wipe-out', description='Wipe out PEDM collections') + parser = argparse.ArgumentParser(prog='wipe-out', description='Wipe out EPM collections') parser.add_argument('--type', dest='type', action='store', type=int, help='collection type') super().__init__(parser) @@ -1629,7 +1629,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmCollectionAddCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='add', description='Creates PEDM collections') + parser = argparse.ArgumentParser(prog='add', description='Creates EPM collections') parser.add_argument('--type', dest='type', action='store', type=int, help='collection type') parser.add_argument('data', nargs='+', help='Field assignment key=value (repeatable)') @@ -1672,7 +1672,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmCollectionUpdateCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='update', description='Update PEDM collection') + parser = argparse.ArgumentParser(prog='update', description='Update EPM collection') parser.add_argument('--type', dest='type', action='store', type=int, help='collection type (optional)') parser.add_argument('--name', dest='name', action='store', required=True, @@ -1710,7 +1710,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmCollectionDeleteCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='delete', description='Delete PEDM collections') + parser = argparse.ArgumentParser(prog='delete', description='Delete EPM collections') parser.add_argument('-f', '--force', dest='force', action='store_true', help='do not prompt for confirmation') parser.add_argument('collection', nargs='+', help='Collection or @orphan_resource') @@ -1761,7 +1761,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmCollectionConnectCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='link', description='Link values to PEDM collection') + parser = argparse.ArgumentParser(prog='link', description='Link values to EPM collection') parser.add_argument('--collection', '-c', dest='collection', action='store', help='Parent collection UID or name') parser.add_argument('--link-type', dest='link_type', action='store', required=True, @@ -1809,7 +1809,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmCollectionDisconnectCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='unlink', description='Unlink values from PEDM collections') + parser = argparse.ArgumentParser(prog='unlink', description='Unlink values from EPM collections') parser.add_argument('--collection', '-c', dest='collection', action='store', help='Parent collection UID or name') parser.add_argument('-f', '--force', dest='force', action='store_true', @@ -1863,7 +1863,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: class PedmCollectionListCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='list', description='List PEDM collections', + parser = argparse.ArgumentParser(prog='list', description='List EPM collections', parents=[base.report_output_parser]) parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='print verbose information') @@ -1960,7 +1960,7 @@ def any_match(row: Any) -> bool: class PedmCollectionViewCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='view', description='Show PEDM collection details', + parser = argparse.ArgumentParser(prog='view', description='Show EPM collection details', parents=[base.report_output_parser]) parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='print verbose information') @@ -2056,7 +2056,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmApprovalCommand(base.GroupCommandNew): def __init__(self): - super().__init__('Manage PEDM approval requests and approvals') + super().__init__('Manage EPM approval requests and approvals') self.register_command_new(PedmApprovalListCommand(), 'list', 'l') self.register_command_new(PedmApprovalViewCommand(), 'view') self.register_command_new(PedmApprovalStatusCommand(), 'action', 'a') @@ -2065,7 +2065,7 @@ def __init__(self): class PedmApprovalViewCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='view', parents=[base.json_output_parser], description='View PEDM approval') + parser = argparse.ArgumentParser(prog='view', parents=[base.json_output_parser], description='View EPM approval') parser.add_argument('approval', help='Approval UID') super().__init__(parser) @@ -2096,7 +2096,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmApprovalListCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='list', description='List PEDM approval requests', + parser = argparse.ArgumentParser(prog='list', description='List EPM approval requests', parents=[base.report_output_parser]) parser.add_argument('--type', dest='type', action='store', choices=['approved', 'denied', 'pending', 'expired'], help='approval type filter') @@ -2139,7 +2139,7 @@ def execute(self, context: KeeperParams, **kwargs) -> Any: class PedmApprovalStatusCommand(base.ArgparseCommand): def __init__(self): - parser = argparse.ArgumentParser(prog='action', description='Modify PEDM approval requests') + parser = argparse.ArgumentParser(prog='action', description='Modify EPM approval requests') parser.add_argument('--approve', dest='approve', action='append', help='Request UIDs for approval') parser.add_argument('--deny', dest='deny', action='append', diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 688ab53b4..c7a901f0c 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -395,6 +395,7 @@ def execute(self, params, **kwargs): } if version < 3 or kwargs.get('legacy') is True: ro['title'] = r.title + ro['record_type'] = r.record_type if r.login: ro['login'] = r.login if r.password: @@ -518,12 +519,17 @@ def execute(self, params, **kwargs): for user in rec['shares']['user_permissions']: print('') if 'username' in user: - print('User: ' + user['username']) + print(' User: ' + user['username']) if 'user_uid' in user: - print('User UID: ' + user['user_uid']) + print(' User UID: ' + user['user_uid']) elif 'accountUid' in user: - print('User UID: ' + user['accountUid']) - + print(' User UID: ' + user['accountUid']) + + # Show owner status + is_owner = user.get('owner', False) + if is_owner: + print(' Owner: Yes') + # Handle both possible spellings of sharable/shareable if 'sharable' in user: shareable = user['sharable'] @@ -531,59 +537,69 @@ def execute(self, params, **kwargs): shareable = user['shareable'] else: shareable = False - + if shareable is None: shareable = False - + # Handle both possible spellings of readable - if 'readable' in user: - readable = user['readable'] + if 'editable' in user: + editable = user['editable'] else: - readable = False - - if readable is None: - readable = False - - print('Shareable: ' + ('Yes' if shareable else 'No')) - print('Read-Only: ' + ('Yes' if not shareable else 'No')) - print('') + editable = False + + if editable is None: + editable = False + + print(' Shareable: ' + ('Yes' if shareable else 'No')) + print(' Read-Only: ' + ('Yes' if not editable else 'No')) if 'shared_folder_permissions' in rec['shares'] and rec['shares']['shared_folder_permissions']: print('') print('Shared Folder Permissions:') for sf in rec['shares']['shared_folder_permissions']: print('') if 'shared_folder_uid' in sf: - print('Shared Folder UID: ' + sf['shared_folder_uid']) + print(' Shared Folder UID: ' + sf['shared_folder_uid']) if 'user_uid' in sf: - print('User UID: ' + sf['user_uid']) + print(' User UID: ' + sf['user_uid']) elif 'accountUid' in sf: - print('User UID: ' + sf['accountUid']) - + print(' User UID: ' + sf['accountUid']) + # Safely access boolean fields with fallback to False if sf.get('manage_users', False) is True: - print('Manage Users: True') + print(' Manage Users: True') if sf.get('manage_records', False) is True: - print('Manage Records: True') + print(' Manage Records: True') if sf.get('can_edit', False) is True: - print('Can Edit: True') + print(' Can Edit: True') if sf.get('can_share', False) is True: - print('Can Share: True') - print('') + print(' Can Share: True') if 'team_permissions' in rec['shares'] and rec['shares']['team_permissions']: print('') print('Team Permissions:') for team in rec['shares']['team_permissions']: print('') if 'team_uid' in team: - print('Team UID: ' + team['team_uid']) + print(' Team UID: ' + team['team_uid']) if 'name' in team: - print('Name: ' + team['name']) - print('') + print(' Name: ' + team['name']) if admins: print('') - print('Share Admins:') - for admin in admins: - print(admin) + max_admins_shown = 10 + total_admins = len(admins) + if total_admins <= max_admins_shown: + print(f'Share Admins ({total_admins}):') + for admin in admins: + print(f' {admin}') + else: + print(f'Share Admins ({total_admins}, showing first {max_admins_shown}):') + for admin in admins[:max_admins_shown]: + print(f' {admin}') + print(f' ... and {total_admins - max_admins_shown} more') + + # Display rotation info for pamUser records when --include-dag is specified + if kwargs.get('include_dag') and r.record_type == 'pamUser': + self.display_rotation_info(params, r) + direct_match = True return @@ -790,7 +806,7 @@ def include_dag(self, params, ro, r): ro: Record output dictionary to be modified r: Record object """ - valid_record_types = {'pamDatabase', 'pamDirectory', 'pamMachine'} + valid_record_types = {'pamDatabase', 'pamDirectory', 'pamMachine', 'pamUser', 'pamRemoteBrowser'} if r.record_type not in valid_record_types: return @@ -815,19 +831,31 @@ def include_dag(self, params, ro, r): from .tunnel.port_forward.TunnelGraph import TunnelDAG from ..keeper_dag import EdgeType - encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, r.record_uid) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, r.record_uid, + transmission_key=transmission_key) if not tdag.linking_dag.has_graph: + ro['dagDebug'] = {'error': 'No graph loaded', 'has_graph': False} return record_vertex = tdag.linking_dag.get_vertex(r.record_uid) if record_vertex is None: + ro['dagDebug'] = {'error': f'Record vertex not found for {r.record_uid}', 'has_graph': True} return + # Add debug info about the vertex + ro['dagDebug'] = { + 'has_graph': True, + 'vertex_uid': record_vertex.uid, + 'vertex_name': getattr(record_vertex, 'name', None), + 'vertex_type': str(getattr(record_vertex, 'vertex_type', None)), + } + # Extract allowedSettings from vertex content try: content = record_vertex.content_as_dict + ro['dagDebug']['vertex_content'] = content if content and 'allowedSettings' in content: allowed_settings = content['allowedSettings'] if isinstance(allowed_settings, dict): @@ -837,13 +865,14 @@ def include_dag(self, params, ro, r): ro['pamSettingsEnabled']['sessionRecording'] = allowed_settings.get('sessionRecording') ro['pamSettingsEnabled']['typescriptRecording'] = allowed_settings.get('typescriptRecording') ro['pamSettingsEnabled']['remoteBrowserIsolation'] = allowed_settings.get('remoteBrowserIsolation') - except Exception: - pass + except Exception as e: + ro['dagDebug']['content_error'] = str(e) # Find all ACL links where Head is recordUID admin_credential = None launch_credential = None linked_credentials = [] + acl_debug = [] # Get vertices that have ACL edges pointing to this record (has_vertices) for user_vertex in record_vertex.has_vertices(EdgeType.ACL): @@ -851,6 +880,10 @@ def include_dag(self, params, ro, r): if acl_edge: try: content = acl_edge.content_as_dict or {} + acl_debug.append({ + 'user_vertex_uid': user_vertex.uid, + 'edge_content': content + }) # belongs_to = content.get('belongs_to', False) is_admin = content.get('is_admin', False) is_launch_credential = content.get('is_launch_credential', None) @@ -864,8 +897,57 @@ def include_dag(self, params, ro, r): if is_launch_credential and launch_credential is None: launch_credential = user_vertex.uid - except Exception: - pass + except Exception as e: + acl_debug.append({'error': str(e), 'user_vertex_uid': user_vertex.uid}) + + ro['dagDebug']['acl_edges'] = acl_debug + ro['dagDebug']['all_edges'] = [{'type': str(e.edge_type), 'head_uid': e.head_uid} for e in record_vertex.edges] + + # For pamUser records, show rotation profile from the ACL edge to parent (config/resource) + if r.record_type == 'pamUser': + rotation_profile = None + rotation_profile_config_uid = None + rotation_profile_resource_uid = None + + for parent_vertex in record_vertex.belongs_to_vertices(): + acl_edge = record_vertex.get_edge(parent_vertex, EdgeType.ACL) + if acl_edge: + try: + edge_content = acl_edge.content_as_dict or {} + + # Extract rotation profile flags from edge content + # (matches web vault PamUserAclData type in dag-pam-link.ts) + belongs_to = edge_content.get('belongs_to', False) + is_iam_user = edge_content.get('is_iam_user', False) + rotation_settings = edge_content.get('rotation_settings', {}) + is_noop = rotation_settings.get('noop', False) if isinstance(rotation_settings, dict) else False + + # Determine rotation profile using same logic as web vault + # (dag-pam-link.ts configLinkRotationProfile) + if is_noop: + rotation_profile = 'scripts_only' # "Run PAM scripts only" + rotation_profile_config_uid = parent_vertex.uid + elif is_iam_user: + rotation_profile = 'iam_user' # "IAM User" + rotation_profile_config_uid = parent_vertex.uid + elif belongs_to: + rotation_profile = 'general' # "General" (linked to resource) + rotation_profile_resource_uid = parent_vertex.uid + + # Store full edge content in dagDebug for troubleshooting + ro['dagDebug']['parentAclEdge'] = { + 'parent_uid': parent_vertex.uid, + 'parent_type': str(getattr(parent_vertex, 'vertex_type', None)), + 'content': edge_content + } + except Exception as e: + ro['dagDebug']['parentAclEdge'] = {'error': str(e), 'parent_uid': parent_vertex.uid} + + ro['rotationProfile'] = { + 'type': rotation_profile, + 'configUid': rotation_profile_config_uid, + 'resourceUid': rotation_profile_resource_uid + } # Update associatedCredentials with found values ro['associatedCredentials']['adminCredential'] = admin_credential @@ -874,6 +956,157 @@ def include_dag(self, params, ro, r): except Exception as e: logging.debug(f"Error accessing DAG for record {r.record_uid}: {e}") + ro['dagDebug'] = {'error': str(e)} + + def display_rotation_info(self, params, r): + """Display rotation info for pamUser records in table format (similar to web vault)""" + from .tunnel.port_forward.tunnel_helpers import get_keeper_tokens + from .tunnel.port_forward.TunnelGraph import TunnelDAG + from keeper_dag.edge import EdgeType + + try: + # Get rotation data from cache + rotation_data = params.record_rotation_cache.get(r.record_uid) + + # Get DAG data for rotation profile + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, r.record_uid, + transmission_key=transmission_key) + + rotation_profile = None + config_uid = None + resource_uid = None + + if tdag.linking_dag.has_graph: + record_vertex = tdag.linking_dag.get_vertex(r.record_uid) + if record_vertex: + for parent_vertex in record_vertex.belongs_to_vertices(): + acl_edge = record_vertex.get_edge(parent_vertex, EdgeType.ACL) + if acl_edge: + edge_content = acl_edge.content_as_dict or {} + belongs_to = edge_content.get('belongs_to', False) + is_iam_user = edge_content.get('is_iam_user', False) + rotation_settings = edge_content.get('rotation_settings', {}) + is_noop = rotation_settings.get('noop', False) if isinstance(rotation_settings, dict) else False + + if is_noop: + rotation_profile = 'Scripts Only' + config_uid = parent_vertex.uid + elif is_iam_user: + rotation_profile = 'IAM User' + config_uid = parent_vertex.uid + elif belongs_to: + rotation_profile = 'General' + resource_uid = parent_vertex.uid + + # Get config UID from rotation cache if not from DAG + if not config_uid and rotation_data: + config_uid = rotation_data.get('configuration_uid') + if not resource_uid and rotation_data: + resource_uid = rotation_data.get('resource_uid') + # If resource_uid equals config_uid, it's an IAM/NOOP user, not General + if resource_uid and resource_uid == config_uid: + resource_uid = None + + # Get configuration name + config_name = None + if config_uid: + config_record = vault.KeeperRecord.load(params, config_uid) + if config_record: + config_name = config_record.title + + # Get resource name + resource_name = None + if resource_uid: + resource_record = vault.KeeperRecord.load(params, resource_uid) + if resource_record: + resource_name = resource_record.title + + print('') + print('Rotation:') + + if not rotation_data and not rotation_profile: + print(' Status: Not configured') + return + + # Rotation status + if rotation_data: + disabled = rotation_data.get('disabled', False) + print(f' Status: {"Disabled" if disabled else "Enabled"}') + else: + print(' Status: Enabled') + + # Rotation profile + if rotation_profile: + print(f' Profile: {rotation_profile}') + + # PAM Configuration + if config_name: + print(f' Configuration: {config_name}') + elif config_uid: + print(f' Configuration UID: {config_uid}') + + # Resource (for General profile) + if resource_name: + print(f' Resource: {resource_name}') + elif resource_uid: + print(f' Resource UID: {resource_uid}') + + # Schedule + if rotation_data and rotation_data.get('schedule'): + try: + schedule_json = json.loads(rotation_data['schedule']) + if isinstance(schedule_json, list) and len(schedule_json) > 0: + schedule = schedule_json[0] + schedule_type = schedule.get('type', 'ON_DEMAND') + if schedule_type == 'ON_DEMAND': + print(' Schedule: On Demand') + else: + # Format schedule description + time_str = schedule.get('utcTime', '') + tz = schedule.get('tz', 'UTC') + if schedule_type == 'DAILY': + interval = schedule.get('intervalCount', 1) + desc = f"Every {interval} day(s) at {time_str} {tz}" + elif schedule_type == 'WEEKLY': + weekday = schedule.get('weekday', '') + desc = f"Weekly on {weekday} at {time_str} {tz}" + elif schedule_type == 'MONTHLY_BY_DAY': + day = schedule.get('day', 1) + desc = f"Monthly on day {day} at {time_str} {tz}" + elif schedule_type == 'MONTHLY_BY_WEEKDAY': + week = schedule.get('week', 'FIRST') + weekday = schedule.get('weekday', '') + desc = f"{week.title()} {weekday} of the month at {time_str} {tz}" + else: + desc = f"{schedule_type} at {time_str} {tz}" + print(f' Schedule: {desc}') + else: + print(' Schedule: On Demand') + except: + print(' Schedule: On Demand') + + # Last rotation + if rotation_data and rotation_data.get('last_rotation'): + last_rotation_ts = rotation_data['last_rotation'] + if last_rotation_ts > 0: + # Convert milliseconds to datetime + last_rotation_dt = datetime.datetime.fromtimestamp(last_rotation_ts / 1000) + print(f' Last Rotated: {last_rotation_dt.strftime("%b %d, %Y at %I:%M %p")}') + + # Show rotation status if available + # RecordRotationStatus enum: 0=NOT_ROTATED, 1=IN_PROGRESS, 2=SUCCESS, 3=FAILURE + last_status = rotation_data.get('last_rotation_status') + if last_status is not None: + status_map = {0: 'Not Rotated', 1: 'In Progress', 2: 'Success', 3: 'Failure'} + status_text = status_map.get(last_status, f'Unknown ({last_status})') + print(f' Last Status: {status_text}') + + except Exception as e: + logging.debug(f"Error displaying rotation info for {r.record_uid}: {e}") + print('') + print('Rotation:') + print(f' Error: Could not retrieve rotation info') class SearchCommand(Command): @@ -893,7 +1126,8 @@ def execute(self, params, **kwargs): all_results = [] if 'r' in categories: - records = list(vault_extensions.find_records(params, pattern)) + + records = list(vault_extensions.find_records(params, pattern, record_version=None if verbose else [2,3])) if records: if fmt == 'json': for record in records: diff --git a/keepercommander/commands/record_edit.py b/keepercommander/commands/record_edit.py index 37682d0e5..9827e67d4 100644 --- a/keepercommander/commands/record_edit.py +++ b/keepercommander/commands/record_edit.py @@ -71,7 +71,7 @@ help='load record type data from strings with dot notation') -append_parser = argparse.ArgumentParser(prog='append-notes', description='Append notes to an existing record.') +append_parser = argparse.ArgumentParser(prog='append-notes', description='Append notes to an existing record') append_parser.add_argument('--notes', dest='notes', action='store', help='notes') append_parser.add_argument('record', nargs='?', type=str, action='store', help='record path or UID') @@ -598,9 +598,10 @@ def assign_typed_fields(self, record, fields): else: if len(parsed_field.value) <= 10: dt = datetime.datetime.strptime(parsed_field.value, '%Y-%m-%d') + dt += datetime.timedelta(hours=12) else: dt = datetime.datetime.strptime(parsed_field.value, '%Y-%m-%dT%H:%M:%SZ') - value = int(dt.timestamp() * 1000) + value = utils.datetime_to_millis(dt) elif isinstance(ft.value, dict): if ft.name == 'name': value = vault.TypedField.import_name_field(parsed_field.value) @@ -805,9 +806,9 @@ def execute(self, params, **kwargs): record_fields.append(parsed_field) if record_type in ('legacy', 'general'): - raise CommandError('record-add', 'Legacy record type is not supported anymore.') - # record = vault.PasswordRecord() - # self.assign_legacy_fields(record, record_fields) + # raise CommandError('record-add', 'Legacy record type is not supported anymore.') + record = vault.PasswordRecord() + self.assign_legacy_fields(record, record_fields) else: rt_fields = self.get_record_type_fields(params, record_type) if not rt_fields: diff --git a/keepercommander/commands/register.py b/keepercommander/commands/register.py index 538c9eb63..a2c662e06 100644 --- a/keepercommander/commands/register.py +++ b/keepercommander/commands/register.py @@ -159,7 +159,7 @@ def register_command_info(aliases, command_info): record_permission_parser.error = raise_parse_exception record_permission_parser.exit = suppress_exit -find_ownerless_desc = 'List (and, optionally, claim) records in the user\'s vault that currently do not have an owner' +find_ownerless_desc = 'List (and, optionally, claim) ownerless records in the vault' find_ownerless_parser = argparse.ArgumentParser(prog='find-ownerless', description=find_ownerless_desc, parents=[base.report_output_parser]) find_ownerless_parser.add_argument('--claim', dest='claim', action='store_true', help='claim records found') @@ -460,7 +460,7 @@ def prepare_request(params, kwargs, curr_sf, users, teams, rec_uids, *, logging.warning('Share invitation has been sent to \'%s\'', username) logging.warning('Please repeat this command when invitation is accepted.') keys = params.key_cache.get(email) - if keys and keys.rsa or keys.ec: + if keys and (keys.rsa or keys.ec): uo.manageRecords = curr_sf.get('default_manage_records') is True if mr is None else folder_pb2.BOOLEAN_TRUE if mr == 'on' else folder_pb2.BOOLEAN_FALSE uo.manageUsers = curr_sf.get('default_manage_users') is True if mu is None else folder_pb2.BOOLEAN_TRUE if mu == 'on' else folder_pb2.BOOLEAN_FALSE sf_key = curr_sf.get('shared_folder_key_unencrypted') # type: Optional[bytes] @@ -1890,14 +1890,15 @@ def dump_record_details(records, output, output_fmt): verbose = kwargs.get('verbose') or not claim_records or out records_dump = None if ownerless_records: - logging.info(f'Found [{len(ownerless_records)}] ownerless record(s)') + count = len(ownerless_records) + logging.info(f'Found {count} ownerless {"record" if count == 1 else "records"}') if verbose: records_dump = dump_record_details(ownerless_records, out, fmt) if claim_records: claim_ownerless_records(ownerless_records) SyncDownCommand().execute(params, force=True) else: - logging.info('To claim the record(s) found above, re-run this command with the --claim flag.') + logging.info('To claim the records found above, re-run this command with the --claim flag.') else: logging.info('No ownerless records found') return records_dump diff --git a/keepercommander/commands/risk_management.py b/keepercommander/commands/risk_management.py index 6dfe7cc52..4401aeb02 100644 --- a/keepercommander/commands/risk_management.py +++ b/keepercommander/commands/risk_management.py @@ -20,8 +20,8 @@ def __init__(self): self.register_command('security-benchmarks-get', RiskManagementSecurityBenchmarksGetCommand(), 'Get the list of security benchmark set for the calling enterprise', 'sbg') self.register_command('security-benchmarks-set', RiskManagementSecurityBenchmarksSetCommand(), 'Set a list of security benchmark. Corresponding audit events will be logged', 'sbs') #Backward compatibility - self.register_command('user', RiskManagementEnterpriseStatDetailsCommand(), 'Show Risk Management User report (absolete)', 'u') - self.register_command('alert', RiskManagementSecurityAlertsSummaryCommand(), 'Show Risk Management Alert report (absolete)', 'a') + self.register_command('user', RiskManagementEnterpriseStatDetailsCommand(), 'Show Risk Management User report (obsolete)', 'u') + self.register_command('alert', RiskManagementSecurityAlertsSummaryCommand(), 'Show Risk Management Alert report (obsolete)', 'a') rmd_enterprise_stat_parser = argparse.ArgumentParser(prog='risk-management enterprise-stat', description='Risk management enterprise stat', parents=[base.report_output_parser]) @@ -170,7 +170,7 @@ def execute(self, params, **kwargs): aet = kwargs.get('aet') aetid = event_lookup.get(aet, 0) if aetid < 1: - raise ValueError(f'Invalid aetid {aetid}: valid aetid > 0') + raise base.CommandError('Valid Audit Event code or name required') request.auditEventTypeId = aetid done = False header = [ diff --git a/keepercommander/commands/supershell.py b/keepercommander/commands/supershell.py new file mode 100644 index 000000000..771e8e43e --- /dev/null +++ b/keepercommander/commands/supershell.py @@ -0,0 +1,4627 @@ +""" +Keeper SuperShell - A full-screen terminal UI for Keeper vault +""" + +import logging +import asyncio +import random +import sys +import io +import json +import re +import time +import os +from pathlib import Path +from typing import Optional, List, Dict, Any +import pyperclip +from rich.markup import escape as rich_escape + + +# Color themes - each theme uses variations of a primary color +COLOR_THEMES = { + 'green': { + 'primary': '#00ff00', # Bright green + 'primary_dim': '#00aa00', # Dim green + 'primary_bright': '#44ff44', # Light green + 'secondary': '#88ff88', # Light green accent + 'selection_bg': '#004400', # Selection background + 'hover_bg': '#002200', # Hover background (dimmer than selection) + 'text': '#ffffff', # White text + 'text_dim': '#aaaaaa', # Dim text + 'folder': '#44ff44', # Folder color (light green) + 'folder_shared': '#00dd00', # Shared folder (slightly different green) + 'record': '#00aa00', # Record color (dimmer than folders) + 'record_num': '#888888', # Record number + 'attachment': '#00cc00', # Attachment color + 'virtual_folder': '#00ff88', # Virtual folder + 'status': '#00ff00', # Status bar + 'border': '#00aa00', # Borders + 'root': '#00ff00', # Root node + 'header_user': '#00bbff', # Header username (blue contrast) + }, + 'blue': { + 'primary': '#0099ff', + 'primary_dim': '#0066cc', + 'primary_bright': '#66bbff', + 'secondary': '#00ccff', + 'selection_bg': '#002244', + 'hover_bg': '#001122', + 'text': '#ffffff', + 'text_dim': '#aaaaaa', + 'folder': '#66bbff', + 'folder_shared': '#0099ff', + 'record': '#0077cc', # Record color (dimmer than folders) + 'record_num': '#888888', + 'attachment': '#0077cc', + 'virtual_folder': '#00aaff', + 'status': '#0099ff', + 'border': '#0066cc', + 'root': '#0099ff', + 'header_user': '#ff9900', # Header username (orange contrast) + }, + 'magenta': { + 'primary': '#ff66ff', + 'primary_dim': '#cc44cc', + 'primary_bright': '#ff99ff', + 'secondary': '#ffaaff', + 'selection_bg': '#330033', + 'hover_bg': '#220022', + 'text': '#ffffff', + 'text_dim': '#aaaaaa', + 'folder': '#ff99ff', + 'folder_shared': '#ff66ff', + 'record': '#cc44cc', # Record color (dimmer than folders) + 'record_num': '#888888', + 'attachment': '#cc44cc', + 'virtual_folder': '#ffaaff', + 'status': '#ff66ff', + 'border': '#cc44cc', + 'root': '#ff66ff', + 'header_user': '#66ff66', # Header username (green contrast) + }, + 'yellow': { + 'primary': '#ffff00', + 'primary_dim': '#cccc00', + 'primary_bright': '#ffff66', + 'secondary': '#ffcc00', + 'selection_bg': '#333300', + 'hover_bg': '#222200', + 'text': '#ffffff', + 'text_dim': '#aaaaaa', + 'folder': '#ffff66', + 'folder_shared': '#ffcc00', + 'record': '#cccc00', # Record color (dimmer than folders) + 'record_num': '#888888', + 'attachment': '#cccc00', + 'virtual_folder': '#ffff88', + 'status': '#ffff00', + 'border': '#cccc00', + 'root': '#ffff00', + 'header_user': '#66ccff', # Header username (blue contrast) + }, + 'white': { + 'primary': '#ffffff', + 'primary_dim': '#cccccc', + 'primary_bright': '#ffffff', + 'secondary': '#dddddd', + 'selection_bg': '#444444', + 'hover_bg': '#333333', + 'text': '#ffffff', + 'text_dim': '#999999', + 'folder': '#eeeeee', + 'folder_shared': '#dddddd', + 'record': '#bbbbbb', # Record color (dimmer than folders) + 'record_num': '#888888', + 'attachment': '#cccccc', + 'virtual_folder': '#ffffff', + 'status': '#ffffff', + 'border': '#888888', + 'root': '#ffffff', + 'header_user': '#66ccff', # Header username (blue contrast) + }, +} + +# Preferences file path +PREFS_FILE = Path.home() / '.keeper' / 'supershell_prefs.json' + + +def load_preferences() -> dict: + """Load preferences from file, return defaults if not found""" + defaults = {'color_theme': 'green'} + try: + if PREFS_FILE.exists(): + with open(PREFS_FILE, 'r') as f: + prefs = json.load(f) + # Merge with defaults + return {**defaults, **prefs} + except Exception as e: + logging.debug(f"Error loading preferences: {e}") + return defaults + + +def save_preferences(prefs: dict): + """Save preferences to file""" + try: + PREFS_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(PREFS_FILE, 'w') as f: + json.dump(prefs, f, indent=2) + except Exception as e: + logging.error(f"Error saving preferences: {e}") + +from textual.app import App, ComposeResult +from textual.containers import Container, Horizontal, Vertical, VerticalScroll, Center, Middle +from textual.widgets import Tree, DataTable, Footer, Header, Static, Input, Label, Button +from textual.binding import Binding +from textual.screen import Screen, ModalScreen +from textual.reactive import reactive +from textual import on, work +from textual.message import Message +from textual.timer import Timer +from rich.text import Text +from textual.events import Click, MouseDown, Paste + +from ..commands.base import Command + + +class ClickableDetailLine(Static): + """A single line in the detail view that highlights on hover and copies on click""" + + DEFAULT_CSS = """ + ClickableDetailLine { + width: 100%; + height: auto; + padding: 0 1; + } + + ClickableDetailLine:hover { + background: #1a1a2e; + } + + ClickableDetailLine.has-value { + /* Clickable lines get a subtle left border indicator */ + } + + ClickableDetailLine.has-value:hover { + background: #16213e; + text-style: bold; + border-left: thick #00ff00; + } + """ + + def __init__(self, content: str, copy_value: str = None, record_uid: str = None, is_password: bool = False, *args, **kwargs): + """ + Create a clickable detail line. + + Args: + content: Rich markup content to display + copy_value: Value to copy on click (None = not copyable) + record_uid: Record UID for password audit events + is_password: If True, use ClipboardCommand for audit event + """ + self.copy_value = copy_value + self.record_uid = record_uid + self.is_password = is_password + classes = "has-value" if copy_value else "" + super().__init__(content, classes=classes, *args, **kwargs) + + def on_mouse_down(self, event: MouseDown) -> None: + """Handle mouse down to copy value - fires immediately without waiting for focus""" + if self.copy_value: + try: + if self.is_password and self.record_uid: + # Use ClipboardCommand to generate audit event for password copy + cc = ClipboardCommand() + cc.execute(self.app.params, record=self.record_uid, output='clipboard', + username=None, copy_uid=False, login=False, totp=False, field=None, revision=None) + self.app.notify("🔑 Password copied to clipboard!", severity="information") + else: + # Regular copy for non-password fields + pyperclip.copy(self.copy_value) + self.app.notify(f"Copied: {self.copy_value[:50]}{'...' if len(self.copy_value) > 50 else ''}", severity="information") + except Exception as e: + self.app.notify(f"Copy failed: {e}", severity="error") + + +class ClickableField(Static): + """A clickable field that copies value to clipboard on click""" + + DEFAULT_CSS = """ + ClickableField { + width: 100%; + height: auto; + padding: 0 1; + } + + ClickableField:hover { + background: #333333; + } + + ClickableField.clickable-value:hover { + background: #444444; + text-style: bold; + } + """ + + def __init__(self, label: str, value: str, copy_value: str = None, + label_color: str = "#aaaaaa", value_color: str = "#00ff00", + is_header: bool = False, indent: int = 0, *args, **kwargs): + """ + Create a clickable field. + + Args: + label: The field label (e.g., "Username:") + value: The display value + copy_value: The value to copy (defaults to value) + label_color: Color for label + value_color: Color for value + is_header: If True, style as section header + indent: Indentation level (spaces) + """ + self.copy_value = copy_value if copy_value is not None else value + + # Build content before calling super().__init__ + indent_str = " " * indent + # Escape brackets for Rich markup + safe_value = value.replace('[', '\\[').replace(']', '\\]') if value else '' + safe_label = label.replace('[', '\\[').replace(']', '\\]') if label else '' + + if is_header: + content = f"[bold {value_color}]{indent_str}{safe_label}[/bold {value_color}]" + elif label: + content = f"{indent_str}[{label_color}]{safe_label}[/{label_color}] [{value_color}]{safe_value}[/{value_color}]" + else: + content = f"{indent_str}[{value_color}]{safe_value}[/{value_color}]" + + # Set classes for hover effect + classes = "clickable-value" if self.copy_value else "" + + super().__init__(content, classes=classes, *args, **kwargs) + + def on_mouse_down(self, event: MouseDown) -> None: + """Handle mouse down to copy value - fires immediately without waiting for focus""" + if self.copy_value: + try: + pyperclip.copy(self.copy_value) + self.app.notify(f"Copied to clipboard", severity="information") + except Exception as e: + self.app.notify(f"Copy failed: {e}", severity="error") + + +class ClickableRecordUID(Static): + """A clickable record UID that navigates to the record when clicked""" + + DEFAULT_CSS = """ + ClickableRecordUID { + width: 100%; + height: auto; + padding: 0 1; + } + + ClickableRecordUID:hover { + background: #333344; + text-style: bold underline; + } + """ + + def __init__(self, label: str, record_uid: str, record_title: str = None, + label_color: str = "#aaaaaa", value_color: str = "#ffff00", + indent: int = 0, *args, **kwargs): + """ + Create a clickable record UID that navigates to the record. + + Args: + label: The field label (e.g., "Record UID:") + record_uid: The UID of the record to navigate to + record_title: Optional title to display instead of UID + label_color: Color for label + value_color: Color for value + indent: Indentation level + """ + self.record_uid = record_uid + + # Build content before calling super().__init__ + indent_str = " " * indent + display_value = record_title or record_uid + safe_value = display_value.replace('[', '\\[').replace(']', '\\]') + safe_label = label.replace('[', '\\[').replace(']', '\\]') if label else '' + + if label: + content = f"{indent_str}[{label_color}]{safe_label}[/{label_color}] [{value_color}]{safe_value} ↗[/{value_color}]" + else: + content = f"{indent_str}[{value_color}]{safe_value} ↗[/{value_color}]" + + super().__init__(content, *args, **kwargs) + + def on_mouse_down(self, event: MouseDown) -> None: + """Handle mouse down to navigate to record - fires immediately without waiting for focus""" + # Find the app and trigger record selection + app = self.app + if hasattr(app, 'records') and self.record_uid in app.records: + # Navigate to the record in the tree + app.selected_record = self.record_uid + app.selected_folder = None + app._display_record_detail(self.record_uid) + + # Try to select the node in the tree + tree = app.query_one("#folder_tree", Tree) + app._select_record_in_tree(tree, self.record_uid) + + app.notify(f"Navigated to record", severity="information") + else: + # Just copy the UID if record not found + pyperclip.copy(self.record_uid) + app.notify(f"Record not in vault. UID copied.", severity="warning") + + +from ..commands.record import RecordGetUidCommand, ClipboardCommand +from ..display import bcolors +from .. import vault +from .. import utils + + +class PreferencesScreen(ModalScreen): + """Modal screen for user preferences""" + + DEFAULT_CSS = """ + PreferencesScreen { + align: center middle; + } + + #prefs_container { + width: 40; + height: auto; + max-height: 90%; + background: #111111; + border: solid #444444; + padding: 1 2; + } + + #prefs_title { + text-align: center; + text-style: bold; + padding-bottom: 1; + } + + #prefs_content { + height: auto; + padding: 0 1; + } + + #prefs_footer { + text-align: center; + padding-top: 1; + color: #666666; + } + """ + + BINDINGS = [ + Binding("escape", "dismiss", "Close", show=False), + Binding("q", "dismiss", "Close", show=False), + Binding("1", "select_green", "Green", show=False), + Binding("2", "select_blue", "Blue", show=False), + Binding("3", "select_magenta", "Magenta", show=False), + Binding("4", "select_yellow", "Yellow", show=False), + Binding("5", "select_white", "White", show=False), + ] + + def __init__(self, app_instance): + super().__init__() + self.app_instance = app_instance + + def compose(self) -> ComposeResult: + current = self.app_instance.color_theme + with Vertical(id="prefs_container"): + yield Static("[bold cyan]⚙ Preferences[/bold cyan]", id="prefs_title") + yield Static(f"""[green]Color Theme:[/green] + [#00ff00]1[/#00ff00] {'●' if current == 'green' else '○'} Green + [#0099ff]2[/#0099ff] {'●' if current == 'blue' else '○'} Blue + [#ff66ff]3[/#ff66ff] {'●' if current == 'magenta' else '○'} Magenta + [#ffff00]4[/#ffff00] {'●' if current == 'yellow' else '○'} Yellow + [#ffffff]5[/#ffffff] {'●' if current == 'white' else '○'} White""", id="prefs_content") + yield Static("[dim]Press 1-5 to select, Esc or q to close[/dim]", id="prefs_footer") + + def action_dismiss(self): + """Close the preferences screen""" + self.dismiss() + + def key_escape(self): + """Handle escape key directly""" + self.dismiss() + + def key_q(self): + """Handle q key directly""" + self.dismiss() + + def action_select_green(self): + self._apply_theme('green') + + def action_select_blue(self): + self._apply_theme('blue') + + def action_select_magenta(self): + self._apply_theme('magenta') + + def action_select_yellow(self): + self._apply_theme('yellow') + + def action_select_white(self): + self._apply_theme('white') + + def _apply_theme(self, theme_name: str): + """Apply the selected theme and save preferences""" + self.app_instance.set_color_theme(theme_name) + # Save to preferences file + prefs = load_preferences() + prefs['color_theme'] = theme_name + save_preferences(prefs) + self.app_instance.notify(f"Theme changed to {theme_name}") + self.dismiss() + + +class HelpScreen(ModalScreen): + """Modal screen for help/keyboard shortcuts""" + + DEFAULT_CSS = """ + HelpScreen { + align: center middle; + } + + #help_container { + width: 90; + height: auto; + max-height: 90%; + background: #111111; + border: solid #444444; + padding: 1 2; + } + + #help_title { + text-align: center; + text-style: bold; + padding-bottom: 1; + } + + #help_columns { + height: auto; + } + + .help_column { + width: 1fr; + height: auto; + padding: 0 1; + } + + #help_footer { + text-align: center; + padding-top: 1; + color: #666666; + } + """ + + BINDINGS = [ + Binding("escape", "dismiss", "Close", show=False), + Binding("q", "dismiss", "Close", show=False), + ] + + def compose(self) -> ComposeResult: + with Vertical(id="help_container"): + yield Static("[bold cyan]⌨ Keyboard Shortcuts[/bold cyan]", id="help_title") + with Horizontal(id="help_columns"): + yield Static("""[green]Navigation:[/green] + j/k ↑/↓ Move up/down + h/l ←/→ Collapse/expand + g / G Top / bottom + Ctrl+d/u Half page + Ctrl+e/y Scroll line + Esc Clear/collapse + +[green]Focus Cycling:[/green] + Tab Tree→Detail→Search + Shift+Tab Cycle backwards + / Focus search + Ctrl+U Clear search + Esc Focus tree + +[green]General:[/green] + ? Help + ! Keeper shell + Ctrl+q Quit""", classes="help_column") + yield Static("""[green]Copy to Clipboard:[/green] + p Password + u Username + c Copy all + w URL + i Record UID + +[green]Actions:[/green] + t Toggle JSON view + m Mask/Unmask + d Sync vault + W User info + D Device info + P Preferences""", classes="help_column") + yield Static("[dim]Press Esc or q to close[/dim]", id="help_footer") + + def action_dismiss(self): + """Close the help screen""" + self.dismiss() + + def key_escape(self): + """Handle escape key directly""" + self.dismiss() + + def key_q(self): + """Handle q key directly""" + self.dismiss() + + +class SuperShellApp(App): + """The Keeper SuperShell TUI application""" + + # Constants for thresholds and limits + AUTO_EXPAND_THRESHOLD = 100 # Auto-expand tree when search results < this number + DEVICE_DISPLAY_LIMIT = 5 # Max devices to show before truncating + SHARE_DISPLAY_LIMIT = 10 # Max shares to show before truncating + PAGE_DOWN_NODES = 10 # Number of nodes to move for half-page down + PAGE_DOWN_FULL_NODES = 20 # Number of nodes to move for full-page down + + @staticmethod + def _strip_ansi_codes(text: str) -> str: + """Remove ANSI color codes from text""" + ansi_escape = re.compile(r'\x1b\[[0-9;]*m') + return ansi_escape.sub('', text) + + CSS = """ + Screen { + background: #000000; + } + + Input { + background: #111111; + color: #ffffff; + } + + Input > .input--content { + color: #ffffff; + } + + Input > .input--placeholder { + color: #666666; + } + + Input > .input--cursor { + color: #ffffff; + text-style: reverse; + } + + Input:focus { + border: solid #888888; + } + + Input:focus > .input--content { + color: #ffffff; + } + + #search_bar { + dock: top; + height: 3; + width: 100%; + background: #222222; + border: solid #666666; + } + + #search_display { + width: 35%; + background: #222222; + color: #ffffff; + padding: 0 2; + height: 3; + } + + #search_results_label { + width: 15%; + color: #aaaaaa; + text-align: right; + padding: 0 2; + height: 3; + background: #222222; + } + + #user_info { + width: auto; + height: 3; + background: #222222; + color: #888888; + padding: 0 1; + } + + #device_status_info { + width: auto; + height: 3; + background: #222222; + color: #888888; + padding: 0 2; + text-align: right; + } + + .clickable-info:hover { + background: #333333; + } + + #main_container { + height: 100%; + background: #000000; + } + + #folder_panel { + width: 50%; + border-right: thick #666666; + padding: 1; + background: #000000; + } + + #folder_tree { + height: 100%; + background: #000000; + } + + #record_panel { + width: 50%; + padding: 1; + background: #000000; + } + + #record_detail { + height: 100%; + overflow-y: auto; + padding: 1; + background: #000000; + } + + #detail_content { + background: #000000; + color: #ffffff; + } + + Tree { + background: #000000; + color: #ffffff; + } + + Tree > .tree--guides { + color: #444444; + } + + Tree > .tree--toggle { + /* Hide expand/collapse icons - nodes still expand/collapse on click */ + width: 0; + } + + Tree > .tree--cursor { + /* Selected row - neutral background that works with all color themes */ + background: #333333; + text-style: bold; + } + + Tree > .tree--highlight { + /* Hover row - subtle background, different from selection */ + background: #1a1a1a; + } + + Tree > .tree--highlight-line { + background: #1a1a1a; + } + + /* Hide tree selection when search input is active */ + Tree.search-input-active > .tree--cursor { + background: transparent; + text-style: none; + } + + Tree.search-input-active > .tree--highlight { + background: transparent; + } + + DataTable { + background: #000000; + color: #ffffff; + } + + DataTable > .datatable--cursor { + background: #444444; + color: #ffffff; + text-style: bold; + } + + DataTable > .datatable--header { + background: #222222; + color: #ffffff; + text-style: bold; + } + + Static { + background: #000000; + color: #ffffff; + } + + VerticalScroll { + background: #000000; + } + + #record_detail:focus { + background: #0a0a0a; + border: solid #333333; + } + + #record_detail:focus-within { + background: #0a0a0a; + } + + #status_bar { + dock: bottom; + height: 1; + background: #000000; + color: #aaaaaa; + padding: 0 2; + } + + #shortcuts_bar { + dock: bottom; + height: 2; + background: #111111; + color: #888888; + padding: 0 1; + border-top: solid #333333; + } + """ + + BINDINGS = [ + Binding("ctrl+q", "quit", "Quit", show=False), + Binding("d", "sync_vault", "Sync", show=False), + Binding("/", "search", "Search", show=False), + Binding("P", "show_preferences", "Preferences", show=False), + Binding("p", "copy_password", "Copy Password", show=False), + Binding("u", "copy_username", "Copy Username", show=False), + Binding("w", "copy_url", "Copy URL", show=False), + Binding("i", "copy_uid", "Copy UID", show=False), + Binding("c", "copy_record", "Copy All", show=False), + Binding("t", "toggle_view_mode", "Toggle JSON", show=False), + Binding("m", "toggle_unmask", "Toggle Unmask", show=False), + Binding("W", "show_user_info", "User Info", show=False), + Binding("D", "show_device_info", "Device Info", show=False), + Binding("?", "show_help", "Help", show=False), + # Vim-style navigation + Binding("j", "cursor_down", "Down", show=False), + Binding("k", "cursor_up", "Up", show=False), + Binding("h", "cursor_left", "Left", show=False), + Binding("l", "cursor_right", "Right", show=False), + Binding("g", "goto_top", "Go to Top", show=False), + Binding("G", "goto_bottom", "Go to Bottom", show=False), + # Vim page navigation + Binding("ctrl+d", "page_down", "Page Down", show=False), + Binding("ctrl+u", "page_up", "Page Up", show=False), + Binding("ctrl+f", "page_down_full", "Page Down (Full)", show=False), + Binding("ctrl+b", "page_up_full", "Page Up (Full)", show=False), + # Vim line scrolling + Binding("ctrl+e", "scroll_down", "Scroll Down", show=False), + Binding("ctrl+y", "scroll_up", "Scroll Up", show=False), + ] + + def __init__(self, params): + super().__init__() + self.params = params + self.records = {} + self.record_to_folder = {} + self.records_in_subfolders = set() # Records in actual subfolders (not root) + self.file_attachment_to_parent = {} # Maps attachment_uid -> parent_record_uid + self.record_file_attachments = {} # Maps record_uid -> list of attachment_uids + self.linked_record_to_parent = {} # Maps linked_record_uid -> parent_record_uid (for addressRef, cardRef, etc.) + self.record_linked_records = {} # Maps record_uid -> list of linked_record_uids + self.app_record_uids = set() # Set of Secrets Manager app record UIDs + self.current_folder = None + self.selected_record = None + self.selected_folder = None + self.view_mode = 'detail' # 'detail' or 'json' + self.unmask_secrets = False # When True, show secret/password/passphrase field values + self.search_query = "" # Current search query + self.search_input_text = "" # Text being typed in search + self.search_input_active = False # True when typing in search, False when navigating results + self.filtered_record_uids = None # None = show all, Set = filtered UIDs + # Save selection before search to restore on ESC + self.pre_search_selected_record = None + self.pre_search_selected_folder = None + self.title = "" + self.sub_title = "" + # Vim-style command mode (e.g., :20 to go to line 20) + self.command_mode = False + self.command_buffer = "" + # Load color theme from preferences + prefs = load_preferences() + self.color_theme = prefs.get('color_theme', 'green') + self.theme_colors = COLOR_THEMES.get(self.color_theme, COLOR_THEMES['green']) + + def set_color_theme(self, theme_name: str): + """Set the color theme and refresh the display""" + if theme_name in COLOR_THEMES: + self.color_theme = theme_name + self.theme_colors = COLOR_THEMES[theme_name] + + # Save the current tree expansion state before rebuilding + tree = self.query_one("#folder_tree", Tree) + expanded_nodes = set() + + def collect_expanded(node): + """Recursively collect UIDs of expanded nodes""" + if node.is_expanded and hasattr(node, 'data') and node.data: + node_uid = node.data.get('uid') if isinstance(node.data, dict) else None + node_type = node.data.get('type') if isinstance(node.data, dict) else None + if node_uid: + expanded_nodes.add(node_uid) + elif node_type == 'root': + expanded_nodes.add('__root__') + elif node_type == 'virtual_folder': + expanded_nodes.add('__secrets_manager_apps__') + for child in node.children: + collect_expanded(child) + + collect_expanded(tree.root) + + # Refresh the tree to apply new colors + self._setup_folder_tree() + + # Restore expansion state + def restore_expanded(node): + """Recursively restore expanded state""" + if hasattr(node, 'data') and node.data: + node_uid = node.data.get('uid') if isinstance(node.data, dict) else None + node_type = node.data.get('type') if isinstance(node.data, dict) else None + + should_expand = False + if node_uid and node_uid in expanded_nodes: + should_expand = True + elif node_type == 'root' and '__root__' in expanded_nodes: + should_expand = True + elif node_type == 'virtual_folder' and node_uid == '__secrets_manager_apps__' and '__secrets_manager_apps__' in expanded_nodes: + should_expand = True + + if should_expand and node.allow_expand: + node.expand() + + for child in node.children: + restore_expanded(child) + + restore_expanded(tree.root) + + # Update CSS dynamically for tree selection/hover + self._apply_theme_css() + + def notify(self, message, *, title="", severity="information", timeout=1.5): + """Override notify to use faster timeout (default 1.5s instead of 5s)""" + super().notify(message, title=title, severity=severity, timeout=timeout) + + def _get_welcome_screen_content(self) -> str: + """Generate the My Vault welcome screen content with current theme colors""" + t = self.theme_colors + return f"""[bold {t['primary']}]● Keeper SuperShell[/bold {t['primary']}] + +[{t['secondary']}]A CLI-based vault viewer with keyboard and mouse navigation.[/{t['secondary']}] + +[bold {t['primary_bright']}]Getting Started[/bold {t['primary_bright']}] + [{t['text_dim']}]•[/{t['text_dim']}] Use [{t['primary']}]j/k[/{t['primary']}] or [{t['primary']}]↑/↓[/{t['primary']}] to navigate up/down + [{t['text_dim']}]•[/{t['text_dim']}] Use [{t['primary']}]l[/{t['primary']}] or [{t['primary']}]→[/{t['primary']}] to expand folders + [{t['text_dim']}]•[/{t['text_dim']}] Use [{t['primary']}]h[/{t['primary']}] or [{t['primary']}]←[/{t['primary']}] to collapse folders + [{t['text_dim']}]•[/{t['text_dim']}] Press [{t['primary']}]/[/{t['primary']}] to search for records + [{t['text_dim']}]•[/{t['text_dim']}] Press [{t['primary']}]Esc[/{t['primary']}] to collapse and navigate back + +[bold {t['primary_bright']}]Vim-Style Navigation[/bold {t['primary_bright']}] + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]g[/{t['primary']}] - Go to top + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]G[/{t['primary']}] (Shift+G) - Go to bottom + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]Ctrl+d/u[/{t['primary']}] - Half page down/up + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]Ctrl+e/y[/{t['primary']}] - Scroll down/up one line + +[bold {t['primary_bright']}]Quick Actions[/bold {t['primary_bright']}] + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]c[/{t['primary']}] - Copy password + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]u[/{t['primary']}] - Copy username + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]w[/{t['primary']}] - Copy URL + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]t[/{t['primary']}] - Toggle Detail/JSON view + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]d[/{t['primary']}] - Sync & refresh vault + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]![/{t['primary']}] - Exit to Keeper shell + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]Ctrl+q[/{t['primary']}] - Quit SuperShell + +[{t['text_dim']}]Press [/{t['text_dim']}][{t['primary']}]?[/{t['primary']}][{t['text_dim']}] for full keyboard shortcuts[/{t['text_dim']}]""" + + def _apply_theme_css(self): + """Apply dynamic CSS based on current theme""" + t = self.theme_colors + + try: + # Update detail content - will be refreshed when record is selected + if self.selected_record: + # Check if it's a Secrets Manager app record + if self.selected_record in self.app_record_uids: + self._display_secrets_manager_app(self.selected_record) + else: + self._display_record_detail(self.selected_record) + elif self.selected_folder: + self._display_folder_with_clickable_fields(self.selected_folder) + else: + # No record/folder selected - update the "My Vault" welcome screen + detail_widget = self.query_one("#detail_content", Static) + detail_widget.update(self._get_welcome_screen_content()) + + except Exception as e: + logging.debug(f"Error applying theme CSS: {e}") + + def action_show_preferences(self): + """Show preferences screen""" + self.push_screen(PreferencesScreen(self)) + + def compose(self) -> ComposeResult: + """Create the application layout""" + # Search bar at top (initially hidden) + with Horizontal(id="search_bar"): + yield Static("", id="search_display") + yield Static("", id="search_results_label") + yield Static("", id="user_info", classes="clickable-info") + yield Static("", id="device_status_info", classes="clickable-info") + + with Horizontal(id="main_container"): + with Vertical(id="folder_panel"): + yield Tree("[#00ff00]● My Vault[/#00ff00]", id="folder_tree") + with Vertical(id="record_panel"): + with VerticalScroll(id="record_detail"): + yield Static("", id="detail_content") + # Fixed footer for shortcuts + yield Static("", id="shortcuts_bar") + # Status bar at very bottom + yield Static("", id="status_bar") + + async def on_mount(self): + """Initialize the application when mounted""" + logging.debug("SuperShell on_mount called") + + # Initialize clickable fields list for detail panel + self.clickable_fields = [] + + # Cache for record output to avoid repeated get command calls + self._record_output_cache = {} + + # TOTP auto-refresh timer + self._totp_timer = None + self._totp_record_uid = None # Record currently showing TOTP + + # Sync vault data if needed + if not hasattr(self.params, 'record_cache') or not self.params.record_cache: + from .utils import SyncDownCommand + try: + logging.debug("Syncing vault data...") + SyncDownCommand().execute(self.params) + except Exception as e: + logging.error(f"Sync failed: {e}", exc_info=True) + self.exit(message=f"Sync failed: {str(e)}") + return + + try: + # Load vault data + logging.debug("Loading vault data...") + self._load_vault_data() + + # Load device and user info for header display + logging.debug("Loading device and user info...") + self.device_info = self._load_device_info() + self.whoami_info = self._load_whoami_info() + + # Setup folder tree with records + logging.debug("Setting up folder tree...") + self._setup_folder_tree() + + # Apply theme CSS after components are mounted + self._apply_theme_css() + + # Update initial content with welcome/help and shortcuts bar + t = self.theme_colors + detail_widget = self.query_one("#detail_content", Static) + help_content = f"""[bold {t['primary']}]● Keeper SuperShell[/bold {t['primary']}] + +[{t['secondary']}]A CLI-based vault with vi-style keyboard and mouse navigation.[/{t['secondary']}] + +[bold {t['primary_bright']}]Getting Started[/bold {t['primary_bright']}] + [{t['text_dim']}]•[/{t['text_dim']}] Use [{t['primary']}]j/k[/{t['primary']}] or [{t['primary']}]↑/↓[/{t['primary']}] to navigate up/down + [{t['text_dim']}]•[/{t['text_dim']}] Use [{t['primary']}]l[/{t['primary']}] or [{t['primary']}]→[/{t['primary']}] to expand folders + [{t['text_dim']}]•[/{t['text_dim']}] Use [{t['primary']}]h[/{t['primary']}] or [{t['primary']}]←[/{t['primary']}] to collapse folders + [{t['text_dim']}]•[/{t['text_dim']}] Press [{t['primary']}]/[/{t['primary']}] to search for records + [{t['text_dim']}]•[/{t['text_dim']}] Press [{t['primary']}]Esc[/{t['primary']}] to collapse and navigate back + +[bold {t['primary_bright']}]Vim-Style Navigation[/bold {t['primary_bright']}] + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]g[/{t['primary']}] - Go to top + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]G[/{t['primary']}] (Shift+G) - Go to bottom + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]Ctrl+d/u[/{t['primary']}] - Half page down/up + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]Ctrl+e/y[/{t['primary']}] - Scroll down/up one line + +[bold {t['primary_bright']}]Quick Actions[/bold {t['primary_bright']}] + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]p[/{t['primary']}] - Copy password + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]u[/{t['primary']}] - Copy username + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]c[/{t['primary']}] - Copy all + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]t[/{t['primary']}] - Toggle Detail/JSON view + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]m[/{t['primary']}] - Mask/Unmask secrets + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]d[/{t['primary']}] - Sync & refresh vault + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]![/{t['primary']}] - Exit to Keeper shell + [{t['text_dim']}]•[/{t['text_dim']}] [{t['primary']}]Ctrl+q[/{t['primary']}] - Quit SuperShell + +[{t['text_dim']}]Press [/{t['text_dim']}][{t['primary']}]?[/{t['primary']}][{t['text_dim']}] for full keyboard shortcuts[/{t['text_dim']}]""" + detail_widget.update(help_content) + + # Initialize shortcuts bar + self._update_shortcuts_bar() + + # Initialize search bar with placeholder + search_display = self.query_one("#search_display", Static) + search_display.update("[dim]Search... (Tab or /)[/dim]") + + # Initialize header info display (user and device) + self._update_header_info_display() + + # Focus the folder tree so vim keys work immediately + self.query_one("#folder_tree", Tree).focus() + + logging.debug("SuperShell ready!") + self._update_status("Navigate: j/k Tab: detail Help: ?") + except Exception as e: + logging.error(f"Error initializing SuperShell: {e}", exc_info=True) + self.exit(message=f"Error: {str(e)}") + + def on_resize(self, event) -> None: + """Handle window resize - update header to show/hide sections based on available width""" + self._update_header_info_display() + + def _load_vault_data(self): + """Load vault data from params""" + # Build record to folder mapping using subfolder_record_cache + # Records in root folder have folder_uid = '' (empty string) + self.record_to_folder = {} # Maps record_uid -> folder_uid + self.records_in_subfolders = set() # Track records that are in actual subfolders (not root) + if hasattr(self.params, 'subfolder_record_cache'): + for folder_uid, record_uids in self.params.subfolder_record_cache.items(): + for record_uid in record_uids: + self.record_to_folder[record_uid] = folder_uid + # Track records in non-root folders + if folder_uid and folder_uid != '': + self.records_in_subfolders.add(record_uid) + + # Track file attachments and their parent records + self.file_attachment_to_parent = {} # Maps attachment_uid -> parent_record_uid + self.record_file_attachments = {} # Maps record_uid -> list of attachment_uids + self.linked_record_to_parent = {} # Maps linked_record_uid -> parent_record_uid (for addressRef, cardRef, etc.) + self.record_linked_records = {} # Maps record_uid -> list of linked_record_uids + + # Secrets Manager app UIDs - identified by record type 'app' in the vault cache + # NOTE: SuperShell should not make direct API calls during initialization. + # Apps are identified by their record type instead of calling vault/get_applications_summary. + self.app_record_uids = set() + + # Build record dictionary + if hasattr(self.params, 'record_cache'): + for record_uid, record_data in self.params.record_cache.items(): + try: + # Try to load and decrypt the record + record = vault.KeeperRecord.load(self.params, record_uid) + + if record: + # Get record type - try multiple approaches + record_type = 'login' # Default + + # First, try get_record_type() method (most reliable) + if hasattr(record, 'get_record_type'): + try: + rt = record.get_record_type() + if rt: + record_type = rt + except: + pass + + # If still default, try record_type property + if record_type == 'login' and hasattr(record, 'record_type'): + try: + rt = record.record_type + if rt: + record_type = rt + except: + pass + + # Fallback: try to get from cached data + if record_type == 'login': + cached_rec = self.params.record_cache.get(record_uid, {}) + version = cached_rec.get('version', 2) + if version == 3: + try: + rec_data = cached_rec.get('data_unencrypted') + if rec_data: + if isinstance(rec_data, bytes): + rec_data = rec_data.decode('utf-8') + data_obj = json.loads(rec_data) + rt = data_obj.get('type') + if rt: + record_type = rt + except: + pass + elif version == 2: + record_type = 'legacy' + + record_dict = { + 'uid': record_uid, + 'title': record.title if hasattr(record, 'title') else 'Untitled', + 'folder_uid': self.record_to_folder.get(record_uid), + 'record_type': record_type, + } + + # Identify Secrets Manager apps by record type + if record_type == 'app': + self.app_record_uids.add(record_uid) + + # Extract fileRef fields to build parent-child relationship + # Handles both 'fileRef' type fields and 'script' type fields (rotation scripts) + file_refs = [] + if hasattr(record, 'fields'): + for field in record.fields: + field_type = getattr(field, 'type', None) + field_value = getattr(field, 'value', None) + + if field_type == 'fileRef': + # Direct fileRef field - value is list of UIDs + if field_value and isinstance(field_value, list): + for ref_uid in field_value: + if isinstance(ref_uid, str) and ref_uid: + file_refs.append(ref_uid) + self.file_attachment_to_parent[ref_uid] = record_uid + + elif field_type == 'script': + # Script field - value is list of objects with 'fileRef' property + if field_value and isinstance(field_value, list): + for script_item in field_value: + if isinstance(script_item, dict): + ref_uid = script_item.get('fileRef') + if ref_uid and isinstance(ref_uid, str): + file_refs.append(ref_uid) + self.file_attachment_to_parent[ref_uid] = record_uid + + if file_refs: + self.record_file_attachments[record_uid] = file_refs + + # Extract linked record references (addressRef, cardRef, etc.) + # These are records that are embedded/linked into this record + linked_refs = [] + if hasattr(record, 'fields'): + for field in record.fields: + field_type = getattr(field, 'type', None) + field_value = getattr(field, 'value', None) + + # addressRef, cardRef, etc. - records linked by reference + if field_type in ('addressRef', 'cardRef'): + if field_value and isinstance(field_value, list): + for ref_uid in field_value: + if isinstance(ref_uid, str) and ref_uid: + linked_refs.append(ref_uid) + self.linked_record_to_parent[ref_uid] = record_uid + + if linked_refs: + self.record_linked_records[record_uid] = linked_refs + + # Extract fields based on record type + if hasattr(record, 'login'): + record_dict['login'] = record.login + if hasattr(record, 'password'): + record_dict['password'] = record.password + if hasattr(record, 'login_url'): + record_dict['login_url'] = record.login_url + if hasattr(record, 'notes'): + record_dict['notes'] = record.notes + # Extract TOTP URL (v2 legacy records have it as 'totp' attribute) + if hasattr(record, 'totp') and record.totp: + record_dict['totp_url'] = record.totp + + # For TypedRecords, extract fields from the fields array + if hasattr(record, 'fields'): + custom_fields = [] + for field in record.fields: + field_type = getattr(field, 'type', None) + field_value = getattr(field, 'value', None) + field_label = getattr(field, 'label', None) + + # Extract password from typed field if not already set + if field_type == 'password' and field_value and not record_dict.get('password'): + if isinstance(field_value, list) and len(field_value) > 0: + record_dict['password'] = field_value[0] + elif isinstance(field_value, str): + record_dict['password'] = field_value + + # Extract login from typed field if not already set + if field_type == 'login' and field_value and not record_dict.get('login'): + if isinstance(field_value, list) and len(field_value) > 0: + record_dict['login'] = field_value[0] + elif isinstance(field_value, str): + record_dict['login'] = field_value + + # Extract URL from typed field if not already set + if field_type == 'url' and field_value and not record_dict.get('login_url'): + if isinstance(field_value, list) and len(field_value) > 0: + record_dict['login_url'] = field_value[0] + elif isinstance(field_value, str): + record_dict['login_url'] = field_value + + # Extract TOTP URL from oneTimeCode field + if field_type == 'oneTimeCode' and field_value and not record_dict.get('totp_url'): + if isinstance(field_value, list) and len(field_value) > 0: + record_dict['totp_url'] = field_value[0] + elif isinstance(field_value, str): + record_dict['totp_url'] = field_value + + # Collect custom fields (those with labels) + if field_label and field_value: + custom_fields.append({ + 'name': field_label, + 'value': str(field_value) if field_value else '' + }) + if custom_fields: + record_dict['custom_fields'] = custom_fields + + self.records[record_uid] = record_dict + except Exception as e: + logging.debug(f"Error loading record {record_uid}: {e}") + continue + + def _load_device_info(self): + """Load device info using the 'this-device' command""" + try: + from .utils import ThisDeviceCommand + + # Call get_device_info directly - returns dict without printing + return ThisDeviceCommand.get_device_info(self.params) + + except Exception as e: + logging.error(f"Error loading device info: {e}", exc_info=True) + return None + + def _load_whoami_info(self): + """Load whoami info using the 'whoami' command""" + try: + from .utils import WhoamiCommand + from .. import constants + import datetime + + # Call get_whoami_info directly - returns dict without printing + data = WhoamiCommand.get_whoami_info(self.params) + + # Add enterprise license info if available (similar to whoami --json) + if self.params.enterprise: + enterprise_licenses = [] + for x in self.params.enterprise.get('licenses', []): + license_info = {} + product_type_id = x.get('product_type_id', 0) + tier = x.get('tier', 0) + if product_type_id in (3, 5): + plan = 'Enterprise' if tier == 1 else 'Business' + elif product_type_id in (9, 10): + distributor = x.get('distributor', False) + plan = 'Distributor' if distributor else 'Managed MSP' + elif product_type_id in (11, 12): + plan = 'Keeper MSP' + elif product_type_id == 8: + plan = 'MC ' + ('Enterprise' if tier == 1 else 'Business') + else: + plan = 'Unknown' + if product_type_id in (5, 10, 12): + plan += ' Trial' + license_info['base_plan'] = plan + + paid = x.get('paid') is True + if paid: + exp = x.get('expiration') + if exp and exp > 0: + dt = datetime.datetime.fromtimestamp(exp // 1000) + datetime.timedelta(days=1) + n = datetime.datetime.now() + td = (dt - n).days + expires = str(dt.date()) + if td > 0: + expires += f' (in {td} days)' + else: + expires += ' (expired)' + license_info['license_expires'] = expires + + license_info['user_licenses'] = { + 'plan': x.get("number_of_seats", ""), + 'active': x.get("seats_allocated", ""), + 'invited': x.get("seats_pending", "") + } + + file_plan = x.get('file_plan') + file_plan_lookup = {fp[0]: fp[2] for fp in constants.ENTERPRISE_FILE_PLANS} + license_info['secure_file_storage'] = file_plan_lookup.get(file_plan, '') + + addons = [] + addon_lookup = {a[0]: a[1] for a in constants.MSP_ADDONS} + for ao in x.get('add_ons', []): + if isinstance(ao, dict): + enabled = ao.get('enabled') is True + if enabled: + name = ao.get('name') + addon_name = addon_lookup.get(name) or name + if name == 'secrets_manager': + api_count = ao.get('api_call_count') + if isinstance(api_count, int) and api_count > 0: + addon_name += f' ({api_count:,} API calls)' + elif name == 'connection_manager': + seats = ao.get('seats') + if isinstance(seats, int) and seats > 0: + addon_name += f' ({seats} licenses)' + addons.append(addon_name) + if addons: + license_info['add_ons'] = addons + + enterprise_licenses.append(license_info) + + if enterprise_licenses: + data['enterprise_licenses'] = enterprise_licenses + + # Add enterprise name if available + if 'enterprise_name' in self.params.enterprise: + data['enterprise_name'] = self.params.enterprise['enterprise_name'] + + return data + + except Exception as e: + logging.error(f"Error loading whoami info: {e}", exc_info=True) + return None + + def _update_header_info_display(self): + """Update the user and device info displays in the search bar area""" + try: + user_info_widget = self.query_one("#user_info", Static) + device_status_widget = self.query_one("#device_status_info", Static) + t = self.theme_colors + + # Get available width for the header info area (roughly half the screen minus search) + try: + available_width = self.size.width // 2 - 10 # Approximate space for header info + except: + available_width = 80 # Default fallback + + separator = " │ " + sep_len = 3 + + # === User info widget: email | DC (click shows whoami) === + user_parts = [] + user_len = 0 + + if hasattr(self, 'whoami_info') and self.whoami_info: + user = self.whoami_info.get('user', '') + if user: + # Truncate email if longer than 30 chars + max_email_len = 30 + if len(user) > max_email_len: + user_display = user[:max_email_len-3] + '...' + else: + user_display = user + user_parts.append(f"[{t['primary']}]{user_display}[/{t['primary']}]") + user_len = len(user_display) + + # Data center + data_center = self.whoami_info.get('data_center', '') + if data_center and user_len + sep_len + len(data_center) < available_width // 2: + user_parts.append(f"[{t['primary']}]{data_center}[/{t['primary']}]") + user_len += sep_len + len(data_center) + + if user_parts: + user_info_widget.update(separator.join(user_parts)) + else: + user_info_widget.update("") + + # === Device status widget: Stay Logged In | Logout (click shows device info) === + device_parts = [] + device_len = 0 + remaining_width = available_width - user_len - sep_len + + if hasattr(self, 'device_info') and self.device_info: + di = self.device_info + + # Stay Logged In status + stay_logged_in_len = 19 # "Stay Logged In: OFF" + if stay_logged_in_len < remaining_width: + if di.get('persistent_login'): + device_parts.append(f"[{t['text_dim']}]Stay Logged In:[/{t['text_dim']}] [green]ON[/green]") + else: + device_parts.append(f"[{t['text_dim']}]Stay Logged In:[/{t['text_dim']}] [red]OFF[/red]") + device_len = stay_logged_in_len + + # Logout timeout + timeout = di.get('effective_logout_timeout') or di.get('device_logout_timeout') or '' + if timeout: + timeout_str = str(timeout) + timeout_str = timeout_str.replace(' days', 'd').replace(' day', 'd') + timeout_str = timeout_str.replace(' hours', 'h').replace(' hour', 'h') + timeout_str = timeout_str.replace(' minutes', 'm').replace(' minute', 'm') + logout_text = f"Logout: {timeout_str}" + if device_len + sep_len + len(logout_text) < remaining_width: + device_parts.append(f"[{t['text_dim']}]Logout:[/{t['text_dim']}] [{t['primary_dim']}]{timeout_str}[/{t['primary_dim']}]") + + if device_parts: + device_status_widget.update(separator.join(device_parts)) + else: + device_status_widget.update("") + + except Exception as e: + logging.debug(f"Error updating header info display: {e}") + + def _display_whoami_info(self): + """Display whoami info in the detail panel""" + try: + # Clear any clickable fields from previous record display + self._clear_clickable_fields() + + t = self.theme_colors + detail_widget = self.query_one("#detail_content", Static) + + if not hasattr(self, 'whoami_info') or not self.whoami_info: + detail_widget.update("[dim]Whoami info unavailable[/dim]") + return + + wi = self.whoami_info + + lines = [f"[bold {t['primary']}]● User Information[/bold {t['primary']}]", ""] + + # Format basic fields + fields = [ + ('User', wi.get('user')), + ('Server', wi.get('server')), + ('Data Center', wi.get('data_center')), + ('Environment', wi.get('environment')), + ('Account Type', wi.get('account_type')), + ('Admin', 'Yes' if wi.get('admin') else 'No' if 'admin' in wi else None), + ('Enterprise', wi.get('enterprise_name')), + ('Renewal Date', wi.get('renewal_date')), + ('Storage Capacity', wi.get('storage_capacity')), + ('Storage Usage', wi.get('storage_usage')), + ('Storage Renewal', wi.get('storage_renewal_date')), + ('BreachWatch', 'Yes' if wi.get('breachwatch') else 'No'), + ('Reporting & Alerts', 'Yes' if wi.get('reporting_and_alerts') else 'No' if 'reporting_and_alerts' in wi else None), + ] + + for label, value in fields: + if value is not None: + lines.append(f" [{t['text_dim']}]{label}:[/{t['text_dim']}] [{t['primary']}]{value}[/{t['primary']}]") + + # Add enterprise license info if available + enterprise_licenses = wi.get('enterprise_licenses', []) + for lic in enterprise_licenses: + lines.append("") + lines.append(f"[bold {t['primary']}]● Enterprise License[/bold {t['primary']}]") + lines.append("") + + lic_fields = [ + ('Base Plan', lic.get('base_plan')), + ('License Expires', lic.get('license_expires')), + ('Secure File Storage', lic.get('secure_file_storage')), + ] + for label, value in lic_fields: + if value: + lines.append(f" [{t['text_dim']}]{label}:[/{t['text_dim']}] [{t['primary']}]{value}[/{t['primary']}]") + + # User licenses + user_lic = lic.get('user_licenses', {}) + if user_lic: + plan_seats = user_lic.get('plan', '') + active = user_lic.get('active', '') + invited = user_lic.get('invited', '') + if plan_seats: + lines.append(f" [{t['text_dim']}]User Licenses:[/{t['text_dim']}] [{t['primary']}]{plan_seats}[/{t['primary']}] [{t['text_dim']}](Active: {active}, Invited: {invited})[/{t['text_dim']}]") + + # Add-ons + addons = lic.get('add_ons', []) + if addons: + lines.append("") + lines.append(f" [{t['text_dim']}]Add-ons:[/{t['text_dim']}]") + for addon in addons: + lines.append(f" [{t['primary']}]• {addon}[/{t['primary']}]") + + detail_widget.update("\n".join(lines)) + self._update_status("User information | Press Esc to return") + self._update_shortcuts_bar(clear=True) + + except Exception as e: + logging.debug(f"Error displaying whoami info: {e}") + + def _display_device_info(self): + """Display this-device info in the detail panel""" + try: + # Clear any clickable fields from previous record display + self._clear_clickable_fields() + + t = self.theme_colors + detail_widget = self.query_one("#detail_content", Static) + + if not hasattr(self, 'device_info') or not self.device_info: + detail_widget.update("[dim]Device info unavailable[/dim]") + return + + di = self.device_info + + lines = [f"[bold {t['primary']}]● Device Information[/bold {t['primary']}]", ""] + + # Helper for ON/OFF display + def on_off(val): + return "[green]ON[/green]" if val else "[red]OFF[/red]" + + fields = [ + ('Device Name', di.get('device_name')), + ('Data Key Present', 'Yes' if di.get('data_key_present') else 'No'), + ('IP Auto Approve', on_off(di.get('ip_auto_approve'))), + ('Persistent Login', on_off(di.get('persistent_login'))), + ('Security Key No PIN', on_off(di.get('security_key_no_pin'))), + ('Device Logout Timeout', di.get('device_logout_timeout')), + ('Enterprise Logout Timeout', di.get('enterprise_logout_timeout')), + ('Effective Logout Timeout', di.get('effective_logout_timeout')), + ('Is SSO User', 'Yes' if di.get('is_sso_user') else 'No'), + ('Config File', di.get('config_file')), + ] + + for label, value in fields: + if value is not None: + lines.append(f" [{t['text_dim']}]{label}:[/{t['text_dim']}] [{t['primary']}]{value}[/{t['primary']}]") + + detail_widget.update("\n".join(lines)) + self._update_status("Device information | Press Esc to return") + self._update_shortcuts_bar(clear=True) + + except Exception as e: + logging.debug(f"Error displaying device info: {e}") + + def _is_displayable_record(self, record: dict) -> bool: + """Check if a record should be displayed in normal folder structure. + Excludes file attachments, linked records, and Secrets Manager app records.""" + record_uid = record.get('uid') + + # Exclude file attachments - they'll be shown under their parent + if record_uid in self.file_attachment_to_parent: + return False + + # Exclude linked records (addressRef, cardRef) - they'll be shown under their parent + if record_uid in self.linked_record_to_parent: + return False + + # Exclude Secrets Manager app records - they go in virtual folder + if record_uid in self.app_record_uids: + return False + + return True + + def _add_record_with_attachments(self, parent_node, record: dict, idx: int, auto_expand: bool = False, total_count: int = 0): + """Add a record to the tree. Records with attachments show 📎 indicator.""" + record_uid = record.get('uid') + record_title = record.get('title', 'Untitled') + t = self.theme_colors # Theme colors + + # Calculate width for right-aligned numbers based on total count + width = len(str(total_count)) if total_count > 0 else len(str(idx)) + idx_str = str(idx).rjust(width) + + # Check if this record has file attachments or linked records + attachments = self.record_file_attachments.get(record_uid, []) + linked_records = self.record_linked_records.get(record_uid, []) + + # Add [+] indicator if record has attachments + attachment_indicator = f" [{t['text_dim']}]\\[+][/{t['text_dim']}]" if (attachments or linked_records) else "" + + record_label = f"[{t['record_num']}]{idx_str}.[/{t['record_num']}] [{t['record']}]{rich_escape(str(record_title))}[/{t['record']}]{attachment_indicator}" + + # All records are leaf nodes for consistent alignment + parent_node.add_leaf( + record_label, + data={'type': 'record', 'uid': record_uid, 'has_attachments': bool(attachments or linked_records)} + ) + + def _setup_folder_tree(self): + """Setup the folder tree structure with records as children""" + tree = self.query_one("#folder_tree", Tree) + tree.clear() + t = self.theme_colors # Theme colors + + # Root node represents "My Vault" + root = tree.root + root_folder = self.params.root_folder + if root_folder: + root.label = f"[{t['root']}]● {root_folder.name}[/{t['root']}]" + root.data = {'type': 'root', 'uid': None} + else: + root.label = f"[{t['root']}]● My Vault[/{t['root']}]" + root.data = {'type': 'root', 'uid': None} + + # Determine if we should auto-expand (when filtering with < AUTO_EXPAND_THRESHOLD results) + auto_expand = False + if self.filtered_record_uids is not None and len(self.filtered_record_uids) < self.AUTO_EXPAND_THRESHOLD: + auto_expand = True + + # Build tree recursively from root using proper folder structure + def add_folder_node(parent_tree_node, folder_node, folder_uid): + """Recursively add folder and its children to tree""" + if not folder_node: + return None + + # Get records in this folder (filtered if search is active) + # Exclude file attachments and 'app' type records + folder_records = [] + for r in self.records.values(): + if r.get('folder_uid') == folder_uid and self._is_displayable_record(r): + # Apply filter if active + if self.filtered_record_uids is None or r['uid'] in self.filtered_record_uids: + folder_records.append(r) + + # Get subfolders that have matching records (recursively) + subfolders_with_records = [] + if hasattr(folder_node, 'subfolders') and folder_node.subfolders: + for subfolder_uid in folder_node.subfolders: + if subfolder_uid in self.params.folder_cache: + subfolder = self.params.folder_cache[subfolder_uid] + # Check if this subfolder has any matching records + if self._folder_has_matching_records(subfolder_uid): + subfolders_with_records.append((subfolder.name.lower() if subfolder.name else '', subfolder_uid, subfolder)) + subfolders_with_records.sort(key=lambda x: x[0]) + + # Skip this folder only when SEARCHING and it has no matching records/subfolders + # When not searching (filtered_record_uids is None), show all folders including empty ones + if self.filtered_record_uids is not None and not folder_records and not subfolders_with_records: + return None + + # Determine label and color based on folder type + color = t['folder'] + if folder_node.type == 'shared_folder': + # Shared folder: bold green name with share icon after + label = f"[bold {color}]{folder_node.name}[/bold {color}] 👥" + else: + # Regular folder: bold green name + label = f"[bold {color}]{folder_node.name}[/bold {color}]" + + # Add this folder to the tree with color + tree_node = parent_tree_node.add( + label, + data={'type': 'folder', 'uid': folder_uid} + ) + + # Add subfolders + for _, subfolder_uid, subfolder in subfolders_with_records: + add_folder_node(tree_node, subfolder, subfolder_uid) + + # Sort and add records (with their file attachments as children) + folder_records.sort(key=lambda r: r.get('title', '').lower()) + total_records = len(folder_records) + + for idx, record in enumerate(folder_records, start=1): + self._add_record_with_attachments(tree_node, record, idx, auto_expand, total_records) + + # Auto-expand if we're in search mode with < 100 results + if auto_expand: + tree_node.expand() + + return tree_node + + # Get and sort root-level folders that have matching records + root_folders = [] + if root_folder and hasattr(root_folder, 'subfolders'): + for folder_uid in root_folder.subfolders: + if folder_uid in self.params.folder_cache: + folder = self.params.folder_cache[folder_uid] + # Only include folders with matching records + if self._folder_has_matching_records(folder_uid): + root_folders.append((folder.name.lower() if folder.name else '', folder_uid, folder)) + root_folders.sort(key=lambda x: x[0]) + + # Add root folders + for _, folder_uid, folder in root_folders: + add_folder_node(root, folder, folder_uid) + + # Add root-level records (records not in any subfolder) + # A record is at root if it's NOT in any actual subfolder (not in records_in_subfolders) + # Exclude file attachments and Secrets Manager app records + root_records = [] + for r in self.records.values(): + record_uid = r.get('uid') + # Record is at root if it's not in any subfolder + is_root_record = record_uid not in self.records_in_subfolders + if is_root_record and self._is_displayable_record(r): + # Apply filter if active + if self.filtered_record_uids is None or record_uid in self.filtered_record_uids: + root_records.append(r) + root_records.sort(key=lambda r: r.get('title', '').lower()) + total_root_records = len(root_records) + + for idx, record in enumerate(root_records, start=1): + self._add_record_with_attachments(root, record, idx, auto_expand, total_root_records) + + # Add virtual "Secrets Manager Apps" folder at the bottom for app records + app_records = [] + for r in self.records.values(): + if r.get('uid') in self.app_record_uids: + # Apply filter if active + if self.filtered_record_uids is None or r['uid'] in self.filtered_record_uids: + app_records.append(r) + + if app_records: + app_records.sort(key=lambda r: r.get('title', '').lower()) + total_app_records = len(app_records) + # Create virtual folder with distinct styling + apps_folder = root.add( + f"[{t['virtual_folder']}]★ Secrets Manager Apps[/{t['virtual_folder']}]", + data={'type': 'virtual_folder', 'uid': '__secrets_manager_apps__'} + ) + + for idx, record in enumerate(app_records, start=1): + self._add_record_with_attachments(apps_folder, record, idx, auto_expand, total_app_records) + + if auto_expand: + apps_folder.expand() + + # Expand root + root.expand() + + def _folder_has_matching_records(self, folder_uid: str) -> bool: + """Check if a folder should be displayed. + When no search filter is active, all folders are shown (including empty ones). + When searching, only folders with matching records are shown.""" + # If no search filter, show all folders including empty ones + if self.filtered_record_uids is None: + return True + + # When searching, check if this folder has any matching displayable records + for r in self.records.values(): + if r.get('folder_uid') == folder_uid and self._is_displayable_record(r): + if r['uid'] in self.filtered_record_uids: + return True + + # Check subfolders recursively + if folder_uid in self.params.folder_cache: + folder = self.params.folder_cache[folder_uid] + if hasattr(folder, 'subfolders') and folder.subfolders: + for subfolder_uid in folder.subfolders: + if self._folder_has_matching_records(subfolder_uid): + return True + + return False + + def _restore_tree_selection(self, tree: Tree): + """Restore tree selection to previously selected record or folder""" + try: + target_uid = self.selected_record or self.selected_folder + if not target_uid: + return + + # Find and select the node in the tree + def find_and_select(node): + if hasattr(node, 'data') and node.data: + data = node.data + node_uid = data.get('uid') if isinstance(data, dict) else None + if node_uid == target_uid: + # Found the node - select it + tree.select_node(node) + node.expand() + # Also expand parent nodes + parent = node.parent + while parent: + parent.expand() + parent = parent.parent + return True + # Check children + for child in node.children: + if find_and_select(child): + return True + return False + + find_and_select(tree.root) + + # Update the detail pane if a record was selected + if self.selected_record: + self._display_record_detail(self.selected_record) + elif self.selected_folder: + folder = self.params.folder_cache.get(self.selected_folder) + folder_name = folder.name if folder else "Unknown" + detail = self.query_one("#detail_content", Static) + t = self.theme_colors + detail.update(f"[bold {t['primary']}]📁 {rich_escape(str(folder_name))}[/bold {t['primary']}]") + + except Exception as e: + logging.error(f"Error restoring tree selection: {e}", exc_info=True) + + def _select_record_in_tree(self, tree: Tree, record_uid: str): + """Select a specific record in the tree by its UID""" + try: + def find_and_select(node): + if hasattr(node, 'data') and node.data: + data = node.data + node_uid = data.get('uid') if isinstance(data, dict) else None + if node_uid == record_uid: + # Found the node - select it + tree.select_node(node) + # Expand parent nodes to make visible + parent = node.parent + while parent: + parent.expand() + parent = parent.parent + return True + # Check children + for child in node.children: + if find_and_select(child): + return True + return False + + find_and_select(tree.root) + except Exception as e: + logging.debug(f"Error selecting record in tree: {e}") + + def _search_records(self, query: str) -> set: + """ + Search records with smart partial matching. + Returns set of matching record UIDs. + + Search logic: + - Tokenizes query by whitespace + - Each token must match (partial) at least one field OR folder name + - Order doesn't matter: "aws prod us" matches "us production aws" + - Searches: title, url, custom field values, notes, AND folder name + - If folder name matches, all records in that folder are candidates + (but other tokens must still match the record) + """ + if not query or not query.strip(): + return None # None means show all + + # Tokenize query - split by whitespace and lowercase + query_tokens = [token.lower().strip() for token in query.split() if token.strip()] + if not query_tokens: + return None + + matching_uids = set() + + # Build folder name cache for quick lookup + folder_names = {} # folder_uid -> folder_name (lowercase) + if hasattr(self.params, 'folder_cache'): + for folder_uid, folder in self.params.folder_cache.items(): + if hasattr(folder, 'name') and folder.name: + folder_names[folder_uid] = folder.name.lower() + + for record_uid, record in self.records.items(): + # Build searchable text from all record fields + record_parts = [] + + # Record UID - important for searching by UID + record_parts.append(record_uid) + + # Title + if record.get('title'): + record_parts.append(str(record['title'])) + + # URL + if record.get('login_url'): + record_parts.append(str(record['login_url'])) + + # Username/Login + if record.get('login'): + record_parts.append(str(record['login'])) + + # Custom fields + if record.get('custom_fields'): + for field in record['custom_fields']: + name = field.get('name', '') + value = field.get('value', '') + if name: + record_parts.append(str(name)) + if value: + record_parts.append(str(value)) + + # Notes + if record.get('notes'): + record_parts.append(str(record['notes'])) + + # Combine record text + record_text = ' '.join(record_parts).lower() + + # Get folder UID and name for this record + folder_uid = self.record_to_folder.get(record_uid) + folder_name = folder_names.get(folder_uid, '') if folder_uid else '' + + # Combined text includes record fields, folder UID, AND folder name + combined_text = record_text + ' ' + (folder_uid.lower() if folder_uid else '') + ' ' + folder_name + + # Check if ALL query tokens match somewhere (record OR folder) + # This allows "customer 123 google" to match record "google" in folder "Customer 123" + all_tokens_match = all( + token in combined_text + for token in query_tokens + ) + + if all_tokens_match: + matching_uids.add(record_uid) + + return matching_uids + + def _perform_live_search(self, query: str) -> int: + """ + Perform live search and update tree. + Returns count of matching records. + """ + self.search_query = query + + # Get matching record UIDs + self.filtered_record_uids = self._search_records(query) + + # Rebuild tree with filtered results + self._setup_folder_tree() + + # Return count + if self.filtered_record_uids is None: + return len(self.records) + else: + return len(self.filtered_record_uids) + + def _format_record_for_tui(self, record_uid: str) -> str: + """Format record details for TUI display using the 'get' command output""" + t = self.theme_colors # Get theme colors + + try: + # Use the get command (same as shell) to fetch record details + output = self._get_record_output(record_uid, format_type='detail') + # Strip ANSI codes from command output + output = self._strip_ansi_codes(output) + + if not output or output.strip() == '': + return "[red]Failed to get record details[/red]" + + # Escape any Rich markup characters in the output + output = output.replace('[', '\\[').replace(']', '\\]') + + # Parse and format the output more cleanly + lines = [] + current_section = None + prev_was_blank = False + seen_first_user = False # Track if we've seen first user in permissions section + in_totp_section = False + # Section headers - only when value is empty + section_headers = {'Custom Fields', 'Attachments', 'User Permissions', + 'Shared Folder Permissions', 'Share Admins', 'One-Time Share URL'} + + def is_section_header(key, value): + """Check if key is a section header (only when value is empty)""" + if value: + return False + if key in section_headers: + return True + for header in section_headers: + if key.startswith(header): + return True + return False + + for line in output.split('\n'): + stripped = line.strip() + + # Skip multiple consecutive blank lines + if not stripped: + if not prev_was_blank and lines: + prev_was_blank = True + continue + prev_was_blank = False + + # Check if line contains a colon (key: value format) + if ':' in stripped: + parts = stripped.split(':', 1) + key = parts[0].strip() + value = parts[1].strip() if len(parts) > 1 else '' + + # UID - use theme primary color + if key in ['UID', 'Record UID']: + lines.append(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + # Title - bold primary with label + elif key in ['Title', 'Name'] and not current_section: + lines.append(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [bold {t['primary']}]{rich_escape(str(value))}[/bold {t['primary']}]") + # Type field + elif key == 'Type': + display_type = value if value else 'app' if record_uid in self.app_record_uids else '' + lines.append(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary_dim']}]{rich_escape(str(display_type))}[/{t['primary_dim']}]") + # Notes - always a section + elif key == 'Notes': + lines.append("") + lines.append(f"[bold {t['secondary']}]Notes:[/bold {t['secondary']}]") + current_section = 'Notes' + if value: + lines.append(f" [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + # TOTP fields - skip, will be calculated from stored URL + elif key == 'TOTP URL': + pass + elif key == 'Two Factor Code': + pass + # Section headers + elif is_section_header(key, value): + current_section = key + seen_first_user = False + in_totp_section = False + if lines: + lines.append("") + lines.append(f"[bold {t['secondary']}]{key}:[/bold {t['secondary']}]") + # Regular key-value pairs + elif value: + if key == 'User' and current_section == 'User Permissions': + if seen_first_user: + lines.append("") + seen_first_user = True + if current_section: + lines.append(f" [{t['text_dim']}]{rich_escape(str(key))}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + else: + lines.append(f"[{t['text_dim']}]{rich_escape(str(key))}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + elif key: + lines.append(f" [{t['primary_dim']}]{rich_escape(str(key))}[/{t['primary_dim']}]") + else: + # Lines without colons - continuation of notes or other content + if current_section == 'Notes': + lines.append(f" [{t['primary']}]{rich_escape(str(stripped))}[/{t['primary']}]") + elif stripped: + lines.append(f" [{t['primary_dim']}]{rich_escape(str(stripped))}[/{t['primary_dim']}]") + + return "\n".join(lines) + + except Exception as e: + logging.error(f"Error formatting record for TUI: {e}", exc_info=True) + error_msg = str(e).replace('[', '\\[').replace(']', '\\]') + return f"[red]Error formatting record:[/red]\n{error_msg}" + + def _format_folder_for_tui(self, folder_uid: str) -> str: + """Format folder/shared folder details for TUI display""" + t = self.theme_colors # Get theme colors + + try: + # Create a StringIO buffer to capture stdout from get command + stdout_buffer = io.StringIO() + old_stdout = sys.stdout + sys.stdout = stdout_buffer + + # Execute the get command for folder + get_cmd = RecordGetUidCommand() + get_cmd.execute(self.params, uid=folder_uid, format='detail') + + # Restore stdout + sys.stdout = old_stdout + + # Get the captured output + output = stdout_buffer.getvalue() + # Strip ANSI codes + output = self._strip_ansi_codes(output) + + if not output or output.strip() == '': + # Fallback to basic folder info if get command didn't work + folder = self.params.folder_cache.get(folder_uid) + if folder: + folder_type = folder.get_folder_type() if hasattr(folder, 'get_folder_type') else folder.type + return ( + f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]\n" + f"[bold {t['primary']}]{rich_escape(str(folder.name))}[/bold {t['primary']}]\n" + f"[{t['text_dim']}]UID:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(folder_uid))}[/{t['primary']}]\n" + f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]\n\n" + f"[{t['secondary']}]{'Type':>20}:[/{t['secondary']}] [{t['primary']}]{rich_escape(str(folder_type))}[/{t['primary']}]\n\n" + f"[{t['primary_dim']}]Expand folder (press 'l' or →) to view records[/{t['primary_dim']}]" + ) + return "[red]Folder not found[/red]" + + # Format the output with proper alignment and theme colors + lines = [] + lines.append(f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]") + + for line in output.split('\n'): + line = line.strip() + if not line: + lines.append("") + continue + + # Check if line contains a colon (key: value format) + if ':' in line: + parts = line.split(':', 1) + if len(parts) == 2: + key = parts[0].strip() + value = parts[1].strip() + + # Special formatting for headers + if key in ['Shared Folder UID', 'Folder UID', 'Team UID']: + lines.append(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + elif key == 'Name': + lines.append(f"[bold {t['primary']}]{rich_escape(str(value))}[/bold {t['primary']}]") + # Section headers (no value or short value) + elif key in ['Record Permissions', 'User Permissions', 'Team Permissions', 'Share Administrators']: + lines.append("") + lines.append(f"[bold {t['primary_bright']}]{key}:[/bold {t['primary_bright']}]") + # Boolean values + elif value.lower() in ['true', 'false']: + color = t['primary'] if value.lower() == 'true' else t['primary_dim'] + lines.append(f"[{t['secondary']}]{rich_escape(str(key)):>25}:[/{t['secondary']}] [{color}]{rich_escape(str(value))}[/{color}]") + # Regular key-value pairs + else: + # Add indentation for permission entries + if key and not key[0].isspace(): + lines.append(f"[{t['secondary']}] • {rich_escape(str(key))}:[/{t['secondary']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + else: + lines.append(f"[{t['secondary']}]{rich_escape(str(key)):>25}:[/{t['secondary']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]") + else: + lines.append(f"[{t['primary']}]{rich_escape(str(line))}[/{t['primary']}]") + else: + # Lines without colons (section content) + if line: + lines.append(f"[{t['primary']}] {rich_escape(str(line))}[/{t['primary']}]") + + lines.append(f"\n[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]") + return "\n".join(lines) + + except Exception as e: + sys.stdout = old_stdout + logging.error(f"Error formatting folder for TUI: {e}", exc_info=True) + return f"[red]Error displaying folder:[/red]\n{str(e)}" + + def _get_record_output(self, record_uid: str, format_type: str = 'detail', include_dag: bool = False) -> str: + """Get record output using Commander's get command (cached for performance)""" + # Check cache first + cache_key = f"{record_uid}:{format_type}:{include_dag}" + if hasattr(self, '_record_output_cache') and cache_key in self._record_output_cache: + return self._record_output_cache[cache_key] + + try: + # Create a StringIO buffer to capture stdout + stdout_buffer = io.StringIO() + old_stdout = sys.stdout + sys.stdout = stdout_buffer + + try: + # Execute the get command + get_cmd = RecordGetUidCommand() + get_cmd.execute(self.params, uid=record_uid, format=format_type, include_dag=include_dag) + finally: + # Always restore stdout + sys.stdout = old_stdout + + # Get the captured output and cache it + output = stdout_buffer.getvalue() + + # If output is empty or error, don't cache it + if output and not output.startswith("Error"): + if not hasattr(self, '_record_output_cache'): + self._record_output_cache = {} + self._record_output_cache[cache_key] = output + + return output if output else "Record data not available" + + except Exception as e: + if sys.stdout != old_stdout: + sys.stdout = old_stdout + logging.error(f"Error getting record output for {record_uid}: {e}", exc_info=True) + return f"Error displaying record: {str(e)}" + + def _get_rotation_info(self, record_uid: str) -> Optional[Dict[str, Any]]: + """Get rotation info for pamUser records from DAG and rotation cache. + + NOTE: This method fetches DAG data which makes API calls. This is acceptable + because it only runs when a user explicitly views a pamUser record in SuperShell, + not during initialization or sync operations. + """ + try: + record_data = self.records.get(record_uid, {}) + record_type = record_data.get('record_type', '') + + # Check if this is a PAM User record (or has rotation data configured) + has_rotation_data = record_uid in self.params.record_rotation_cache + is_pam_user = record_type == 'pamUser' + + if not is_pam_user and not has_rotation_data: + return None + + from .. import vault + + rotation_info = {} + rotation_profile = None + config_uid = None + resource_uid = None + + # Get rotation data from cache + rotation_data = self.params.record_rotation_cache.get(record_uid) + + # Only fetch DAG data for pamUser records (requires PAM infrastructure) + if is_pam_user: + try: + from .tunnel.port_forward.tunnel_helpers import get_keeper_tokens + from .tunnel.port_forward.TunnelGraph import TunnelDAG + from keeper_dag.edge import EdgeType + + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(self.params) + tdag = TunnelDAG(self.params, encrypted_session_token, encrypted_transmission_key, record_uid, + transmission_key=transmission_key) + + if tdag.linking_dag.has_graph: + record_vertex = tdag.linking_dag.get_vertex(record_uid) + if record_vertex: + for parent_vertex in record_vertex.belongs_to_vertices(): + acl_edge = record_vertex.get_edge(parent_vertex, EdgeType.ACL) + if acl_edge: + edge_content = acl_edge.content_as_dict or {} + belongs_to = edge_content.get('belongs_to', False) + is_iam_user = edge_content.get('is_iam_user', False) + rotation_settings = edge_content.get('rotation_settings', {}) + is_noop = rotation_settings.get('noop', False) if isinstance(rotation_settings, dict) else False + + if is_noop: + rotation_profile = 'Scripts Only' + config_uid = parent_vertex.uid + elif is_iam_user: + rotation_profile = 'IAM User' + config_uid = parent_vertex.uid + elif belongs_to: + rotation_profile = 'General' + resource_uid = parent_vertex.uid + except Exception: + pass # DAG fetch failed, continue with cached data only + + # Get config UID from rotation cache if not from DAG + if not config_uid and rotation_data: + config_uid = rotation_data.get('configuration_uid') + if not resource_uid and rotation_data: + resource_uid = rotation_data.get('resource_uid') + # If resource_uid equals config_uid, it's an IAM/NOOP user, not General + if resource_uid and resource_uid == config_uid: + resource_uid = None + + # Get configuration name + config_name = None + if config_uid: + config_record = vault.KeeperRecord.load(self.params, config_uid) + if config_record: + config_name = config_record.title + + # Get resource name + resource_name = None + if resource_uid: + resource_record = vault.KeeperRecord.load(self.params, resource_uid) + if resource_record: + resource_name = resource_record.title + + # Determine rotation status + if not rotation_data and not rotation_profile: + rotation_info['status'] = 'Not configured' + return rotation_info + + # Rotation status + if rotation_data: + disabled = rotation_data.get('disabled', False) + rotation_info['status'] = 'Disabled' if disabled else 'Enabled' + else: + rotation_info['status'] = 'Enabled' + + # Rotation profile + if rotation_profile: + rotation_info['profile'] = rotation_profile + + # PAM Configuration + if config_name: + rotation_info['config_name'] = config_name + elif config_uid: + rotation_info['config_uid'] = config_uid + + # Resource (for General profile) + if resource_name: + rotation_info['resource_name'] = resource_name + elif resource_uid: + rotation_info['resource_uid'] = resource_uid + + # Schedule + if rotation_data and rotation_data.get('schedule'): + try: + schedule_json = json.loads(rotation_data['schedule']) + if isinstance(schedule_json, list) and len(schedule_json) > 0: + schedule = schedule_json[0] + schedule_type = schedule.get('type', 'ON_DEMAND') + if schedule_type == 'ON_DEMAND': + rotation_info['schedule'] = 'On Demand' + else: + # Format schedule description + time_str = schedule.get('utcTime', schedule.get('time', '')) + tz = schedule.get('tz', 'UTC') + if schedule_type == 'DAILY': + interval = schedule.get('intervalCount', 1) + rotation_info['schedule'] = f"Every {interval} day(s) at {time_str} {tz}" + elif schedule_type == 'WEEKLY': + weekday = schedule.get('weekday', '') + rotation_info['schedule'] = f"Weekly on {weekday} at {time_str} {tz}" + elif schedule_type == 'MONTHLY_BY_DAY': + day = schedule.get('monthDay', 1) + rotation_info['schedule'] = f"Monthly on day {day} at {time_str} {tz}" + elif schedule_type == 'MONTHLY_BY_WEEKDAY': + week = schedule.get('occurrence', 'FIRST') + weekday = schedule.get('weekday', '') + rotation_info['schedule'] = f"{week.title()} {weekday} of month at {time_str} {tz}" + else: + rotation_info['schedule'] = f"{schedule_type} at {time_str} {tz}" + else: + rotation_info['schedule'] = 'On Demand' + except (json.JSONDecodeError, KeyError, ValueError, TypeError) as e: + logging.debug(f"Error parsing schedule: {e}") + rotation_info['schedule'] = 'On Demand' + + # Last rotation + if rotation_data and rotation_data.get('last_rotation'): + last_rotation_ts = rotation_data['last_rotation'] + if last_rotation_ts > 0: + import datetime + last_rotation_dt = datetime.datetime.fromtimestamp(last_rotation_ts / 1000) + rotation_info['last_rotated'] = last_rotation_dt.strftime("%b %d, %Y at %I:%M %p") + + # Show rotation status if available + last_status = rotation_data.get('last_rotation_status') + # RecordRotationStatus enum: 0=NOT_ROTATED, 1=IN_PROGRESS, 2=SUCCESS, 3=FAILURE + if last_status is not None: + status_map = {0: 'Not Rotated', 1: 'In Progress', 2: 'Success', 3: 'Failure'} + rotation_info['last_status'] = status_map.get(last_status, f'Unknown ({last_status})') + + return rotation_info if rotation_info else None + + except (KeyError, AttributeError, ValueError, TypeError) as e: + logging.debug(f"Error getting rotation info: {e}") + return None + + def _clear_clickable_fields(self): + """Remove any dynamically mounted clickable field widgets""" + try: + detail_scroll = self.query_one("#record_detail", VerticalScroll) + # Collect all widgets to remove first, then batch remove + widgets_to_remove = [] + widgets_to_remove.extend(detail_scroll.query(ClickableDetailLine)) + widgets_to_remove.extend(detail_scroll.query(ClickableField)) + widgets_to_remove.extend(detail_scroll.query(ClickableRecordUID)) + # Also remove any dynamically added Static widgets (but keep #detail_content) + for widget in detail_scroll.query(Static): + if widget.id != "detail_content" and widget.id != "shortcuts_bar": + widgets_to_remove.append(widget) + # Batch remove all at once + for widget in widgets_to_remove: + widget.remove() + self.clickable_fields.clear() + except Exception as e: + logging.debug(f"Error clearing clickable fields: {e}") + + def _display_record_with_clickable_fields(self, record_uid: str): + """Display record details with clickable fields for copy-on-click""" + t = self.theme_colors + detail_scroll = self.query_one("#record_detail", VerticalScroll) + detail_widget = self.query_one("#detail_content", Static) + + # Clear previous clickable fields + self._clear_clickable_fields() + + # Get and parse record output + output = self._get_record_output(record_uid, format_type='detail') + output = self._strip_ansi_codes(output) + + if not output or output.strip() == '': + detail_widget.update("[red]Failed to get record details[/red]") + return + + # Hide the static placeholder + detail_widget.update("") + + # Collect all widgets first, then mount in batch for performance + widgets_to_mount = [] + + # Helper to collect clickable lines (batched for performance) + def mount_line(content: str, copy_value: str = None, is_password: bool = False): + line = ClickableDetailLine(content, copy_value, record_uid=record_uid, is_password=is_password) + widgets_to_mount.append(line) + + # Get the actual record data for password lookup + record_data = self.records.get(record_uid, {}) + actual_password = record_data.get('password', '') + + # Parse and create clickable lines + current_section = None + seen_first_user = False # Track if we've seen first user in permissions section + totp_displayed = False # Track if TOTP has been displayed + totp_url = record_data.get('totp_url') # Get TOTP URL once for use in display + # Section headers are only headers when they have NO value on the same line + section_headers = {'Custom Fields', 'Attachments', 'User Permissions', + 'Shared Folder Permissions', 'Share Admins', 'One-Time Share URL'} + + def display_totp(): + """Helper to display TOTP section""" + nonlocal totp_displayed + if totp_url and not totp_displayed: + from ..record import get_totp_code + try: + result = get_totp_code(totp_url) + if result: + code, seconds_remaining, period = result + mount_line("", None) # Blank line before TOTP + mount_line(f"[bold {t['secondary']}]Two-Factor Authentication:[/bold {t['secondary']}]", None) + mount_line(f" [{t['text_dim']}]Code:[/{t['text_dim']}] [bold {t['primary']}]{code}[/bold {t['primary']}] [{t['text_dim']}]valid for[/{t['text_dim']}] [bold {t['secondary']}]{seconds_remaining} sec[/bold {t['secondary']}]", code) + mount_line("", None) # Blank line after TOTP + totp_displayed = True + except Exception as e: + logging.debug(f"Error calculating TOTP: {e}") + + def is_section_header(key, value): + """Check if key is a section header (only when value is empty)""" + if value: # If there's a value on same line, it's not a section header + return False + if key in section_headers: + return True + # Handle cases like "Share Admins (64, showing first 10)" + for header in section_headers: + if key.startswith(header): + return True + return False + + # Get attachments for this record + file_attachment_uids = self.record_file_attachments.get(record_uid, []) + linked_record_uids = self.record_linked_records.get(record_uid, []) + attachments_displayed = False + + def display_attachments(): + """Helper to display file attachments section""" + nonlocal attachments_displayed + if attachments_displayed: + return + if not file_attachment_uids and not linked_record_uids: + return + + mount_line("", None) # Blank line before attachments + mount_line(f"[bold {t['secondary']}]File Attachments:[/bold {t['secondary']}]", None) + + # Display file attachments (use + symbol instead of emoji) + for att_uid in file_attachment_uids: + att_record = self.records.get(att_uid, {}) + att_title = att_record.get('title', att_uid) + mount_line(f" [{t['text_dim']}]+[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(att_title))}[/{t['primary']}]", att_uid) + + # Display linked records (addressRef, cardRef, etc.) + for link_uid in linked_record_uids: + link_record = self.records.get(link_uid, {}) + link_title = link_record.get('title', link_uid) + link_type = link_record.get('record_type', '') + type_label = f" ({rich_escape(str(link_type))})" if link_type else "" + mount_line(f" [{t['text_dim']}]→[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(link_title))}[/{t['primary']}][{t['text_dim']}]{type_label}[/{t['text_dim']}]", link_uid) + + attachments_displayed = True + + # Rotation info for pamUser records + rotation_info = self._get_rotation_info(record_uid) + rotation_displayed = False + + def display_rotation(): + """Helper to display rotation info section""" + nonlocal rotation_displayed + if rotation_displayed or not rotation_info: + return + + mount_line("", None) # Blank line before rotation section + mount_line(f"[bold {t['secondary']}]Rotation:[/bold {t['secondary']}]", None) + + status = rotation_info.get('status', 'Unknown') + status_color = '#00ff00' if status == 'Enabled' else '#ff6600' if status == 'Disabled' else t['text_dim'] + mount_line(f" [{t['text_dim']}]Status:[/{t['text_dim']}] [{status_color}]{status}[/{status_color}]", status) + + if rotation_info.get('profile'): + mount_line(f" [{t['text_dim']}]Profile:[/{t['text_dim']}] [{t['primary']}]{rotation_info['profile']}[/{t['primary']}]", rotation_info['profile']) + + if rotation_info.get('config_name'): + mount_line(f" [{t['text_dim']}]Configuration:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(rotation_info['config_name']))}[/{t['primary']}]", rotation_info['config_name']) + elif rotation_info.get('config_uid'): + mount_line(f" [{t['text_dim']}]Configuration UID:[/{t['text_dim']}] [{t['primary']}]{rotation_info['config_uid']}[/{t['primary']}]", rotation_info['config_uid']) + + if rotation_info.get('resource_name'): + mount_line(f" [{t['text_dim']}]Resource:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(rotation_info['resource_name']))}[/{t['primary']}]", rotation_info['resource_name']) + elif rotation_info.get('resource_uid'): + mount_line(f" [{t['text_dim']}]Resource UID:[/{t['text_dim']}] [{t['primary']}]{rotation_info['resource_uid']}[/{t['primary']}]", rotation_info['resource_uid']) + + if rotation_info.get('schedule'): + mount_line(f" [{t['text_dim']}]Schedule:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(rotation_info['schedule']))}[/{t['primary']}]", rotation_info['schedule']) + + if rotation_info.get('last_rotated'): + mount_line(f" [{t['text_dim']}]Last Rotated:[/{t['text_dim']}] [{t['primary']}]{rotation_info['last_rotated']}[/{t['primary']}]", rotation_info['last_rotated']) + + if rotation_info.get('last_status'): + last_status = rotation_info['last_status'] + last_status_color = '#00ff00' if last_status == 'Success' else '#ff0000' if last_status == 'Failure' else '#ffff00' + mount_line(f" [{t['text_dim']}]Last Status:[/{t['text_dim']}] [{last_status_color}]{last_status}[/{last_status_color}]", last_status) + + rotation_displayed = True + + for line in output.split('\n'): + stripped = line.strip() + if not stripped: + continue + + if ':' in stripped: + parts = stripped.split(':', 1) + key = parts[0].strip() + value = parts[1].strip() if len(parts) > 1 else '' + + if key in ['UID', 'Record UID']: + mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + elif key in ['Title', 'Name'] and not current_section: + mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [bold {t['primary']}]{rich_escape(str(value))}[/bold {t['primary']}]", value) + elif key == 'Type': + # Show 'app' for app records if type is blank + display_type = value if value else 'app' if record_uid in self.app_record_uids else '' + mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary_dim']}]{rich_escape(str(display_type))}[/{t['primary_dim']}]", display_type) + elif key == 'Password': + # Show masked password but use ClipboardCommand to copy (generates audit event) + # Respect unmask_secrets toggle + if self.unmask_secrets: + display_value = actual_password if actual_password else value + else: + display_value = '******' if actual_password else value + copy_value = actual_password if actual_password else None + mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(display_value))}[/{t['primary']}]", copy_value, is_password=True) + elif key == 'URL': + # Display URL, then TOTP if present + mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + display_totp() # Add TOTP section right after URL (before Notes) + elif key == 'Notes': + # Display TOTP before Notes if not already shown (for records without URL) + display_totp() + # Notes section - check if it has content on same line or is multi-line + mount_line("", None) # Blank line before Notes + mount_line(f"[bold {t['secondary']}]Notes:[/bold {t['secondary']}]", None) + current_section = 'Notes' + if value: + # Notes content is on the same line + mount_line(f" [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + elif key == 'TOTP URL': + # Skip TOTP URL - we'll show the code calculated from stored URL + pass + elif key == 'Two Factor Code': + # Skip - we'll calculate and show TOTP from stored URL below + pass + elif key == 'Passkey': + # Passkey section header + mount_line("", None) # Blank line before + mount_line(f"[bold {t['secondary']}]Passkey:[/bold {t['secondary']}]", None) + current_section = 'Passkey' + elif current_section == 'Passkey' and key in ('Created', 'Username', 'Relying Party'): + # Passkey detail fields + mount_line(f" [{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + elif is_section_header(key, value): + # Display rotation BEFORE User Permissions section + if key == 'User Permissions' and not rotation_displayed: + display_rotation() + # Display attachments BEFORE Share Admins section + if key.startswith('Share Admins') and not attachments_displayed: + display_attachments() + current_section = key + seen_first_user = False # Reset for new section + mount_line("", None) # Blank line + mount_line(f"[bold {t['secondary']}]{key}:[/bold {t['secondary']}]", None) + elif key.rstrip(':') in ('fileRef', 'addressRef', 'cardRef'): + # Skip reference fields - we handle attachments/linked records separately + pass + elif value: + # Strip type prefixes from field names (e.g., "text:Sign-In Address" -> "Sign-In Address") + display_key = key + field_type_prefixes = ('text:', 'multiline:', 'url:', 'phone:', 'email:', 'secret:', 'date:', 'name:', 'host:', 'address:') + for prefix in field_type_prefixes: + if key.lower().startswith(prefix): + display_key = key[len(prefix):] + # If label was empty, use a friendly name based on type + if not display_key: + type_friendly_names = { + 'text:': 'Text', + 'multiline:': 'Note', + 'url:': 'URL', + 'phone:': 'Phone', + 'email:': 'Email', + 'secret:': 'Secret', + 'date:': 'Date', + 'name:': 'Name', + 'host:': 'Host', + 'address:': 'Address', + } + display_key = type_friendly_names.get(prefix, prefix.rstrip(':').title()) + break + + # Add blank line before each User entry in User Permissions section (except first) + if display_key == 'User' and current_section == 'User Permissions': + if seen_first_user: + mount_line("", None) # Blank line between users + seen_first_user = True + indent = " " if current_section else "" + + # Check if this is a sensitive field that should be masked + is_sensitive = self._is_sensitive_field(display_key) or self._is_sensitive_field(key) + if is_sensitive and not self.unmask_secrets: + display_value = '******' + # Use is_password=False so it uses pyperclip.copy(value) instead of ClipboardCommand + # ClipboardCommand only copies the record's Password field, not arbitrary secret fields + mount_line(f"{indent}[{t['text_dim']}]{rich_escape(str(display_key))}:[/{t['text_dim']}] [{t['primary']}]{display_value}[/{t['primary']}]", value) + else: + mount_line(f"{indent}[{t['text_dim']}]{rich_escape(str(display_key))}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + elif key: + mount_line(f" [{t['primary_dim']}]{rich_escape(str(key))}[/{t['primary_dim']}]", key) + else: + # Lines without colons - continuation of notes or other multi-line content + if current_section == 'Notes': + mount_line(f" [{t['primary']}]{rich_escape(str(stripped))}[/{t['primary']}]", stripped) + elif stripped: + mount_line(f" [{t['primary_dim']}]{rich_escape(str(stripped))}[/{t['primary_dim']}]", stripped) + + # Display attachments at end if not already shown (records without Share Admins section) + display_attachments() + + # Display rotation at end if not already shown (records without User Permissions section) + display_rotation() + + # Batch mount all widgets at once for better performance + if widgets_to_mount: + detail_scroll.mount(*widgets_to_mount, before=detail_widget) + + # Start TOTP auto-refresh timer if this record has TOTP + # Skip timer management if we're in a refresh callback + if not getattr(self, '_totp_refreshing', False): + self._stop_totp_timer() # Stop any existing timer + if totp_url: + self._totp_record_uid = record_uid + self._totp_timer = self.set_interval(1.0, self._refresh_totp_display) + + def _stop_totp_timer(self): + """Stop the TOTP auto-refresh timer""" + if hasattr(self, '_totp_timer') and self._totp_timer: + self._totp_timer.stop() + self._totp_timer = None + self._totp_record_uid = None + + def _refresh_totp_display(self): + """Refresh the TOTP display (called by timer every second)""" + if not hasattr(self, '_totp_record_uid') or not self._totp_record_uid: + self._stop_totp_timer() + return + + if self._totp_record_uid != self.selected_record: + self._stop_totp_timer() + return + + # Don't refresh if in JSON view mode + if self.view_mode == 'json': + return + + # Re-display the record (TOTP is calculated fresh each time) + record_uid = self._totp_record_uid + self._totp_refreshing = True # Flag to prevent timer restart + try: + self._display_record_with_clickable_fields(record_uid) + finally: + self._totp_refreshing = False + # Restore the record UID since display clears it + self._totp_record_uid = record_uid + + def _display_json_with_clickable_fields(self, record_uid: str): + """Display JSON view with clickable string values, syntax highlighting, masking passwords""" + # Stop TOTP timer when in JSON view (no live countdown) + self._stop_totp_timer() + + t = self.theme_colors + container = self.query_one("#record_detail", VerticalScroll) + detail_widget = self.query_one("#detail_content", Static) + + # Clear previous clickable fields + self._clear_clickable_fields() + + # Get JSON output (include DAG data for PAM records) + output = self._get_record_output(record_uid, format_type='json', include_dag=True) + output = self._strip_ansi_codes(output) + + try: + json_obj = json.loads(output) + except: + # If JSON parsing fails, show raw output + detail_widget.update(f"[{t['primary']}]JSON View:\n\n{rich_escape(str(output))}[/{t['primary']}]") + return + + # Keep unmasked JSON for copying actual values + unmasked_obj = json_obj + + # Create masked version for display + display_obj = self._mask_passwords_in_json(json_obj) + + # Clear detail widget content + detail_widget.update("") + + # Collect all widgets first for batch mounting + widgets_to_mount = [] + + # Helper to collect widgets (batched for performance) + def mount_line(content: str, copy_value: str = None, is_password: bool = False): + """Collect a clickable line for batch mounting""" + line = ClickableDetailLine( + content, + copy_value=copy_value, + record_uid=record_uid if is_password else None, + is_password=is_password + ) + widgets_to_mount.append(line) + self.clickable_fields.append(line) + + # Render JSON header + mount_line(f"[bold {t['primary']}]JSON View:[/bold {t['primary']}]") + mount_line("") + + # Render JSON with syntax highlighting + self._render_json_lines(display_obj, unmasked_obj, mount_line, t, record_uid) + + # Batch mount all widgets at once for better performance + if widgets_to_mount: + container.mount(*widgets_to_mount, before=detail_widget) + + def _render_json_lines(self, display_obj, unmasked_obj, mount_line, t, record_uid, indent=0): + """Recursively render JSON object as clickable lines with syntax highlighting""" + indent_str = " " * indent + key_color = "#88ccff" # Light blue for keys + string_color = t['primary'] # Theme color for strings + number_color = "#ffcc66" # Orange for numbers + bool_color = "#ff99cc" # Pink for booleans + null_color = "#999999" # Gray for null + bracket_color = t['text_dim'] + + if isinstance(display_obj, dict): + # Make the root opening brace copyable with the entire object + mount_line(f"{indent_str}[{bracket_color}]{{[/{bracket_color}]", + copy_value=json.dumps(unmasked_obj, indent=2)) + items = list(display_obj.items()) + for i, (key, value) in enumerate(items): + comma = "," if i < len(items) - 1 else "" + # Get unmasked value for copying + unmasked_value = unmasked_obj.get(key, value) if isinstance(unmasked_obj, dict) else value + + if isinstance(value, str): + # Escape brackets for Rich markup + display_val = value.replace("[", "\\[") + is_password = (value == "************") + copy_val = unmasked_value if isinstance(unmasked_value, str) else str(unmasked_value) + mount_line( + f"{indent_str} [{key_color}]\"{key}\"[/{key_color}]: [{string_color}]\"{display_val}\"[/{string_color}]{comma}", + copy_value=copy_val, + is_password=is_password + ) + elif isinstance(value, bool): + bool_str = "true" if value else "false" + mount_line( + f"{indent_str} [{key_color}]\"{key}\"[/{key_color}]: [{bool_color}]{bool_str}[/{bool_color}]{comma}", + copy_value=str(value) + ) + elif isinstance(value, (int, float)): + mount_line( + f"{indent_str} [{key_color}]\"{key}\"[/{key_color}]: [{number_color}]{value}[/{number_color}]{comma}", + copy_value=str(value) + ) + elif value is None: + mount_line( + f"{indent_str} [{key_color}]\"{key}\"[/{key_color}]: [{null_color}]null[/{null_color}]{comma}" + ) + elif isinstance(value, list): + unmasked_list = unmasked_value if isinstance(unmasked_value, list) else value + # Make the opening line copyable with the entire array + mount_line(f"{indent_str} [{key_color}]\"{key}\"[/{key_color}]: [{bracket_color}]\\[[/{bracket_color}]", + copy_value=json.dumps(unmasked_list, indent=2)) + self._render_json_list_items(value, unmasked_list, mount_line, t, record_uid, indent + 2) + mount_line(f"{indent_str} [{bracket_color}]][/{bracket_color}]{comma}") + elif isinstance(value, dict): + unmasked_dict = unmasked_value if isinstance(unmasked_value, dict) else value + # Make the opening line copyable with the entire object + mount_line(f"{indent_str} [{key_color}]\"{key}\"[/{key_color}]: [{bracket_color}]{{[/{bracket_color}]", + copy_value=json.dumps(unmasked_dict, indent=2)) + self._render_json_dict_items(value, unmasked_dict, mount_line, t, record_uid, indent + 2) + mount_line(f"{indent_str} [{bracket_color}]}}[/{bracket_color}]{comma}") + mount_line(f"{indent_str}[{bracket_color}]}}[/{bracket_color}]") + elif isinstance(display_obj, list): + # Make the root opening bracket copyable with the entire array + mount_line(f"{indent_str}[{bracket_color}]\\[[/{bracket_color}]", + copy_value=json.dumps(unmasked_obj, indent=2)) + self._render_json_list_items(display_obj, unmasked_obj, mount_line, t, record_uid, indent + 1) + mount_line(f"{indent_str}[{bracket_color}]][/{bracket_color}]") + + def _render_json_dict_items(self, display_dict, unmasked_dict, mount_line, t, record_uid, indent): + """Render dict items for nested objects""" + indent_str = " " * indent + key_color = "#88ccff" + string_color = t['primary'] + number_color = "#ffcc66" + bool_color = "#ff99cc" + null_color = "#999999" + bracket_color = t['text_dim'] + + items = list(display_dict.items()) + for i, (key, value) in enumerate(items): + comma = "," if i < len(items) - 1 else "" + unmasked_value = unmasked_dict.get(key, value) if isinstance(unmasked_dict, dict) else value + + if isinstance(value, str): + display_val = rich_escape(value) + is_password = (value == "************") + copy_val = unmasked_value if isinstance(unmasked_value, str) else str(unmasked_value) + mount_line( + f"{indent_str}[{key_color}]\"{rich_escape(str(key))}\"[/{key_color}]: [{string_color}]\"{display_val}\"[/{string_color}]{comma}", + copy_value=copy_val, + is_password=is_password + ) + elif isinstance(value, bool): + bool_str = "true" if value else "false" + mount_line( + f"{indent_str}[{key_color}]\"{rich_escape(str(key))}\"[/{key_color}]: [{bool_color}]{bool_str}[/{bool_color}]{comma}", + copy_value=str(value) + ) + elif isinstance(value, (int, float)): + mount_line( + f"{indent_str}[{key_color}]\"{rich_escape(str(key))}\"[/{key_color}]: [{number_color}]{value}[/{number_color}]{comma}", + copy_value=str(value) + ) + elif value is None: + mount_line(f"{indent_str}[{key_color}]\"{rich_escape(str(key))}\"[/{key_color}]: [{null_color}]null[/{null_color}]{comma}") + elif isinstance(value, list): + unmasked_list = unmasked_value if isinstance(unmasked_value, list) else value + # Make the opening line copyable with the entire array + mount_line(f"{indent_str}[{key_color}]\"{rich_escape(str(key))}\"[/{key_color}]: [{bracket_color}]\\[[/{bracket_color}]", + copy_value=json.dumps(unmasked_list, indent=2)) + self._render_json_list_items(value, unmasked_list, mount_line, t, record_uid, indent + 1) + mount_line(f"{indent_str}[{bracket_color}]][/{bracket_color}]{comma}") + elif isinstance(value, dict): + unmasked_inner = unmasked_value if isinstance(unmasked_value, dict) else value + # Make the opening line copyable with the entire object + mount_line(f"{indent_str}[{key_color}]\"{rich_escape(str(key))}\"[/{key_color}]: [{bracket_color}]{{[/{bracket_color}]", + copy_value=json.dumps(unmasked_inner, indent=2)) + self._render_json_dict_items(value, unmasked_inner, mount_line, t, record_uid, indent + 1) + mount_line(f"{indent_str}[{bracket_color}]}}[/{bracket_color}]{comma}") + + def _render_json_list_items(self, display_list, unmasked_list, mount_line, t, record_uid, indent): + """Render list items for arrays""" + indent_str = " " * indent + string_color = t['primary'] + number_color = "#ffcc66" + bool_color = "#ff99cc" + null_color = "#999999" + bracket_color = t['text_dim'] + + for i, value in enumerate(display_list): + comma = "," if i < len(display_list) - 1 else "" + unmasked_value = unmasked_list[i] if isinstance(unmasked_list, list) and i < len(unmasked_list) else value + + if isinstance(value, str): + display_val = rich_escape(value) + is_password = (value == "************") + copy_val = unmasked_value if isinstance(unmasked_value, str) else str(unmasked_value) + mount_line( + f"{indent_str}[{string_color}]\"{display_val}\"[/{string_color}]{comma}", + copy_value=copy_val, + is_password=is_password + ) + elif isinstance(value, bool): + bool_str = "true" if value else "false" + mount_line(f"{indent_str}[{bool_color}]{bool_str}[/{bool_color}]{comma}", copy_value=str(value)) + elif isinstance(value, (int, float)): + mount_line(f"{indent_str}[{number_color}]{value}[/{number_color}]{comma}", copy_value=str(value)) + elif value is None: + mount_line(f"{indent_str}[{null_color}]null[/{null_color}]{comma}") + elif isinstance(value, dict): + unmasked_inner = unmasked_value if isinstance(unmasked_value, dict) else value + # Make the opening brace copyable with the entire object + mount_line(f"{indent_str}[{bracket_color}]{{[/{bracket_color}]", + copy_value=json.dumps(unmasked_inner, indent=2)) + self._render_json_dict_items(value, unmasked_inner, mount_line, t, record_uid, indent + 1) + mount_line(f"{indent_str}[{bracket_color}]}}[/{bracket_color}]{comma}") + elif isinstance(value, list): + unmasked_inner = unmasked_value if isinstance(unmasked_value, list) else value + # Make the opening bracket copyable with the entire array + mount_line(f"{indent_str}[{bracket_color}]\\[[/{bracket_color}]", + copy_value=json.dumps(unmasked_inner, indent=2)) + self._render_json_list_items(value, unmasked_inner, mount_line, t, record_uid, indent + 1) + mount_line(f"{indent_str}[{bracket_color}]][/{bracket_color}]{comma}") + + def _is_sensitive_field(self, field_name: str) -> bool: + """Check if a field name indicates it contains sensitive data""" + if not field_name: + return False + name_lower = field_name.lower() + return any(term in name_lower for term in ('secret', 'password', 'passphrase')) + + def _mask_passwords_in_json(self, obj, parent_key: str = None): + """Recursively mask password/secret/passphrase values in JSON object for display""" + if self.unmask_secrets: + return obj # Don't mask if unmask mode is enabled + + if isinstance(obj, dict): + # Check if this dict is a password field (has type: "password") + if obj.get('type') == 'password': + masked = dict(obj) + if 'value' in masked and isinstance(masked['value'], list) and len(masked['value']) > 0: + masked['value'] = ['************'] + return masked + # Check if this dict has a label that indicates sensitive data + label = obj.get('label', '') + if self._is_sensitive_field(label): + masked = dict(obj) + if 'value' in masked and isinstance(masked['value'], list) and len(masked['value']) > 0: + masked['value'] = ['************'] + return masked + # Otherwise recurse into dict values + result = {} + for key, value in obj.items(): + # Check if key itself indicates sensitive data + if self._is_sensitive_field(key) and isinstance(value, str) and value: + result[key] = '************' + else: + result[key] = self._mask_passwords_in_json(value, parent_key=key) + return result + elif isinstance(obj, list): + return [self._mask_passwords_in_json(item, parent_key=parent_key) for item in obj] + else: + return obj + + def _display_folder_with_clickable_fields(self, folder_uid: str): + """Display folder details with clickable fields for copy-on-click""" + # Stop TOTP timer when viewing folders + self._stop_totp_timer() + + # Check if JSON view is requested + if self.view_mode == 'json': + self._display_folder_json(folder_uid) + return + + t = self.theme_colors + detail_scroll = self.query_one("#record_detail", VerticalScroll) + detail_widget = self.query_one("#detail_content", Static) + + # Clear previous clickable fields + self._clear_clickable_fields() + + # Get folder from cache for type info + folder = self.params.folder_cache.get(folder_uid) + folder_type = "" + if folder: + folder_type = folder.get_folder_type() if hasattr(folder, 'get_folder_type') else str(folder.type) + + # Get folder output from get command + try: + stdout_buffer = io.StringIO() + old_stdout = sys.stdout + sys.stdout = stdout_buffer + get_cmd = RecordGetUidCommand() + get_cmd.execute(self.params, uid=folder_uid, format='detail') + sys.stdout = old_stdout + output = stdout_buffer.getvalue() + output = self._strip_ansi_codes(output) + except Exception as e: + sys.stdout = old_stdout + logging.error(f"Error getting folder output: {e}") + output = "" + + # Hide the static placeholder + detail_widget.update("") + + # Helper to mount clickable lines + def mount_line(content: str, copy_value: str = None): + line = ClickableDetailLine(content, copy_value) + detail_scroll.mount(line, before=detail_widget) + + # Header line + mount_line(f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]", None) + + if not output or output.strip() == '': + # Fallback to basic folder info + if folder: + mount_line(f"[bold {t['primary']}]{rich_escape(str(folder.name))}[/bold {t['primary']}]", folder.name) + mount_line(f"[{t['text_dim']}]UID:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(folder_uid))}[/{t['primary']}]", folder_uid) + mount_line(f"[{t['text_dim']}]Type:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(folder_type))}[/{t['primary']}]", folder_type) + mount_line(f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]", None) + return + + # Parse the output and format with clickable lines + current_section = None + section_headers = {'Record Permissions', 'User Permissions', 'Team Permissions', 'Share Administrators'} + share_admins_count = 0 + + # First pass: count share admins + in_share_admins = False + for line in output.split('\n'): + stripped = line.strip() + if ':' in stripped: + key = stripped.split(':', 1)[0].strip() + if key == 'Share Administrators': + in_share_admins = True + elif key in section_headers and key != 'Share Administrators': + in_share_admins = False + elif in_share_admins and key == 'User': + share_admins_count += 1 + + # Second pass: build clickable lines + in_share_admins = False + for line in output.split('\n'): + stripped = line.strip() + if not stripped: + continue + + if ':' in stripped: + parts = stripped.split(':', 1) + key = parts[0].strip() + value = parts[1].strip() if len(parts) > 1 else '' + + # UID fields + if key in ['Shared Folder UID', 'Folder UID', 'Team UID']: + mount_line(f"[{t['text_dim']}]{key}:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + # Folder Type + elif key == 'Folder Type': + display_type = value if value else folder_type + mount_line(f"[{t['text_dim']}]Type:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(display_type))}[/{t['primary']}]", display_type) + # Name - title + elif key == 'Name': + mount_line(f"[bold {t['primary']}]{rich_escape(str(value))}[/bold {t['primary']}]", value) + # Section headers + elif key in section_headers: + current_section = key + in_share_admins = (key == 'Share Administrators') + mount_line("", None) + if key == 'Share Administrators' and share_admins_count > 0: + mount_line(f"[bold {t['primary_bright']}]{key}:[/bold {t['primary_bright']}] [{t['text_dim']}]({share_admins_count} users)[/{t['text_dim']}]", None) + else: + mount_line(f"[bold {t['primary_bright']}]{key}:[/bold {t['primary_bright']}]", None) + # Record UID in Record Permissions - show title + elif key == 'Record UID' and current_section == 'Record Permissions': + if value in self.records: + record_title = self.records[value].get('title', 'Untitled') + mount_line(f" [{t['text_dim']}]Record:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(record_title))}[/{t['primary']}]", record_title) + mount_line(f" [{t['text_dim']}]UID:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + else: + mount_line(f" [{t['text_dim']}]Record UID:[/{t['text_dim']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + # Boolean values + elif value.lower() in ['true', 'false']: + color = t['primary'] if value.lower() == 'true' else t['primary_dim'] + indent = " " if current_section else "" + mount_line(f"{indent}[{t['secondary']}]{rich_escape(str(key))}:[/{t['secondary']}] [{color}]{rich_escape(str(value))}[/{color}]", value) + # Regular key-value pairs + elif value: + indent = " " if current_section else "" + # Skip Share Admins details (collapsed) + if in_share_admins and key in ['User', 'Email']: + continue + mount_line(f"{indent}[{t['secondary']}]{rich_escape(str(key))}:[/{t['secondary']}] [{t['primary']}]{rich_escape(str(value))}[/{t['primary']}]", value) + elif key: + indent = " " if current_section else "" + if in_share_admins: + continue + mount_line(f"{indent}[{t['primary_dim']}]{rich_escape(str(key))}[/{t['primary_dim']}]", key) + else: + if stripped: + indent = " " if current_section else "" + if in_share_admins: + continue + mount_line(f"{indent}[{t['primary']}]{rich_escape(str(stripped))}[/{t['primary']}]", stripped) + + # Footer line + mount_line(f"\n[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]", None) + + def _display_folder_json(self, folder_uid: str): + """Display folder/shared folder as JSON with clickable values""" + t = self.theme_colors + container = self.query_one("#record_detail", VerticalScroll) + detail_widget = self.query_one("#detail_content", Static) + + # Clear previous clickable fields + self._clear_clickable_fields() + + # Get JSON output from get command + try: + stdout_buffer = io.StringIO() + old_stdout = sys.stdout + sys.stdout = stdout_buffer + get_cmd = RecordGetUidCommand() + get_cmd.execute(self.params, uid=folder_uid, format='json') + sys.stdout = old_stdout + output = stdout_buffer.getvalue() + output = self._strip_ansi_codes(output) + except Exception as e: + sys.stdout = old_stdout + logging.error(f"Error getting folder JSON output: {e}") + detail_widget.update(f"[red]Error getting folder JSON: {str(e)}[/red]") + return + + try: + json_obj = json.loads(output) + except: + # If JSON parsing fails, show raw output + detail_widget.update(f"[{t['primary']}]JSON View:\n\n{rich_escape(str(output))}[/{t['primary']}]") + return + + # Clear detail widget content + detail_widget.update("") + + # Helper to mount clickable JSON lines + def mount_json_line(content: str, copy_value: str = None, indent: int = 0): + line = ClickableDetailLine(content, copy_value) + container.mount(line, before=detail_widget) + + # Build formatted JSON output with clickable values + mount_json_line(f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]", None) + mount_json_line(f"[bold {t['primary']}]JSON View[/bold {t['primary']}] [{t['text_dim']}](press 't' for detail view)[/{t['text_dim']}]", None) + mount_json_line(f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]", None) + mount_json_line("", None) + + def render_json(obj, indent=0): + """Recursively render JSON with clickable string values""" + prefix = " " * indent + if isinstance(obj, dict): + # Make the opening brace copyable with the entire object + mount_json_line(f"{prefix}{{", json.dumps(obj, indent=2)) + items = list(obj.items()) + for i, (key, value) in enumerate(items): + comma = "," if i < len(items) - 1 else "" + if isinstance(value, str): + escaped_value = rich_escape(value) + mount_json_line( + f"{prefix} [{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['primary']}]\"{escaped_value}\"[/{t['primary']}]{comma}", + value + ) + elif isinstance(value, bool): + bool_str = "true" if value else "false" + mount_json_line( + f"{prefix} [{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['primary_bright']}]{bool_str}[/{t['primary_bright']}]{comma}", + str(value) + ) + elif isinstance(value, (int, float)): + mount_json_line( + f"{prefix} [{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['primary_bright']}]{value}[/{t['primary_bright']}]{comma}", + str(value) + ) + elif value is None: + mount_json_line( + f"{prefix} [{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['text_dim']}]null[/{t['text_dim']}]{comma}", + None + ) + elif isinstance(value, dict): + # Make the key line copyable with the entire object + mount_json_line(f"{prefix} [{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: {{", + json.dumps(value, indent=2)) + render_json_items(value, indent + 2) + mount_json_line(f"{prefix} }}{comma}", None) + elif isinstance(value, list): + # Make the key line copyable with the entire array + mount_json_line(f"{prefix} [{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [", + json.dumps(value, indent=2)) + render_json_list_items(value, indent + 2) + mount_json_line(f"{prefix} ]{comma}", None) + mount_json_line(f"{prefix}}}", None) + elif isinstance(obj, list): + mount_json_line(f"{prefix}[", json.dumps(obj, indent=2)) + render_json_list_items(obj, indent + 1) + mount_json_line(f"{prefix}]", None) + + def render_json_items(obj, indent): + """Render dict items without outer braces""" + prefix = " " * indent + items = list(obj.items()) + for i, (key, value) in enumerate(items): + comma = "," if i < len(items) - 1 else "" + if isinstance(value, str): + escaped_value = rich_escape(value) + mount_json_line( + f"{prefix}[{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['primary']}]\"{escaped_value}\"[/{t['primary']}]{comma}", + value + ) + elif isinstance(value, bool): + bool_str = "true" if value else "false" + mount_json_line( + f"{prefix}[{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['primary_bright']}]{bool_str}[/{t['primary_bright']}]{comma}", + str(value) + ) + elif isinstance(value, (int, float)): + mount_json_line( + f"{prefix}[{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['primary_bright']}]{value}[/{t['primary_bright']}]{comma}", + str(value) + ) + elif value is None: + mount_json_line( + f"{prefix}[{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [{t['text_dim']}]null[/{t['text_dim']}]{comma}", + None + ) + elif isinstance(value, dict): + mount_json_line(f"{prefix}[{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: {{", + json.dumps(value, indent=2)) + render_json_items(value, indent + 1) + mount_json_line(f"{prefix}}}{comma}", None) + elif isinstance(value, list): + mount_json_line(f"{prefix}[{t['secondary']}]\"{rich_escape(key)}\"[/{t['secondary']}]: [", + json.dumps(value, indent=2)) + render_json_list_items(value, indent + 1) + mount_json_line(f"{prefix}]{comma}", None) + + def render_json_list_items(obj, indent): + """Render list items without outer brackets""" + prefix = " " * indent + for i, item in enumerate(obj): + comma = "," if i < len(obj) - 1 else "" + if isinstance(item, str): + escaped_item = rich_escape(item) + mount_json_line(f"{prefix}[{t['primary']}]\"{escaped_item}\"[/{t['primary']}]{comma}", item) + elif isinstance(item, dict): + mount_json_line(f"{prefix}{{", json.dumps(item, indent=2)) + render_json_items(item, indent + 1) + mount_json_line(f"{prefix}}}{comma}", None) + elif isinstance(item, list): + mount_json_line(f"{prefix}[", json.dumps(item, indent=2)) + render_json_list_items(item, indent + 1) + mount_json_line(f"{prefix}]{comma}", None) + else: + mount_json_line(f"{prefix}[{t['primary_bright']}]{item}[/{t['primary_bright']}]{comma}", str(item)) + + render_json(json_obj) + + mount_json_line("", None) + mount_json_line(f"[bold {t['secondary']}]{'━' * 60}[/bold {t['secondary']}]", None) + + # Add copy full JSON option + full_json = json.dumps(json_obj, indent=2) + mount_json_line(f"\n[{t['text_dim']}]Click to copy full JSON:[/{t['text_dim']}]", full_json) + + def _display_secrets_manager_app(self, app_uid: str): + """Display Secrets Manager application details""" + # Clear any previous content + self._clear_clickable_fields() + + detail_widget = self.query_one("#detail_content", Static) + t = self.theme_colors + + try: + from ..proto import APIRequest_pb2, enterprise_pb2 + from .. import api, utils + import json + + record = self.records[app_uid] + app_title = record.get('title', 'Untitled') + + # Fetch app info from API + app_data = { + "app_name": app_title, + "app_uid": app_uid, + "client_devices": [], + "shares": [] + } + + try: + rq = APIRequest_pb2.GetAppInfoRequest() + rq.appRecordUid.append(utils.base64_url_decode(app_uid)) + rs = api.communicate_rest(self.params, rq, 'vault/get_app_info', rs_type=APIRequest_pb2.GetAppInfoResponse) + + if rs.appInfo: + app_info = rs.appInfo[0] + + # Collect client devices + client_devices = [x for x in app_info.clients if x.appClientType == enterprise_pb2.GENERAL] + for client in client_devices: + app_data["client_devices"].append({ + "device_name": client.id + }) + + # Collect application access (shares) + for share in app_info.shares: + uid_str = utils.base64_url_encode(share.secretUid) + share_type = APIRequest_pb2.ApplicationShareType.Name(share.shareType) + + # Get title from cache + title = "Unknown" + if share_type == 'SHARE_TYPE_RECORD': + if uid_str in self.params.record_cache: + rec = self.params.record_cache[uid_str] + if 'data_unencrypted' in rec: + data = json.loads(rec['data_unencrypted']) + title = data.get('title', 'Untitled') + share_type_display = "RECORD" + elif share_type == 'SHARE_TYPE_FOLDER': + if hasattr(self.params, 'folder_cache'): + folder = self.params.folder_cache.get(uid_str) + if folder: + title = folder.name + share_type_display = "FOLDER" + else: + share_type_display = share_type + + app_data["shares"].append({ + "share_type": share_type, + "uid": uid_str, + "editable": share.editable, + "title": title, + "type": share_type_display + }) + + except (KeyError, AttributeError, json.JSONDecodeError, ValueError) as e: + logging.debug(f"Error fetching app info: {e}", exc_info=True) + + # Display based on view mode + if self.view_mode == 'json': + # JSON view with syntax highlighting + # Clear previous clickable fields + self._clear_clickable_fields() + detail_widget.update("") + + # Collect widgets for batch mounting + container = self.query_one("#record_detail", VerticalScroll) + widgets_to_mount = [] + + def mount_line(content: str, copy_value: str = None, is_password: bool = False): + """Collect a clickable line for batch mounting""" + line = ClickableDetailLine( + content, + copy_value=copy_value, + record_uid=app_uid if is_password else None, + is_password=is_password + ) + widgets_to_mount.append(line) + self.clickable_fields.append(line) + + # Render JSON header + mount_line(f"[bold {t['primary']}]JSON View:[/bold {t['primary']}]") + mount_line("") + + # Render JSON with syntax highlighting + self._render_json_lines(app_data, app_data, mount_line, t, app_uid) + + # Batch mount all widgets + if widgets_to_mount: + container.mount(*widgets_to_mount, before=detail_widget) + else: + # Detail view + lines = [] + lines.append(f"[bold {t['primary']}]Secrets Manager Application[/bold {t['primary']}]") + lines.append(f"[{t['text_dim']}]App Name:[/{t['text_dim']}] [{t['primary']}]{app_title}[/{t['primary']}]") + lines.append(f"[{t['text_dim']}]App UID:[/{t['text_dim']}] [{t['primary']}]{app_uid}[/{t['primary']}]") + lines.append("") + + # Show client devices + if app_data["client_devices"]: + lines.append(f"[bold {t['secondary']}]Client Devices ({len(app_data['client_devices'])}):[/bold {t['secondary']}]") + for idx, device in enumerate(app_data["client_devices"][:self.DEVICE_DISPLAY_LIMIT], 1): + lines.append(f" [{t['text_dim']}]{idx}.[/{t['text_dim']}] [{t['primary']}]{device['device_name']}[/{t['primary']}]") + if len(app_data["client_devices"]) > self.DEVICE_DISPLAY_LIMIT: + lines.append(f" [{t['text_dim']}]... and {len(app_data['client_devices']) - self.DEVICE_DISPLAY_LIMIT} more[/{t['text_dim']}]") + lines.append("") + else: + lines.append(f"[{t['text_dim']}]No client devices registered for this Application[/{t['text_dim']}]") + lines.append("") + + # Show application access + if app_data["shares"]: + lines.append(f"[bold {t['secondary']}]Application Access:[/bold {t['secondary']}]") + lines.append("") + for idx, share in enumerate(app_data["shares"][:self.SHARE_DISPLAY_LIMIT], 1): + lines.append(f" [{t['text_dim']}]{share['type']}:[/{t['text_dim']}] [{t['primary']}]{share['title']}[/{t['primary']}]") + lines.append(f" [{t['text_dim']}]UID:[/{t['text_dim']}] [{t['text']}]{share['uid']}[/{t['text']}]") + permissions = "Editable" if share['editable'] else "Read-Only" + lines.append(f" [{t['text_dim']}]Permissions:[/{t['text_dim']}] [{t['primary_dim']}]{permissions}[/{t['primary_dim']}]") + lines.append("") + if len(app_data["shares"]) > self.SHARE_DISPLAY_LIMIT: + lines.append(f" [{t['text_dim']}]... and {len(app_data['shares']) - self.SHARE_DISPLAY_LIMIT} more shares[/{t['text_dim']}]") + lines.append("") + else: + lines.append(f"[bold {t['secondary']}]Application Access:[/bold {t['secondary']}]") + lines.append(f"[{t['text_dim']}]No shared folders or records[/{t['text_dim']}]") + lines.append("") + + detail_widget.update("\n".join(lines)) + + self._update_shortcuts_bar(record_selected=True) + + except Exception as e: + logging.error(f"Error displaying Secrets Manager app: {e}", exc_info=True) + detail_widget.update(f"[red]Error displaying app:[/red]\n{str(e)}") + + def _display_record_detail(self, record_uid: str): + """Display record details in the right panel using Commander's get command""" + detail_widget = self.query_one("#detail_content", Static) + t = self.theme_colors # Get theme colors + + try: + if record_uid not in self.records: + self._clear_clickable_fields() + detail_widget.update("[red]Record not found[/red]") + return + + # Use clickable fields for both views + if self.view_mode == 'json': + self._display_json_with_clickable_fields(record_uid) + else: + self._display_record_with_clickable_fields(record_uid) + + # Update shortcuts bar to show record-specific shortcuts + self._update_shortcuts_bar(record_selected=True) + + except Exception as e: + logging.error(f"Error displaying record detail: {e}", exc_info=True) + self._clear_clickable_fields() + # Fallback to simple static display + try: + content = self._format_record_for_tui(record_uid) + detail_widget.update(content) + except: + error_msg = str(e).replace('[', '\\[').replace(']', '\\]') + detail_widget.update(f"[red]Error displaying record:[/red]\n{error_msg}\n\n[dim]Press 't' to toggle view mode[/dim]") + + def _update_status(self, message: str): + """Update the status bar""" + status_bar = self.query_one("#status_bar", Static) + status_bar.update(f"⚡ {message}") + + def _update_shortcuts_bar(self, record_selected: bool = False, folder_selected: bool = False, clear: bool = False): + """Update the shortcuts bar at bottom of detail panel""" + try: + shortcuts_bar = self.query_one("#shortcuts_bar", Static) + t = self.theme_colors + + if clear: + # Clear the shortcuts bar (for info displays like device/user info) + shortcuts_bar.update("") + elif record_selected: + mode = "JSON" if self.view_mode == 'json' else "Detail" + mask_label = "Mask" if self.unmask_secrets else "Unmask" + shortcuts_bar.update( + f"[{t['secondary']}]Mode: {mode}[/{t['secondary']}] " + f"[{t['text_dim']}]t[/{t['text_dim']}]=Toggle " + f"[{t['text_dim']}]p[/{t['text_dim']}]=Password " + f"[{t['text_dim']}]u[/{t['text_dim']}]=Username " + f"[{t['text_dim']}]c[/{t['text_dim']}]=Copy All " + f"[{t['text_dim']}]m[/{t['text_dim']}]={mask_label}" + ) + elif folder_selected: + mode = "JSON" if self.view_mode == 'json' else "Detail" + mask_label = "Mask" if self.unmask_secrets else "Unmask" + shortcuts_bar.update( + f"[{t['secondary']}]Mode: {mode}[/{t['secondary']}] " + f"[{t['text_dim']}]t[/{t['text_dim']}]=Toggle " + f"[{t['text_dim']}]c[/{t['text_dim']}]=Copy All " + f"[{t['text_dim']}]m[/{t['text_dim']}]={mask_label}" + ) + else: + # Root or other - hide navigation help + shortcuts_bar.update("") + except Exception as e: + logging.debug(f"Error updating shortcuts bar: {e}") + + @on(Click, "#search_bar, #search_display") + def on_search_bar_click(self, event: Click) -> None: + """Activate search mode when search bar is clicked""" + tree = self.query_one("#folder_tree", Tree) + self.search_input_active = True + tree.add_class("search-input-active") + self._update_search_display(perform_search=False) # Don't change tree when entering search + self._update_status("Type to search | Tab to navigate | Ctrl+U to clear") + event.stop() + + @on(Click, "#user_info") + def on_user_info_click(self, event: Click) -> None: + """Show whoami info when user info is clicked""" + self._display_whoami_info() + event.stop() + + @on(Click, "#device_status_info") + def on_device_status_click(self, event: Click) -> None: + """Show device info when Stay Logged In / Logout section is clicked""" + self._display_device_info() + event.stop() + + def on_paste(self, event: Paste) -> None: + """Handle paste events (Cmd+V on Mac, Ctrl+V on Windows/Linux)""" + if self.search_input_active and event.text: + # Append pasted text to search input (strip newlines) + pasted_text = event.text.replace('\n', ' ').replace('\r', '') + self.search_input_text += pasted_text + self._update_search_display() + event.stop() + + @on(Tree.NodeSelected) + def on_tree_node_selected(self, event: Tree.NodeSelected): + """Handle tree node selection (folder or record)""" + # Deactivate search input mode when selecting a node (clicking or navigating) + if self.search_input_active: + self.search_input_active = False + tree = self.query_one("#folder_tree", Tree) + tree.remove_class("search-input-active") + # Update search display to remove cursor + search_display = self.query_one("#search_display", Static) + if self.search_input_text: + search_display.update(rich_escape(self.search_input_text)) + else: + search_display.update("[dim]Search... (Tab or /)[/dim]") + + node_data = event.node.data + if not node_data: + return + + node_type = node_data.get('type') + node_uid = node_data.get('uid') + + if node_type == 'record': + # Record selected - show details + self.selected_record = node_uid + self.selected_folder = None # Clear folder selection + + # Verify record exists before displaying + if node_uid in self.records: + # Check if this is an app record (Secrets Manager) + if node_uid in self.app_record_uids: + # Display Secrets Manager app info + self._display_secrets_manager_app(node_uid) + self._update_status(f"App record selected: {self.records[node_uid].get('title', 'Untitled')}") + else: + self._display_record_detail(node_uid) + self._update_status(f"Record selected: {self.records[node_uid].get('title', 'Untitled')}") + else: + # Record not found - show error + detail_widget = self.query_one("#detail_content", Static) + detail_widget.update(f"[red]Error: Record not found[/red]\n\nUID: {node_uid}\n\nThis record may have been deleted or you may not have access to it.") + self._update_status(f"Record not found: {node_uid}") + elif node_type == 'folder': + # Folder selected - show folder info with clickable fields + self.selected_record = None # Clear record selection + self.selected_folder = node_uid # Set folder selection + folder = self.params.folder_cache.get(node_uid) + if folder: + # Use clickable fields for folder display + self._display_folder_with_clickable_fields(node_uid) + self._update_status(f"Folder: {folder.name}") + else: + self._clear_clickable_fields() + detail_widget = self.query_one("#detail_content", Static) + detail_widget.update("[red]Folder not found[/red]") + self._update_shortcuts_bar(folder_selected=True) + elif node_type == 'virtual_folder': + # Virtual folder selected (e.g., Secrets Manager Apps) + self.selected_record = None + self.selected_folder = None + self._clear_clickable_fields() + detail_widget = self.query_one("#detail_content", Static) + t = self.theme_colors + if node_uid == '__secrets_manager_apps__': + # Count app records + app_count = len(self.app_record_uids) + detail_widget.update( + f"[bold {t['virtual_folder']}]★ Secrets Manager Apps[/bold {t['virtual_folder']}]\n\n" + f"[{t['primary_dim']}]Contains {app_count} Secrets Manager application {'record' if app_count == 1 else 'records'}.\n" + f"Select a record to view details.[/{t['primary_dim']}]" + ) + self._update_status("Secrets Manager Apps") + else: + detail_widget.update(f"[{t['primary_dim']}]Virtual folder[/{t['primary_dim']}]") + self._update_status("Virtual folder") + self._update_shortcuts_bar(folder_selected=True) + elif node_type == 'root': + # Root selected - show welcome/help content + self.selected_record = None # Clear record selection + self.selected_folder = None # Clear folder selection + self._clear_clickable_fields() + detail_widget = self.query_one("#detail_content", Static) + detail_widget.update(self._get_welcome_screen_content()) + self._update_status("My Vault") + self._update_shortcuts_bar(clear=True) # Help content is already in the panel + + def _update_search_display(self, perform_search=True): + """Update the search display and results with blinking cursor. + + Args: + perform_search: If True, perform search when text changes. Set to False + when just entering search mode to avoid tree changes. + """ + try: + search_display = self.query_one("#search_display", Static) + results_label = self.query_one("#search_results_label", Static) + + # Force visibility + if search_display.styles.display == "none": + search_display.styles.display = "block" + + # Update display with blinking cursor at end + if self.search_input_text: + # Show text with blinking cursor (escape special chars for Rich markup) + display_text = f"> {rich_escape(self.search_input_text)}[blink]▎[/blink]" + else: + # Show prompt with blinking cursor (ready to type) + display_text = "> [blink]▎[/blink]" + + search_display.update(display_text) + + # Update status bar + self._update_status("Type to search | Enter/Tab/↓ to navigate | ESC to close") + + # Only perform search when requested and there's text, or when clearing + if perform_search: + if self.search_input_text: + result_count = self._perform_live_search(self.search_input_text) + t = self.theme_colors + + if result_count == 0: + results_label.update("[#ff0000]No matches[/#ff0000]") + elif result_count == 1: + results_label.update(f"[{t['secondary']}]1 match[/{t['secondary']}]") + else: + results_label.update(f"[{t['secondary']}]{result_count} matches[/{t['secondary']}]") + else: + # Clear results label when no text + results_label.update("") + else: + # Just entering search mode - don't change results label + pass + + except Exception as e: + logging.error(f"Error in _update_search_display: {e}", exc_info=True) + self._update_status(f"ERROR: {str(e)}") + + def on_key(self, event): + """Handle keyboard events""" + search_bar = self.query_one("#search_bar") + tree = self.query_one("#folder_tree", Tree) + + # Global key handlers that work regardless of focus + # ! exits to regular shell (works from any widget) + if event.character == "!" and not self.search_input_active: + self.exit("Exited to Keeper shell. Type 'supershell' or 'ss' to return.") + event.prevent_default() + event.stop() + return + + # Handle Tab/Shift+Tab cycling: Tree → Detail → Search (counterclockwise) + detail_scroll = self.query_one("#record_detail", VerticalScroll) + + # Handle search input mode Tab/Shift+Tab first (search_input_active takes priority) + if self.search_input_active: + if event.key == "tab": + # Search input → Tree (forward in cycle) + self.search_input_active = False + tree.remove_class("search-input-active") + search_display = self.query_one("#search_display", Static) + if self.search_input_text: + search_display.update(rich_escape(self.search_input_text)) + else: + search_display.update("[dim]Search...[/dim]") + tree.focus() + self._update_status("Navigate with j/k | Tab to detail | ? for help") + event.prevent_default() + event.stop() + return + elif event.key == "shift+tab": + # Search input → Detail pane (backwards in cycle) + self.search_input_active = False + tree.remove_class("search-input-active") + search_display = self.query_one("#search_display", Static) + if self.search_input_text: + search_display.update(rich_escape(self.search_input_text)) + else: + search_display.update("[dim]Search...[/dim]") + detail_scroll.focus() + self._update_status("Detail pane | Tab to search | Shift+Tab to tree") + event.prevent_default() + event.stop() + return + + if detail_scroll.has_focus: + if event.key == "tab": + # Detail pane → Search input + self.search_input_active = True + tree.add_class("search-input-active") + self._update_search_display(perform_search=False) # Don't change tree when entering search + self._update_status("Type to search | Tab to tree | Ctrl+U to clear") + event.prevent_default() + event.stop() + return + elif event.key == "shift+tab": + # Detail pane → Tree + tree.focus() + self._update_status("Navigate with j/k | Tab to detail | ? for help") + event.prevent_default() + event.stop() + return + elif event.key == "escape": + tree.focus() + event.prevent_default() + event.stop() + return + elif event.key == "ctrl+y": + # Ctrl+Y scrolls viewport up (like vim) + detail_scroll.scroll_relative(y=-1) + event.prevent_default() + event.stop() + return + elif event.key == "ctrl+e": + # Ctrl+E scrolls viewport down (like vim) + detail_scroll.scroll_relative(y=1) + event.prevent_default() + event.stop() + return + + if search_bar.styles.display != "none": + # Search bar is active + + # If we're navigating results (not typing), let tree/app handle its keys + if not self.search_input_active and tree.has_focus: + # Handle left/right arrow keys for expand/collapse + if event.key == "left": + if tree.cursor_node and tree.cursor_node.allow_expand: + tree.cursor_node.collapse() + event.prevent_default() + event.stop() + return + elif event.key == "right": + if tree.cursor_node and tree.cursor_node.allow_expand: + tree.cursor_node.expand() + event.prevent_default() + event.stop() + return + # Navigation keys for tree + if event.key in ("j", "k", "h", "l", "up", "down", "enter", "space"): + return + # Action keys (copy, toggle view, etc.) - let them pass through + if event.key in ("t", "c", "u", "w", "i", "y", "d", "g", "p", "question_mark"): + return + # Shift+G for go to bottom + if event.character == "G": + return + # Tab switches to detail pane + if event.key == "tab": + detail_scroll.focus() + self._update_status("Detail pane | Tab to search | Shift+Tab to tree") + event.prevent_default() + event.stop() + return + # Shift+Tab switches to search input + elif event.key == "shift+tab": + self.search_input_active = True + tree.add_class("search-input-active") + self._update_search_display(perform_search=False) # Don't change tree when entering search + self._update_status("Type to search | Tab to tree | Ctrl+U to clear") + event.prevent_default() + event.stop() + return + elif event.key == "slash": + # Switch back to search input mode + self.search_input_active = True + tree.add_class("search-input-active") + self._update_search_display(perform_search=False) # Don't change tree when entering search + event.prevent_default() + event.stop() + return + + # Ctrl+U clears the search input (like bash) + # Reset tree to show all items when clearing search + if event.key == "ctrl+u" and self.search_input_active: + self.search_input_text = "" + self._update_search_display(perform_search=False) # Just update display + self._perform_live_search("") # Reset tree to show all + event.prevent_default() + event.stop() + return + + # "/" to switch to search input mode (works from anywhere when search bar visible) + if event.key == "slash" and not self.search_input_active: + self.search_input_active = True + tree.add_class("search-input-active") + self._update_search_display(perform_search=False) # Don't change tree when entering search + event.prevent_default() + event.stop() + return + + if event.key == "escape": + # Clear search and move focus to tree (don't hide search bar) + self.search_input_text = "" + self.search_input_active = False + tree.remove_class("search-input-active") + self._perform_live_search("") # Reset to show all + + # Update search display to show placeholder + search_display = self.query_one("#search_display", Static) + search_display.update("[dim]Search... (Tab or /)[/dim]") + results_label = self.query_one("#search_results_label", Static) + results_label.update("") + + # Restore previous selection + self.selected_record = self.pre_search_selected_record + self.selected_folder = self.pre_search_selected_folder + self._restore_tree_selection(tree) + + tree.focus() + self._update_status("Navigate with j/k | Tab to detail | ? for help") + event.prevent_default() + event.stop() + elif event.key in ("enter", "down"): + # Move focus to tree to navigate results + # Switch to navigation mode + self.search_input_active = False + + # Show tree selection - remove the class that hides it + tree.remove_class("search-input-active") + + # Remove cursor from search display + search_display = self.query_one("#search_display", Static) + if self.search_input_text: + search_display.update(rich_escape(self.search_input_text)) # No cursor + else: + search_display.update("[dim]Search...[/dim]") + + # Force focus to tree + self.set_focus(tree) + tree.focus() + + self._update_status("Navigate results with j/k | / to edit search | ESC to close") + event.prevent_default() + event.stop() + return # Return immediately to avoid further processing + elif event.key == "backspace": + # Delete last character + if self.search_input_text: + self.search_input_text = self.search_input_text[:-1] + self._update_search_display() + event.prevent_default() + event.stop() + elif self.search_input_active and event.character and event.character.isprintable(): + # Only add characters when search input is active (not when navigating results) + self.search_input_text += event.character + self._update_search_display() + event.prevent_default() + event.stop() + else: + # Search bar is NOT active - handle escape and command mode + + # Handle command mode (vim :N navigation) + if self.command_mode: + if event.key == "escape": + # Cancel command mode + self.command_mode = False + self.command_buffer = "" + self._update_status("Command cancelled") + event.prevent_default() + event.stop() + return + elif event.key == "enter": + # Execute command + self._execute_command(self.command_buffer) + self.command_mode = False + self.command_buffer = "" + event.prevent_default() + event.stop() + return + elif event.key == "backspace": + # Delete last character + if self.command_buffer: + self.command_buffer = self.command_buffer[:-1] + self._update_status(f":{self.command_buffer}") + else: + # Exit command mode if buffer is empty + self.command_mode = False + self._update_status("Navigate with j/k | / to search | ? for help") + event.prevent_default() + event.stop() + return + elif event.character and event.character.isdigit(): + # Accept digits for line number navigation + self.command_buffer += event.character + self._update_status(f":{self.command_buffer}") + event.prevent_default() + event.stop() + return + else: + # Invalid character for command mode + event.prevent_default() + event.stop() + return + + # Enter command mode with : + if event.character == ":": + self.command_mode = True + self.command_buffer = "" + self._update_status(":") + event.prevent_default() + event.stop() + return + + # Handle arrow keys for expand/collapse when search is not active + if tree.has_focus: + if event.key == "left": + if tree.cursor_node and tree.cursor_node.allow_expand: + tree.cursor_node.collapse() + event.prevent_default() + event.stop() + return + elif event.key == "right": + if tree.cursor_node and tree.cursor_node.allow_expand: + tree.cursor_node.expand() + event.prevent_default() + event.stop() + return + + if event.key == "escape": + # Escape: collapse current folder or go to parent, stop at root + self._collapse_current_or_parent(tree) + event.prevent_default() + event.stop() + return + + def _collapse_current_or_parent(self, tree: Tree): + """Collapse current node if expanded, or go to parent. Stop at root.""" + cursor_node = tree.cursor_node + if cursor_node is None: + return + + # If we're at root, do nothing - this is as far as we go + if cursor_node == tree.root: + self._update_status("At root") + return + + if cursor_node.is_expanded and cursor_node.children: + # Current node is expanded - collapse it + cursor_node.collapse() + self._update_status("Collapsed") + elif cursor_node.parent: + # Go to parent + tree.select_node(cursor_node.parent) + self._update_status("Moved to parent") + + def _execute_command(self, command: str): + """Execute vim-style command (e.g., :20 to go to line 20)""" + command = command.strip() + + # Try to parse as line number + try: + line_num = int(command) + self._goto_line(line_num) + except ValueError: + self._update_status(f"Unknown command: {command}") + + def _goto_line(self, line_num: int): + """Go to specified line number in the tree (1-indexed like vim)""" + tree = self.query_one("#folder_tree", Tree) + + # Build list of visible nodes + visible_nodes = [] + + def collect_visible_nodes(node, include_self=True): + """Collect all visible nodes in order""" + if include_self: + visible_nodes.append(node) + if node.is_expanded: + for child in node.children: + collect_visible_nodes(child, include_self=True) + + # Start from root (line 1 = root) + collect_visible_nodes(tree.root) + + # Convert 1-indexed to 0-indexed + target_index = line_num - 1 + + if target_index < 0: + target_index = 0 + elif target_index >= len(visible_nodes): + target_index = len(visible_nodes) - 1 + + if visible_nodes: + target_node = visible_nodes[target_index] + tree.select_node(target_node) + self._update_status(f"Line {target_index + 1} of {len(visible_nodes)}") + else: + self._update_status("No visible nodes") + + def check_action(self, action: str, parameters: tuple) -> bool | None: + """Control whether actions are enabled based on search state""" + # When search input is active, disable all bindings except escape and search + # This allows keys to be captured as text input instead of triggering actions + if hasattr(self, 'search_input_active') and self.search_input_active: + # Only allow escape and search actions when typing in search + if action in ("quit", "search"): + return True + # Disable all other actions - keys will be captured as text + return False + # When not in search input mode, allow all actions + return True + + def action_search(self): + """Activate search input mode""" + tree = self.query_one("#folder_tree", Tree) + + # Save current selection before activating search + if not self.search_input_active: + self.pre_search_selected_record = self.selected_record + self.pre_search_selected_folder = self.selected_folder + + # Activate search input mode + self.search_input_active = True + tree.add_class("search-input-active") + self._update_search_display(perform_search=False) # Don't change tree when entering search + self._update_status("Type to search | Tab to navigate | Ctrl+U to clear") + + def action_toggle_view_mode(self): + """Toggle between detail and JSON view modes""" + # Works for records, folders, and shared folders + if not self.selected_record and not self.selected_folder: + self.notify("⚠️ No record or folder selected", severity="warning") + return + + if self.view_mode == 'detail': + self.view_mode = 'json' + self.notify("📋 Switched to JSON view", severity="information") + else: + self.view_mode = 'detail' + self.notify("📋 Switched to Detail view", severity="information") + + # Refresh the current display + try: + if self.selected_record: + # Check if it's a Secret Manager app + if self.selected_record in self.app_record_uids: + self._display_secrets_manager_app(self.selected_record) + else: + self._display_record_detail(self.selected_record) + elif self.selected_folder: + self._display_folder_with_clickable_fields(self.selected_folder) + except Exception as e: + logging.error(f"Error toggling view mode: {e}", exc_info=True) + self.notify(f"⚠️ Error switching view: {str(e)}", severity="error") + + def action_toggle_unmask(self): + """Toggle unmasking of secret/password/passphrase fields""" + if not self.selected_record and not self.selected_folder: + self.notify("No record or folder selected", severity="warning") + return + + self.unmask_secrets = not self.unmask_secrets + status = "unmasked" if self.unmask_secrets else "masked" + self.notify(f"Secrets {status}", severity="information") + + # Refresh the current display and shortcuts bar + try: + if self.selected_record: + self._display_record_detail(self.selected_record) + self._update_shortcuts_bar(record_selected=True) + elif self.selected_folder: + self._display_folder_with_clickable_fields(self.selected_folder) + self._update_shortcuts_bar(folder_selected=True) + except Exception as e: + logging.error(f"Error toggling unmask: {e}", exc_info=True) + + def action_copy_password(self): + """Copy password of selected record to clipboard using clipboard-copy command (generates audit event)""" + if self.selected_record and self.selected_record in self.records: + try: + # Use ClipboardCommand to copy password - this generates the audit event + cc = ClipboardCommand() + cc.execute(self.params, record=self.selected_record, output='clipboard', + username=None, copy_uid=False, login=False, totp=False, field=None, revision=None) + self.notify("🔑 Password copied to clipboard!", severity="information") + except Exception as e: + logging.debug(f"ClipboardCommand error: {e}") + self.notify("⚠️ No password found for this record", severity="warning") + else: + self.notify("⚠️ No record selected", severity="warning") + + def action_copy_username(self): + """Copy username of selected record to clipboard""" + if self.selected_record and self.selected_record in self.records: + record = self.records[self.selected_record] + if 'login' in record: + pyperclip.copy(record['login']) + self.notify("👤 Username copied to clipboard!", severity="information") + else: + self.notify("⚠️ No username found for this record", severity="warning") + else: + self.notify("⚠️ No record selected", severity="warning") + + def action_copy_url(self): + """Copy URL of selected record to clipboard""" + if self.selected_record and self.selected_record in self.records: + record = self.records[self.selected_record] + if 'login_url' in record: + pyperclip.copy(record['login_url']) + self.notify("🔗 URL copied to clipboard!", severity="information") + else: + self.notify("⚠️ No URL found for this record", severity="warning") + else: + self.notify("⚠️ No record selected", severity="warning") + + def action_copy_uid(self): + """Copy UID of selected record or folder to clipboard""" + if self.selected_record: + pyperclip.copy(self.selected_record) + self.notify("📋 Record UID copied to clipboard!", severity="information") + elif self.selected_folder: + pyperclip.copy(self.selected_folder) + self.notify("📋 Folder UID copied to clipboard!", severity="information") + else: + self.notify("⚠️ No record or folder selected", severity="warning") + + def action_copy_record(self): + """Copy entire record contents to clipboard (formatted or JSON based on view mode)""" + if self.selected_record: + try: + # Check if it's a Secrets Manager app record + if self.selected_record in self.app_record_uids: + # For Secrets Manager apps, copy the app data in JSON format + from ..proto import APIRequest_pb2, enterprise_pb2 + from .. import api, utils + import json + + record = self.records[self.selected_record] + app_title = record.get('title', 'Untitled') + + app_data = { + "app_name": app_title, + "app_uid": self.selected_record, + "client_devices": [], + "shares": [] + } + + try: + rq = APIRequest_pb2.GetAppInfoRequest() + rq.appRecordUid.append(utils.base64_url_decode(self.selected_record)) + rs = api.communicate_rest(self.params, rq, 'vault/get_app_info', rs_type=APIRequest_pb2.GetAppInfoResponse) + + if rs.appInfo: + app_info = rs.appInfo[0] + + # Collect client devices + client_devices = [x for x in app_info.clients if x.appClientType == enterprise_pb2.GENERAL] + for client in client_devices: + app_data["client_devices"].append({"device_name": client.id}) + + # Collect application access (shares) + for share in app_info.shares: + uid_str = utils.base64_url_encode(share.secretUid) + share_type = APIRequest_pb2.ApplicationShareType.Name(share.shareType) + + title = "Unknown" + if share_type == 'SHARE_TYPE_RECORD': + if uid_str in self.params.record_cache: + rec = self.params.record_cache[uid_str] + if 'data_unencrypted' in rec: + data = json.loads(rec['data_unencrypted']) + title = data.get('title', 'Untitled') + share_type_display = "RECORD" + elif share_type == 'SHARE_TYPE_FOLDER': + if hasattr(self.params, 'folder_cache'): + folder = self.params.folder_cache.get(uid_str) + if folder: + title = folder.name + share_type_display = "FOLDER" + else: + share_type_display = share_type + + app_data["shares"].append({ + "share_type": share_type, + "uid": uid_str, + "editable": share.editable, + "title": title, + "type": share_type_display + }) + except Exception as e: + logging.debug(f"Error fetching app info for copy: {e}") + + # Format based on view mode + if self.view_mode == 'json': + # Copy as JSON + formatted = json.dumps(app_data, indent=2) + pyperclip.copy(formatted) + self.notify("📋 Secrets Manager app JSON copied to clipboard!", severity="information") + else: + # Copy as formatted text (detail view) + lines = [] + lines.append("Secrets Manager Application") + lines.append(f"App Name: {app_title}") + lines.append(f"App UID: {self.selected_record}") + lines.append("") + + # Client devices + if app_data["client_devices"]: + lines.append(f"Client Devices ({len(app_data['client_devices'])}):") + for idx, device in enumerate(app_data["client_devices"], 1): + lines.append(f" {idx}. {device['device_name']}") + lines.append("") + else: + lines.append("No client devices registered for this Application") + lines.append("") + + # Application access + if app_data["shares"]: + lines.append("Application Access:") + lines.append("") + for share in app_data["shares"]: + lines.append(f" {share['type']}: {share['title']}") + lines.append(f" UID: {share['uid']}") + permissions = "Editable" if share['editable'] else "Read-Only" + lines.append(f" Permissions: {permissions}") + lines.append("") + else: + lines.append("Application Access:") + lines.append("No shared folders or records") + lines.append("") + + formatted = "\n".join(lines) + pyperclip.copy(formatted) + self.notify("📋 Secrets Manager app details copied to clipboard!", severity="information") + else: + # Regular record handling + record_data = self.records.get(self.selected_record, {}) + has_password = bool(record_data.get('password')) + + if self.view_mode == 'json': + # Copy JSON format (with actual password, not masked) + output = self._get_record_output(self.selected_record, format_type='json') + output = self._strip_ansi_codes(output) + json_obj = json.loads(output) + formatted = json.dumps(json_obj, indent=2) + pyperclip.copy(formatted) + # Generate audit event since JSON contains the password + if has_password: + self.params.queue_audit_event('copy_password', record_uid=self.selected_record) + self.notify("📋 JSON copied to clipboard!", severity="information") + else: + # Copy formatted text (without Rich markup) + content = self._format_record_for_tui(self.selected_record) + # Strip Rich markup for plain text clipboard + import re + plain = re.sub(r'\[/?[^\]]+\]', '', content) + pyperclip.copy(plain) + # Generate audit event if record has password (detail view includes password) + if has_password: + self.params.queue_audit_event('copy_password', record_uid=self.selected_record) + self.notify("📋 Record contents copied to clipboard!", severity="information") + except Exception as e: + logging.error(f"Error copying record: {e}", exc_info=True) + self.notify("⚠️ Failed to copy record contents", severity="error") + else: + self.notify("⚠️ No record selected", severity="warning") + + def action_show_help(self): + """Show help modal""" + self.push_screen(HelpScreen()) + + def action_show_user_info(self): + """Show user/whoami information in detail panel""" + self._display_whoami_info() + + def action_show_device_info(self): + """Show device information in detail panel""" + self._display_device_info() + + def action_sync_vault(self): + """Sync vault data from server (sync-down + enterprise-down) and refresh UI""" + self._update_status("Syncing vault data from server...") + + try: + # Run sync-down command + from .utils import SyncDownCommand + SyncDownCommand().execute(self.params) + + # Run enterprise-down if available (enterprise users) + try: + from .enterprise import EnterpriseDownCommand + EnterpriseDownCommand().execute(self.params) + except Exception: + pass # Not an enterprise user or command not available + + # Reload vault data and refresh UI + self.records = {} + self.record_to_folder = {} + self.records_in_subfolders = set() + self.file_attachment_to_parent = {} + self.record_file_attachments = {} + self.linked_record_to_parent = {} + self.record_linked_records = {} + self.app_record_uids = set() + self._record_output_cache = {} # Clear record output cache + self._load_vault_data() + self.device_info = self._load_device_info() # Refresh device info + self.whoami_info = self._load_whoami_info() # Refresh whoami info + self._setup_folder_tree() + self._update_header_info_display() # Update header info display + + self._update_status("Vault synced & refreshed") + self.notify("Vault synced & refreshed", severity="information") + except Exception as e: + logging.error(f"Error syncing vault: {e}", exc_info=True) + self._update_status(f"Sync failed: {str(e)}") + self.notify(f"Sync failed: {str(e)}", severity="error") + + # Vim-style navigation actions + def action_cursor_down(self): + """Move cursor down (Vim j)""" + focused = self.focused + if isinstance(focused, (Tree, DataTable)): + focused.action_cursor_down() + elif isinstance(focused, VerticalScroll): + # Scroll down in the detail view + focused.scroll_down(animate=False) + + def action_cursor_up(self): + """Move cursor up (Vim k)""" + focused = self.focused + if isinstance(focused, (Tree, DataTable)): + focused.action_cursor_up() + elif isinstance(focused, VerticalScroll): + # Scroll up in the detail view + focused.scroll_up(animate=False) + + def action_cursor_left(self): + """Move cursor left (Vim h)""" + focused = self.focused + if isinstance(focused, Tree): + # Collapse node in tree + if focused.cursor_node and focused.cursor_node.allow_expand: + focused.cursor_node.collapse() + + def action_cursor_right(self): + """Move cursor right (Vim l)""" + focused = self.focused + if isinstance(focused, Tree): + # Expand node in tree + if focused.cursor_node and focused.cursor_node.allow_expand: + focused.cursor_node.expand() + + def action_goto_top(self): + """Go to top (Vim g)""" + focused = self.focused + if isinstance(focused, DataTable): + focused.move_cursor(row=0) + elif isinstance(focused, Tree): + # Get first child of root instead of root itself to avoid collapsing + if focused.root and focused.root.children: + first_child = focused.root.children[0] + focused.select_node(first_child) + else: + focused.select_node(focused.root) + elif isinstance(focused, VerticalScroll): + focused.scroll_home(animate=False) + + def action_goto_bottom(self): + """Go to bottom (Vim G)""" + focused = self.focused + if isinstance(focused, DataTable): + focused.move_cursor(row=focused.row_count - 1) + elif isinstance(focused, Tree): + # Find the last visible node in the tree + def get_last_visible_node(node): + """Recursively find the last visible (expanded) node""" + if node.is_expanded and node.children: + return get_last_visible_node(node.children[-1]) + return node + last_node = get_last_visible_node(focused.root) + focused.select_node(last_node) + elif isinstance(focused, VerticalScroll): + focused.scroll_end(animate=False) + + def action_page_down(self): + """Page down (Vim CTRL+d) - half page""" + focused = self.focused + if isinstance(focused, DataTable): + # Move down by half the visible height + current_row = focused.cursor_row + page_size = max(1, self.size.height // 4) # Half page + new_row = min(current_row + page_size, focused.row_count - 1) + focused.move_cursor(row=new_row) + elif isinstance(focused, Tree): + # Move down through tree nodes + for _ in range(self.PAGE_DOWN_NODES): # Move down half page + focused.action_cursor_down() + elif isinstance(focused, VerticalScroll): + # Scroll down by page in detail view + focused.scroll_page_down(animate=False) + + def action_page_up(self): + """Page up (Vim CTRL+u) - half page""" + focused = self.focused + if isinstance(focused, DataTable): + # Move up by half the visible height + current_row = focused.cursor_row + page_size = max(1, self.size.height // 4) # Half page + new_row = max(current_row - page_size, 0) + focused.move_cursor(row=new_row) + elif isinstance(focused, Tree): + # Move up through tree nodes + for _ in range(self.PAGE_DOWN_NODES): # Move up half page + focused.action_cursor_up() + elif isinstance(focused, VerticalScroll): + # Scroll up by page in detail view + focused.scroll_page_up(animate=False) + + def action_page_down_full(self): + """Page down (Vim CTRL+f) - full page""" + focused = self.focused + if isinstance(focused, DataTable): + # Move down by full visible height + current_row = focused.cursor_row + page_size = max(1, self.size.height // 2) # Full page + new_row = min(current_row + page_size, focused.row_count - 1) + focused.move_cursor(row=new_row) + elif isinstance(focused, Tree): + # Move down through tree nodes + for _ in range(self.PAGE_DOWN_FULL_NODES): # Move down full page + focused.action_cursor_down() + elif isinstance(focused, VerticalScroll): + # Scroll down by full page in detail view + focused.scroll_page_down(animate=False) + + def action_page_up_full(self): + """Page up (Vim CTRL+b) - full page""" + focused = self.focused + if isinstance(focused, DataTable): + # Move up by full visible height + current_row = focused.cursor_row + page_size = max(1, self.size.height // 2) # Full page + new_row = max(current_row - page_size, 0) + focused.move_cursor(row=new_row) + elif isinstance(focused, Tree): + # Move up through tree nodes + for _ in range(self.PAGE_DOWN_FULL_NODES): # Move up full page + focused.action_cursor_up() + elif isinstance(focused, VerticalScroll): + # Scroll up by full page in detail view + focused.scroll_page_up(animate=False) + + def action_scroll_up(self): + """Scroll up one line (Vim CTRL+y)""" + focused = self.focused + if not self.search_input_active: + if isinstance(focused, Tree): + focused.scroll_relative(y=-1) + elif isinstance(focused, VerticalScroll): + focused.scroll_relative(y=-1) + + def action_scroll_down(self): + """Scroll down one line (Vim CTRL+e)""" + focused = self.focused + if not self.search_input_active: + if isinstance(focused, Tree): + focused.scroll_relative(y=1) + elif isinstance(focused, VerticalScroll): + focused.scroll_relative(y=1) + + def action_quit(self): + """Quit the application""" + self._stop_totp_timer() + self.exit() + + +class SuperShellCommand(Command): + """Command to launch the SuperShell TUI""" + + def get_parser(self): + from argparse import ArgumentParser + parser = ArgumentParser(prog='supershell', description='Launch full terminal vault UI with vim navigation') + # -h/--help is automatically added by ArgumentParser + return parser + + def is_authorised(self): + """Don't require pre-authentication - TUI handles all auth""" + return False + + def execute(self, params, **kwargs): + """Launch the SuperShell TUI - handles login if needed""" + from .. import display + from ..cli import debug_manager + + # Show government warning for GOV environments when entering SuperShell + if params.server and 'govcloud' in params.server.lower(): + display.show_government_warning() + + # Disable debug mode for SuperShell to prevent log output from messing up the TUI + saved_debug = getattr(params, 'debug', False) + saved_log_level = logging.getLogger().level + if saved_debug or logging.getLogger().level == logging.DEBUG: + params.debug = False + debug_manager.set_console_debug(False, params.batch_mode) + # Also set root logger level to suppress all debug output + logging.getLogger().setLevel(logging.WARNING) + + try: + self._execute_supershell(params, **kwargs) + finally: + # Restore debug state when SuperShell exits + if saved_debug: + params.debug = saved_debug + debug_manager.set_console_debug(True, params.batch_mode) + logging.getLogger().setLevel(saved_log_level) + + def _execute_supershell(self, params, **kwargs): + """Internal method to run SuperShell""" + import threading + import time + import sys + + class Spinner: + """Animated spinner that runs in a background thread""" + def __init__(self, message="Loading..."): + self.message = message + self.running = False + self.thread = None + self.chars = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] + self.colors = ['\033[36m', '\033[32m', '\033[33m', '\033[35m'] + + def _spin(self): + i = 0 + while self.running: + color = self.colors[i % len(self.colors)] + char = self.chars[i % len(self.chars)] + # Check running again before writing to avoid race condition + if not self.running: + break + sys.stdout.write(f"\r {color}{char}\033[0m {self.message}") + sys.stdout.flush() + time.sleep(0.1) + i += 1 + + def start(self): + self.running = True + self.thread = threading.Thread(target=self._spin, daemon=True) + self.thread.start() + + def stop(self, success_message=None): + self.running = False + if self.thread: + self.thread.join(timeout=0.5) + # Small delay to ensure thread has stopped writing + time.sleep(0.15) + # Clear spinner line (do it twice to handle any race condition) + sys.stdout.write("\r\033[K") + sys.stdout.write("\r\033[K") + sys.stdout.flush() + if success_message: + print(f" \033[32m✓\033[0m {success_message}") + + def update(self, message): + self.message = message + + # Check if authentication is needed + if not params.session_token: + from .utils import LoginCommand + try: + # Run login (no spinner - login may prompt for 2FA, password, etc.) + LoginCommand().execute(params, email=params.user, password=params.password, new_login=False) + + if not params.session_token: + logging.error("\nLogin failed or was cancelled.") + return + + # Sync vault data with spinner + sync_spinner = Spinner("Syncing vault data...") + sync_spinner.start() + try: + from .utils import SyncDownCommand + SyncDownCommand().execute(params) + sync_spinner.stop("Vault synced!") + except Exception as e: + sync_spinner.stop() + raise + + print() # Blank line before TUI + + except KeyboardInterrupt: + print("\n\nLogin cancelled.") + return + except Exception as e: + logging.error(f"\nLogin failed: {e}") + return + + # Launch the TUI app + try: + app = SuperShellApp(params) + result = app.run() + + # If user pressed '!' to exit to shell, start the Keeper shell + if result and "Exited to Keeper shell" in str(result): + print(result) # Show the exit message + # Check if we were in batch mode BEFORE modifying it + was_batch_mode = params.batch_mode + # Clear batch mode and pending commands so the shell runs interactively + params.batch_mode = False + params.commands = [c for c in params.commands if c.lower() not in ('q', 'quit')] + # Only start a new shell if we were in batch mode (ran 'keeper supershell' directly) + # Otherwise, just return to the existing interactive shell + if was_batch_mode: + from ..cli import loop as shell_loop + shell_loop(params, skip_init=True, suppress_goodbye=True) + # When the inner shell exits, queue 'q' so the outer batch-mode loop also exits + params.commands.append('q') + except KeyboardInterrupt: + logging.debug("SuperShell interrupted") + except Exception as e: + logging.error(f"Error running SuperShell: {e}") + raise diff --git a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py index a41a1d199..a1c98c0db 100644 --- a/keepercommander/commands/tunnel/port_forward/TunnelGraph.py +++ b/keepercommander/commands/tunnel/port_forward/TunnelGraph.py @@ -1,4 +1,4 @@ -from ....commands.tunnel.port_forward.tunnel_helpers import generate_random_bytes, get_config_uid +from .tunnel_helpers import generate_random_bytes, get_config_uid from ....keeper_dag import DAG, EdgeType from ....keeper_dag.connection.commander import Connection from ....keeper_dag.types import RefType, PamEndpoints @@ -21,7 +21,8 @@ def get_vertex_content(vertex): class TunnelDAG: - def __init__(self, params, encrypted_session_token, encrypted_transmission_key, record_uid: str, is_config=False): + def __init__(self, params, encrypted_session_token, encrypted_transmission_key, record_uid: str, + is_config=False, transmission_key=None): config_uid = None if not is_config: config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) @@ -32,8 +33,11 @@ def __init__(self, params, encrypted_session_token, encrypted_transmission_key, self.record.record_key = generate_random_bytes(32) self.encrypted_session_token = encrypted_session_token self.encrypted_transmission_key = encrypted_transmission_key - self.conn = Connection(params=params, encrypted_transmission_key=self.encrypted_transmission_key, + self.transmission_key = transmission_key + self.conn = Connection(params=params, + encrypted_transmission_key=self.encrypted_transmission_key, encrypted_session_token=self.encrypted_session_token, + transmission_key=self.transmission_key, use_write_protobuf=True ) self.linking_dag = DAG(conn=self.conn, record=self.record, graph_id=0, write_endpoint=PamEndpoints.PAM) diff --git a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py index 2b07bb750..a7c752472 100644 --- a/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py +++ b/keepercommander/commands/tunnel/port_forward/tunnel_helpers.py @@ -82,7 +82,7 @@ class CloseConnectionReason: Represents a structured close reason for WebRTC tunnel connections. Provides categorization and backward compatibility with legacy outcome strings. """ - + # Close reason codes with their properties REASONS = { 0: {"name": "Normal", "critical": False, "user_initiated": True, "retryable": False}, @@ -106,7 +106,7 @@ class CloseConnectionReason: 19: {"name": "ProtocolError", "critical": True, "user_initiated": False, "retryable": False}, 20: {"name": "UpstreamClosed", "critical": False, "user_initiated": False, "retryable": True}, } - + # Legacy outcome mapping for backward compatibility LEGACY_OUTCOMES = { "normal": 0, @@ -142,12 +142,12 @@ class CloseConnectionReason: "protocol_error": 19, "upstream_closed": 20, } - + def __init__(self, code, name=None): self.code = code self._reason_info = self.REASONS.get(code, self.REASONS[6]) # Default to Unknown self.name = name or self._reason_info["name"] - + @classmethod def from_code(cls, code): """Create CloseConnectionReason from numeric code""" @@ -156,37 +156,37 @@ def from_code(cls, code): else: logging.warning(f"Unknown close reason code: {code}, defaulting to Unknown") return cls(6) # Unknown - + @classmethod def from_legacy_outcome(cls, outcome): """Create CloseConnectionReason from legacy outcome string""" if not outcome or not isinstance(outcome, str): return cls(6) # Unknown - + # Try direct mapping first outcome_lower = outcome.lower().strip() code = cls.LEGACY_OUTCOMES.get(outcome_lower) - + if code is not None: return cls(code) - + # Try partial matching for common variations for legacy_key, legacy_code in cls.LEGACY_OUTCOMES.items(): if legacy_key in outcome_lower or outcome_lower in legacy_key: return cls(legacy_code) - + # Default to Unknown logging.warning(f"Unknown legacy outcome: '{outcome}', defaulting to Unknown") return cls(6) - + def is_critical(self): """Returns True if this is a critical failure requiring immediate attention""" return self._reason_info["critical"] - + def is_user_initiated(self): """Returns True if this was initiated by user action""" return self._reason_info["user_initiated"] - + def is_retryable(self): """Returns True if this failure is potentially retryable""" return self._reason_info["retryable"] @@ -214,7 +214,12 @@ def __init__(self, tube_id, conversation_id, gateway_uid, symmetric_key, self.websocket_thread = None self.websocket_ready_event = None self.websocket_stop_event = None - + # Optional attributes (set dynamically) + # Note: signal_handler is set after TunnelSignalHandler is created + self.signal_handler = None # type: ignore[assignment] + # Note: gateway_ready_event is an optional threading.Event set if needed + self.gateway_ready_event = None # type: ignore[assignment] + def update_activity(self): """Update last activity timestamp""" self.last_activity = time.time() @@ -291,13 +296,13 @@ def get_conversation_status(): with _CONVERSATION_KEYS_LOCK: active_conversations = len(_GLOBAL_CONVERSATION_KEYS) conversation_ids = list(_GLOBAL_CONVERSATION_KEYS.keys()) - + # Get tunnel session info to count active WebSockets with _TUNNEL_SESSIONS_LOCK: active_websockets = sum(1 for session in _GLOBAL_TUNNEL_SESSIONS.values() if session.websocket_thread and session.websocket_thread.is_alive()) total_tunnels = len(_GLOBAL_TUNNEL_SESSIONS) - + return { "active_conversations": active_conversations, "conversation_ids": conversation_ids, @@ -307,6 +312,105 @@ def get_conversation_status(): # Tunnel helper functions +def _configure_rust_logger_levels(current_is_debug: bool, log_level: int): + """ + Configure Rust logger levels based on debug mode. + + Args: + current_is_debug: Whether debug mode is currently enabled + log_level: Current effective log level + """ + # Quick Fix: switch only between ERROR and DEBUG + # RCA: Commander has 2 modes only DEBUG and non-debug (default) + # yet all rust log messages are always printed incl. DEBUG messages when non-debug mode is set + + # Configure Rust logger level based on debug mode + if current_is_debug or log_level <= logging.DEBUG: + root_logger = logging.getLogger() + # Ensure root logger can handle DEBUG messages + if root_logger.level > logging.DEBUG: + root_logger.setLevel(logging.DEBUG) + + # CRITICAL: Ensure root logger has a handler + # pyo3_log sends Rust logs to Python loggers, but if loggers have no handlers, + # messages are lost even if propagate=True + import sys + if not root_logger.handlers: + # Add a console handler if none exists + console_handler = logging.StreamHandler(sys.stderr) + console_handler.setFormatter(logging.Formatter( + '%(levelname)s:%(name)s:%(message)s' + )) + console_handler.setLevel(logging.DEBUG) + root_logger.addHandler(console_handler) + + # Set up a custom logger factory that adds handlers to all Rust loggers + # This ensures handlers are added even if loggers are created dynamically + original_logger_class = logging.getLoggerClass() + + class RustLoggerHandler(logging.Logger): + """Custom logger that auto-adds handlers for Rust loggers""" + def __init__(self, name, level=logging.NOTSET): + super().__init__(name, level) + if name.startswith('keeper_pam_webrtc_rs'): + self.setLevel(logging.DEBUG) + self.propagate = False # Disable propagation to prevent duplicate logs + if not self.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter( + '%(levelname)s:%(name)s:%(message)s' + )) + handler.setLevel(logging.DEBUG) + self.addHandler(handler) + + # Temporarily set our custom logger class + logging.setLoggerClass(RustLoggerHandler) + + # Now set up all existing loggers (disable propagation to prevent duplicates) + for logger_name in list(logging.Logger.manager.loggerDict.keys()): + if isinstance(logger_name, str) and logger_name.startswith('keeper_pam_webrtc_rs'): + rust_logger = logging.getLogger(logger_name) + rust_logger.setLevel(logging.DEBUG) + rust_logger.propagate = False # Disable propagation to prevent duplicate logs + if not rust_logger.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter( + '%(levelname)s:%(name)s:%(message)s' + )) + handler.setLevel(logging.DEBUG) + rust_logger.addHandler(handler) + + # pyo3_log creates loggers based on Rust module paths + tube_registry_logger = logging.getLogger("keeper_pam_webrtc_rs.python.tube_registry_binding") + tube_registry_logger.setLevel(logging.DEBUG) + tube_registry_logger.propagate = False # Disable propagation to prevent duplicate logs + if not tube_registry_logger.handlers: + handler = logging.StreamHandler(sys.stderr) + handler.setFormatter(logging.Formatter( + '%(levelname)s:%(name)s:%(message)s' + )) + handler.setLevel(logging.DEBUG) + tube_registry_logger.addHandler(handler) + + # Restore original logger class + logging.setLoggerClass(original_logger_class) + + logging.debug(f"Rust loggers enabled at DEBUG level") + enabled_loggers = [name for name in logging.Logger.manager.loggerDict.keys() + if isinstance(name, str) and name.startswith('keeper_pam_webrtc_rs')] + logging.debug(f"Enabled Rust loggers: {enabled_loggers}") + else: + # Set to ERROR when not debugging + main_rust_logger = logging.getLogger("keeper_pam_webrtc_rs") + main_rust_logger.setLevel(logging.ERROR) + + # Also set all existing Rust sub-loggers to ERROR + for logger_name in list(logging.Logger.manager.loggerDict.keys()): + if isinstance(logger_name, str) and logger_name.startswith('keeper_pam_webrtc_rs'): + logger = logging.getLogger(logger_name) + logger.setLevel(logging.ERROR) + + def get_or_create_tube_registry(params): """Get or create the tube registry instance, storing it on params for reuse""" try: @@ -322,6 +426,9 @@ def get_or_create_tube_registry(params): level=log_level ) + # Configure Rust logger levels based on debug mode + _configure_rust_logger_levels(current_is_debug, log_level) + # Reuse existing registry or create new one if not hasattr(params, 'tube_registry') or params.tube_registry is None: params.tube_registry = PyTubeRegistry() @@ -342,7 +449,7 @@ def cleanup_tube_registry(params): params.tube_registry = None except Exception as e: logging.warning(f"Error cleaning up tube registry: {e}") - + # Also clear all conversation keys when cleaning up everything clear_all_conversation_keys() @@ -678,7 +785,7 @@ async def connect_websocket_with_fallback(ws_endpoint, headers, ssl_context, tub """ Connect to WebSocket with backward compatibility for both websockets 15.0.1+ and 11.0.3 Handles parameter name differences between versions - + Args: ws_endpoint: WebSocket URL headers: Connection headers @@ -694,14 +801,14 @@ async def connect_websocket_with_fallback(ws_endpoint, headers, ssl_context, tub "ping_timeout": 20, "close_timeout": 30 } - + if WEBSOCKETS_VERSION == "asyncio": # websockets 15.0.1+ uses additional_headers and ssl_context/ssl parameters connect_kwargs = { **base_kwargs, "additional_headers": headers } - + # Try ssl_context parameter first, fallback to ssl if not supported if ssl_context: try: @@ -728,23 +835,23 @@ async def connect_websocket_with_fallback(ws_endpoint, headers, ssl_context, tub raise else: async with websockets_connect(ws_endpoint, **connect_kwargs) as websocket: - logging.info("WebSocket connection established") + logging.debug("WebSocket connection established") # Signal ready event immediately after connection if ready_event: ready_event.set() logging.debug("WebSocket ready event signaled") await handle_websocket_messages(websocket, tube_registry, timeout, stop_event) - + elif WEBSOCKETS_VERSION == "legacy": # websockets 11.0.3 uses extra_headers and ssl parameters connect_kwargs = { **base_kwargs, "extra_headers": headers } - + if ssl_context: async with websockets_connect(ws_endpoint, ssl=ssl_context, **connect_kwargs) as websocket: - logging.info("WebSocket connection established (legacy)") + logging.debug("WebSocket connection established (legacy)") # Signal ready event immediately after connection if ready_event: ready_event.set() @@ -752,7 +859,7 @@ async def connect_websocket_with_fallback(ws_endpoint, headers, ssl_context, tub await handle_websocket_messages(websocket, tube_registry, timeout, stop_event) else: async with websockets_connect(ws_endpoint, **connect_kwargs) as websocket: - logging.info("WebSocket connection established (legacy)") + logging.debug("WebSocket connection established (legacy)") # Signal ready event immediately after connection if ready_event: ready_event.set() @@ -770,7 +877,7 @@ async def handle_websocket_responses(params, tube_registry, timeout=60, gateway_ """ Direct WebSocket handler that connects, listens for responses, and routes them to Rust. Uses global conversation key store to support multiple concurrent tunnels. - + Args: params: KeeperParams instance tube_registry: PyTubeRegistry instance @@ -794,7 +901,6 @@ async def handle_websocket_responses(params, tube_registry, timeout=60, gateway_ 'TransmissionKey': bytes_to_base64(encrypted_transmission_key), 'Authorization': f'KeeperUser {bytes_to_base64(encrypted_session_token)}', } - # Set up SSL context ssl_context = None if ws_endpoint.startswith('wss://'): @@ -802,7 +908,7 @@ async def handle_websocket_responses(params, tube_registry, timeout=60, gateway_ if not VERIFY_SSL: ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - + # Connect and handle messages with backward compatibility # Handle parameter differences between websockets versions await connect_websocket_with_fallback(ws_endpoint, headers, ssl_context, tube_registry, timeout, ready_event, stop_event) @@ -810,14 +916,14 @@ async def handle_websocket_responses(params, tube_registry, timeout=60, gateway_ async def handle_websocket_messages(websocket, tube_registry, timeout, stop_event=None): """Handle WebSocket message processing - + Args: websocket: WebSocket connection tube_registry: PyTubeRegistry instance timeout: Maximum time to listen for messages stop_event: threading.Event to signal when to stop listening """ - + # Listen for messages with timeout try: start_time = time.time() @@ -826,7 +932,7 @@ async def handle_websocket_messages(websocket, tube_registry, timeout, stop_even if stop_event and stop_event.is_set(): logging.debug("WebSocket stop event received, closing connection") break - + try: # Wait for a message with short timeout to allow checking stop event and overall timeout message_text = await asyncio.wait_for(websocket.recv(), timeout=1.0) @@ -836,10 +942,17 @@ async def handle_websocket_messages(websocket, tube_registry, timeout, stop_even response_data = json.loads(message_text) if isinstance(response_data, list): # Handle an array of responses - for response_item in response_data: + logging.debug(f"Received {len(response_data)} WebSocket messages") + for idx, response_item in enumerate(response_data): + logging.debug(f" Message {idx+1}/{len(response_data)}: conversationId={response_item.get('conversationId', 'N/A')}, type={response_item.get('type', 'N/A')}") + if 'payload' in response_item: + logging.debug(f" Payload preview: {str(response_item['payload'])[:100]}...") route_message_to_rust(response_item, tube_registry) elif isinstance(response_data, dict): # Handle a single response object + logging.debug(f"Received WebSocket message: conversationId={response_data.get('conversationId', 'N/A')}, type={response_data.get('type', 'N/A')}") + if 'payload' in response_data: + logging.debug(f" Payload preview: {str(response_data['payload'])[:100]}...") route_message_to_rust(response_data, tube_registry) else: logging.warning(f"Unexpected WebSocket message format: {type(response_data)}") @@ -848,7 +961,7 @@ async def handle_websocket_messages(websocket, tube_registry, timeout, stop_even # No message received within 1 second, continue loop to check stop event and overall timeout continue except ConnectionClosed: - logging.info("WebSocket connection closed") + logging.debug("WebSocket connection closed") break except Exception as e: @@ -862,25 +975,25 @@ def route_message_to_rust(response_item, tube_registry): try: conversation_id = response_item.get('conversationId') logging.debug(f"Processing WebSocket message for conversation: {conversation_id}") - + if not conversation_id: logging.debug("No conversationId in response, skipping") return - + # Get the symmetric key for this conversation from global store symmetric_key = get_conversation_key(conversation_id) - + if not symmetric_key: logging.debug(f"No encryption key found for conversation: {conversation_id}") logging.debug(f"Registered conversations: {get_all_conversation_ids()}") return - + logging.debug(f"Found encryption key for conversation: {conversation_id}") - + # Decrypt the message payload encrypted_payload = response_item.get('payload', '') logging.debug(f"Processing payload for conversation {conversation_id}, payload length: {len(encrypted_payload) if encrypted_payload else 0}") - + if encrypted_payload: # Parse the payload JSON string first try: @@ -891,11 +1004,11 @@ def route_message_to_rust(response_item, tube_registry): logging.error(f"Failed to parse payload as JSON: {e}") logging.error(f"Raw payload: {encrypted_payload[:200]}...") return - + # Handle different types of responses if payload_data.get('is_ok') and payload_data.get('data'): data_field = payload_data.get('data', '') - + # Check if this is a plain text acknowledgment (not encrypted) if isinstance(data_field, str) and ( "ice candidate" in data_field.lower() or @@ -906,13 +1019,24 @@ def route_message_to_rust(response_item, tube_registry): data_field.endswith(conversation_id) # Plain text responses often end with conversation ID ): logging.debug(f"Received plain text acknowledgment: {data_field}") + + # CRITICAL: Mark ICE candidate response received to allow next candidate + if "ice candidate" in data_field.lower() or "ice candidates" in data_field.lower(): + # Find the signal handler and mark response received + tube_id = tube_registry.tube_id_from_connection_id(conversation_id) + if tube_id: + session = get_tunnel_session(tube_id) + if session and hasattr(session, 'signal_handler') and session.signal_handler: + session.signal_handler.ice_candidate_response_received = True + logging.debug(f"Marked ICE candidate response received for tube {tube_id}") + return - + # Check if this is just a buffered acknowledgment (these sometimes have invalid base64) if "buffered" in data_field.lower(): logging.debug(f"Received buffered acknowledgment: {data_field}") return - + logging.debug("Detected SDP answer response - processing...") # This looks like an SDP answer response encrypted_data = data_field @@ -929,7 +1053,7 @@ def route_message_to_rust(response_item, tube_registry): if decrypted_data: data_text = bytes_to_string(decrypted_data).replace("'", '"') logging.debug(f"Successfully decrypted data for {conversation_id}, length: {len(data_text)}") - + # Check if this is a simple JSON-encoded acknowledgment string try: parsed_text = json.loads(data_text) @@ -940,8 +1064,17 @@ def route_message_to_rust(response_item, tube_registry): return except (json.JSONDecodeError, TypeError): pass # Not a simple JSON string, continue with normal processing - + data_json = json.loads(data_text) + + # Ensure data_json is a dictionary before processing + if not isinstance(data_json, dict): + logging.debug(f"Data is not a dictionary (got {type(data_json).__name__}), treating as acknowledgment: {data_json}") + return + + # Log what type of data we received + logging.debug(f"🔓 Decrypted payload type: {data_json.get('type', 'unknown')}, keys: {list(data_json.keys())}") + if "answer" in data_json: answer_sdp = data_json.get('answer') @@ -957,7 +1090,7 @@ def route_message_to_rust(response_item, tube_registry): tube_id = tube_registry.tube_id_from_connection_id(url_safe_conversation_id) if tube_id: logging.debug(f"Found tube using URL-safe conversion: {url_safe_conversation_id}") - + if not tube_id: logging.error(f"No tube ID found for conversation: {conversation_id} (also tried URL-safe version)") return @@ -967,23 +1100,27 @@ def route_message_to_rust(response_item, tube_registry): # Send any buffered local ICE candidates now that we have the answer session = get_tunnel_session(tube_id) - if session and session.buffered_ice_candidates: - logging.debug(f"Sending {len(session.buffered_ice_candidates)} buffered ICE candidates after answer") - # Need to get the signal handler to send candidates - # Since we're in the routing function, we need to find the handler - # is stored in the session for this purpose - if hasattr(session, 'signal_handler') and session.signal_handler: - for candidate in session.buffered_ice_candidates: - session.signal_handler._send_ice_candidate_immediately(candidate, tube_id) - session.buffered_ice_candidates.clear() - else: - logging.warning(f"No signal handler found for tube {tube_id} to send buffered candidates") + if session: + session.gateway_ready_event.set() + + # Send any buffered local ICE candidates now that we have the answer + if session.buffered_ice_candidates: + logging.debug(f"Sending {len(session.buffered_ice_candidates)} buffered ICE candidates after answer") + # Need to get the signal handler to send candidates + # Since we're in the routing function, we need to find the handler + # is stored in the session for this purpose + if hasattr(session, 'signal_handler') and session.signal_handler: + for candidate in session.buffered_ice_candidates: + session.signal_handler._send_ice_candidate_immediately(candidate, tube_id) + session.buffered_ice_candidates.clear() + else: + logging.warning(f"No signal handler found for tube {tube_id} to send buffered candidates") elif "offer" in data_json or (data_json.get("type") == "offer"): # Gateway is sending us an ICE restart offer offer_sdp = data_json.get('sdp') or data_json.get('offer') if offer_sdp: - logging.info(f"Received ICE restart offer from Gateway for conversation: {conversation_id}") + logging.debug(f"Received ICE restart offer from Gateway for conversation: {conversation_id}") tube_id = tube_registry.tube_id_from_connection_id(conversation_id) if not tube_id: @@ -1016,7 +1153,7 @@ def route_message_to_rust(response_item, tube_registry): answer_sdp = tube_registry.create_answer(tube_id) if answer_sdp: - logging.info(f"Generated ICE restart answer for tube {tube_id}") + logging.debug(f"Generated ICE restart answer for tube {tube_id}") # Get session to access symmetric key and other info session = get_tunnel_session(tube_id) @@ -1048,11 +1185,12 @@ def route_message_to_rust(response_item, tube_registry): destination_gateway_uid_str=session.gateway_uid, gateway_action=GatewayActionWebRTCSession( conversation_id=session.conversation_id, + message_id=GatewayAction.conversation_id_to_message_id(session.conversation_id), inputs={ "recordUid": signal_handler.record_uid, 'kind': 'ice_restart_answer', 'base64Nonce': signal_handler.base64_nonce, - 'conversationType': 'tunnel', + 'conversationType': signal_handler.conversation_type, "data": encrypted_data, "trickleICE": signal_handler.trickle_ice, } @@ -1062,17 +1200,14 @@ def route_message_to_rust(response_item, tube_registry): gateway_timeout=GATEWAY_TIMEOUT ) - logging.info(f"ICE restart answer sent for tube {tube_id}") - print(f"{bcolors.OKGREEN}ICE restart answer sent successfully{bcolors.ENDC}") + logging.debug(f"ICE restart answer sent for tube {tube_id}") else: logging.error(f"No signal handler found for tube {tube_id} to send answer") else: logging.error(f"Failed to generate ICE restart answer for tube {tube_id}") - print(f"{bcolors.FAIL}Failed to generate ICE restart answer{bcolors.ENDC}") except Exception as e: logging.error(f"Error handling ICE restart offer for tube {tube_id}: {e}") - print(f"{bcolors.FAIL}Error processing ICE restart offer: {e}{bcolors.ENDC}") else: logging.warning(f"Received offer message without SDP data for conversation: {conversation_id}") elif "candidates" in data_json: @@ -1084,7 +1219,7 @@ def route_message_to_rust(response_item, tube_registry): tube_id = tube_registry.tube_id_from_connection_id(url_safe_conversation_id) if tube_id: logging.debug(f"Found tube using URL-safe conversion: {url_safe_conversation_id}") - + if not tube_id: logging.error(f"No tube ID found for conversation: {conversation_id} (also tried URL-safe version)") return @@ -1092,7 +1227,7 @@ def route_message_to_rust(response_item, tube_registry): candidates_list = data_json.get('candidates', []) candidate_count = len(candidates_list) logging.debug(f"Received {candidate_count} ICE candidates from gateway for {conversation_id}") - + # Gateway sends candidates in consistent format, pass them directly to Rust for candidate in candidates_list: logging.debug(f"Forwarding candidate to Rust: {candidate[:100]}...") # Log first 100 chars @@ -1105,7 +1240,7 @@ def route_message_to_rust(response_item, tube_registry): logging.error("Failed to decrypt data") else: logging.warning("No 'data' field found in response") - + # Handle error responses elif (payload_data.get('errors') is not None and payload_data.get('errors') != [] and @@ -1121,7 +1256,7 @@ def route_message_to_rust(response_item, tube_registry): logging.warning(f"Unhandled payload type for {conversation_id}: {payload_data}") else: logging.warning(f"No encrypted payload in message for conversation: {conversation_id}") - + except Exception as e: logging.error(f"Error routing message to Rust: {e}") import traceback @@ -1131,29 +1266,29 @@ def route_message_to_rust(response_item, tube_registry): def start_websocket_listener(params, tube_registry, timeout=60, gateway_uid=None, tunnel_session=None): """ Start WebSocket listener in a background thread. - + Creates a DEDICATED WebSocket for the provided tunnel_session. Each tunnel gets its own independent WebSocket connection. - + Args: params: KeeperParams instance tube_registry: PyTubeRegistry instance timeout: Maximum time to listen for messages (seconds) gateway_uid: Gateway UID (optional) tunnel_session: TunnelSession instance for dedicated WebSocket (required) - + Returns: (thread, is_reused) tuple - is_reused is always False (each tunnel gets its own WebSocket) """ if tunnel_session is None: raise ValueError("tunnel_session is required for dedicated WebSocket architecture") - + logging.debug(f"Creating dedicated WebSocket for tunnel {tunnel_session.tube_id}") - + # Create per-tunnel events tunnel_session.websocket_ready_event = threading.Event() tunnel_session.websocket_stop_event = threading.Event() - + # Start a dedicated WebSocket listener thread for this tunnel def run_dedicated_websocket(): loop = asyncio.new_event_loop() @@ -1169,7 +1304,7 @@ def run_dedicated_websocket(): finally: loop.close() logging.debug(f"Dedicated WebSocket closed for tunnel {tunnel_session.tube_id}") - + tunnel_session.websocket_thread = threading.Thread( target=run_dedicated_websocket, daemon=True, @@ -1209,14 +1344,16 @@ def __init__(self, conversation_id, record_uid): class TunnelSignalHandler: """ Signal handler for WebRTC tunnel events with HTTP sending and WebSocket receiving. - + Features immediate ICE candidate sending: - Sends ICE candidates immediately as they arrive from Rust - Always sends candidates in {"candidates": [candidate]} array format for gateway consistency - Maintains consistent protocol with gateway expectations """ - def __init__(self, params, record_uid, gateway_uid, symmetric_key, base64_nonce, conversation_id, tube_registry, tube_id=None, trickle_ice=False, websocket_router=None): + def __init__(self, params, record_uid, gateway_uid, symmetric_key, base64_nonce, conversation_id, + tube_registry, tube_id=None, trickle_ice=False, websocket_router=None, + conversation_type='tunnel'): self.params = params self.record_uid = record_uid self.gateway_uid = gateway_uid @@ -1232,7 +1369,7 @@ def __init__(self, params, record_uid, gateway_uid, symmetric_key, base64_nonce, self.websocket_router = websocket_router # For key cleanup self.offer_sent = False # Track if offer has been sent to gateway self.buffered_ice_candidates = [] # Buffer ICE candidates until offer is sent - + # WebSocket routing is handled automatically - no setup needed if trickle_ice and not WEBSOCKETS_AVAILABLE: raise Exception("Trickle ICE requires WebSocket support - install with: pip install websockets") @@ -1250,6 +1387,11 @@ def signal_from_rust(self, response: dict): session = get_tunnel_session(tube_id) if tube_id else None if session: session.update_activity() + else: + if tube_id: + logging.debug( + f"No tunnel session found for tube {tube_id} while handling signal '{signal_kind}'" + ) # Handle local connection state changes if signal_kind == 'connection_state_changed': @@ -1264,32 +1406,38 @@ def signal_from_rust(self, response: dict): logging.error(f"Connection failed for tube {tube_id} - ICE restart may be attempted by Rust") elif new_state == 'connected': - logging.debug(f"Connection established/restored for tube {tube_id}") + logging.debug( + f"Connection established/restored for tube {tube_id} " + f"(conversation_id={conversation_id_from_signal or self.conversation_id})" + ) + logging.debug(f"Connection state: connected") + + # CRITICAL: Mark connection as connected - IMMEDIATELY stop sending ICE candidates + self.connection_connected = True + self.ice_sending_in_progress = False # Stop any pending ICE candidate sends if not self.connection_success_shown: self.connection_success_shown = True # Get tunnel session for record details if session: - print(f"\n{bcolors.OKGREEN}Connection established successfully.{bcolors.ENDC}") + logging.debug(f"\n{bcolors.OKGREEN}Connection established successfully.{bcolors.ENDC}") # Display record title if available if session.record_title: - print(f"{bcolors.OKBLUE}Record:{bcolors.ENDC} {session.record_title}") + logging.debug(f"{bcolors.OKBLUE}Record:{bcolors.ENDC} {session.record_title}") # Display remote target if session.target_host and session.target_port: - print(f"{bcolors.OKBLUE}Remote:{bcolors.ENDC} {session.target_host}:{session.target_port}") + logging.debug(f"{bcolors.OKBLUE}Remote:{bcolors.ENDC} {session.target_host}:{session.target_port}") # Display local listening address if session.host and session.port: - print(f"{bcolors.OKBLUE}Local:{bcolors.ENDC} {session.host}:{session.port}") + logging.debug(f"{bcolors.OKBLUE}Local:{bcolors.ENDC} {session.host}:{session.port}") # Display conversation ID if session.conversation_id: - print(f"{bcolors.OKBLUE}Conversation ID:{bcolors.ENDC} {session.conversation_id}") - - print() # Empty line for readability + logging.debug(f"{bcolors.OKBLUE}Conversation ID:{bcolors.ENDC} {session.conversation_id}") # Flush any buffered ICE candidates now that we're connected if session and session.buffered_ice_candidates: @@ -1302,7 +1450,7 @@ def signal_from_rust(self, response: dict): logging.debug(f"Connection in progress for tube {tube_id}") elif new_state == "closed": - logging.info(f"Connection closed for tube {tube_id}") + logging.debug(f"Connection closed for tube {tube_id}") else: logging.debug(f"Connection state for tube {tube_id}: {new_state}") @@ -1311,7 +1459,7 @@ def signal_from_rust(self, response: dict): elif signal_kind == 'channel_closed': conversation_id_from_signal = conversation_id_from_signal or self.conversation_id - logging.info(f"Received 'channel_closed' signal for conversation '{conversation_id_from_signal}' of tube '{tube_id}'.") + logging.debug(f"Received 'channel_closed' signal for conversation '{conversation_id_from_signal}' of tube '{tube_id}'.") # Check if the tunnel session exists and is already closed session = get_tunnel_session(tube_id) if tube_id else None @@ -1322,41 +1470,37 @@ def signal_from_rust(self, response: dict): try: data_json = json.loads(data) if data else {} - + # Try to get structured close reason first close_reason = None if "close_reason" in data_json: reason_code = data_json["close_reason"].get("code") if reason_code is not None: close_reason = CloseConnectionReason.from_code(reason_code) - logging.info(f" Structured close reason: {close_reason.name} (code: {reason_code})") - + logging.debug(f" Structured close reason: {close_reason.name} (code: {reason_code})") + # Fallback to old string-based outcome for backward compatibility if close_reason is None: outcome = data_json.get("outcome", "unknown") close_reason = CloseConnectionReason.from_legacy_outcome(outcome) - logging.info(f" Legacy outcome: '{outcome}' -> {close_reason.name}") - + logging.debug(f" Legacy outcome: '{outcome}' -> {close_reason.name}") + # Handle based on reason type if close_reason.is_critical(): - logging.error(f"Critical failure in tunnel '{tube_id}': {close_reason.name}. Stopping session immediately.") - print(f"{bcolors.FAIL}Tunnel closed due to critical failure: {close_reason.name}{bcolors.ENDC}") - + logging.error(f"{bcolors.FAIL}Tunnel closed due to critical failure - '{tube_id}': {close_reason.name}{bcolors.ENDC}") + elif close_reason.is_user_initiated(): - logging.info(f"User-initiated closure of tunnel '{tube_id}': {close_reason.name}.") - print(f"{bcolors.OKBLUE}Tunnel closed: {close_reason.name}{bcolors.ENDC}") - + logging.debug(f"{bcolors.OKBLUE}User-initiated closure of tunnel '{tube_id}': {close_reason.name}{bcolors.ENDC}") + elif close_reason.is_retryable(): - logging.warning(f"Retryable failure in tunnel '{tube_id}': {close_reason.name}.") - print(f"{bcolors.WARNING}Tunnel closed with retryable error: {close_reason.name}{bcolors.ENDC}") - + logging.debug(f"{bcolors.WARNING}Tunnel closed with retryable error - '{tube_id}': {close_reason.name}{bcolors.ENDC}") + else: - logging.info(f"Tunnel '{tube_id}' closed with reason: {close_reason.name}.") - print(f"{bcolors.OKBLUE}Tunnel closed: {close_reason.name}{bcolors.ENDC}") + logging.debug(f"{bcolors.OKBLUE}Tunnel '{tube_id}' closed with reason: {close_reason.name}{bcolors.ENDC}") except (json.JSONDecodeError, KeyError) as e: logging.error(f"Failed to parse close reason: {e}. Defaulting to critical handling.") - print(f"{bcolors.FAIL}Tunnel closed due to unknown error{bcolors.ENDC}") + logging.debug(f"{bcolors.FAIL}Tunnel closed due to unknown error{bcolors.ENDC}") # Clean up the tunnel session when channel closes if tube_id: @@ -1367,7 +1511,7 @@ def signal_from_rust(self, response: dict): if hasattr(session, 'signal_handler') and session.signal_handler: session.signal_handler.cleanup() logging.debug(f"Cleaned up conversation keys for closed tunnel {tube_id}") - + # Stop dedicated WebSocket if this tunnel has one if session.websocket_stop_event and session.websocket_thread: logging.debug(f"Stopping dedicated WebSocket for tunnel {tube_id}") @@ -1378,14 +1522,13 @@ def signal_from_rust(self, response: dict): logging.warning(f"Dedicated WebSocket for tunnel {tube_id} did not close in time") else: logging.debug(f"Dedicated WebSocket closed for tunnel {tube_id}") - + unregister_tunnel_session(tube_id) return # Local event, no gateway response needed elif signal_kind == 'error': error_msg = data if data else 'Unknown error' logging.error(f"Tunnel error for {tube_id}: {error_msg}") - print(f"{bcolors.FAIL}Tunnel error: {error_msg}{bcolors.ENDC}") # Clean up on error as well if tube_id and data.lower() in ["failed", "closed"]: # Get session before unregistering to access signal handler @@ -1395,7 +1538,7 @@ def signal_from_rust(self, response: dict): if hasattr(session, 'signal_handler') and session.signal_handler: session.signal_handler.cleanup() logging.debug(f"Cleaned up conversation keys for failed tunnel {tube_id}") - + # Stop dedicated WebSocket if this tunnel has one if session.websocket_stop_event and session.websocket_thread: logging.debug(f"Stopping dedicated WebSocket for failed tunnel {tube_id}") @@ -1406,12 +1549,26 @@ def signal_from_rust(self, response: dict): logging.warning(f"Dedicated WebSocket for tunnel {tube_id} did not close in time") else: logging.debug(f"Dedicated WebSocket closed for failed tunnel {tube_id}") - + unregister_tunnel_session(tube_id) return # Local event, no gateway response needed # Handle ICE candidates - use session to check if offer is sent AND WebSocket is ready elif signal_kind == 'icecandidate': + # CRITICAL: Stop immediately if connection is already established + if self.connection_connected: + logging.debug(f"Skipping ICE candidate - connection already established for tube {tube_id}") + return + + # CRITICAL: Check if we're already sending a candidate (serialize) + if self.ice_sending_in_progress: + logging.debug(f"ICE candidate send already in progress for tube {tube_id}, buffering this candidate") + if session: + session.buffered_ice_candidates.append(data) + priority_score, candidate_type = self._get_ice_candidate_priority(data) + logging.debug(f"Buffered candidate (priority={priority_score}, type={candidate_type})") + return + logging.debug(f"Received ICE candidate for tube {tube_id}") # Check if we should buffer this candidate @@ -1427,13 +1584,14 @@ def signal_from_rust(self, response: dict): logging.debug(f"Buffering ICE candidate - WebSocket not ready yet for tube {tube_id}") if should_buffer: - session.buffered_ice_candidates.append(data) + if session: + session.buffered_ice_candidates.append(data) else: # Send the candidate immediately (but still in array format for gateway consistency) self._send_ice_candidate_immediately(data, tube_id) return elif signal_kind == 'ice_restart_request': - logging.info(f"Received ICE restart request for tube {tube_id}") + logging.debug(f"Received ICE restart request for tube {tube_id}") # ICE restart requires trickle ICE mode if not self.trickle_ice: @@ -1445,7 +1603,7 @@ def signal_from_rust(self, response: dict): restart_sdp = self.tube_registry.restart_ice(tube_id) if restart_sdp: - logging.info(f"ICE restart successful for tube {tube_id}") + logging.debug(f"ICE restart successful for tube {tube_id}") self._send_restart_offer(restart_sdp, tube_id) else: logging.error(f"ICE restart failed for tube {tube_id}") @@ -1458,7 +1616,7 @@ def signal_from_rust(self, response: dict): elif signal_kind == 'ice_restart_offer': # Rust initiated ICE restart and generated offer (e.g., network change detected) # We need to send this offer to Gateway and get an answer - logging.info(f"Received ice_restart_offer from Rust for tube {tube_id}") + logging.debug(f"Received ice_restart_offer from Rust for tube {tube_id}") # ICE restart requires trickle ICE mode if not self.trickle_ice: @@ -1478,13 +1636,86 @@ def signal_from_rust(self, response: dict): # Unknown signal type else: logging.debug(f"Unknown signal type: {signal_kind}") - + + def _get_ice_candidate_priority(self, candidate_data): + """ + Extract priority from ICE candidate string to determine send order. + + ICE candidate priority order (highest to lowest): + 1. typ host - Direct local connection (fastest, most reliable) + 2. typ srflx - Server reflexive via STUN (medium speed) + 3. typ relay - Relay via TURN (slowest, last resort) + + Args: + candidate_data: ICE candidate string (SDP format) + + Returns: + tuple: (priority_score, candidate_type) where: + - priority_score: Higher = better (0-100) + - candidate_type: 'host', 'srflx', 'relay', or 'unknown' + """ + if not isinstance(candidate_data, str): + candidate_str = str(candidate_data) + else: + candidate_str = candidate_data + + # Parse candidate string for type + candidate_lower = candidate_str.lower() + + if 'typ host' in candidate_lower: + # Highest priority: direct local connection + return (100, 'host') + elif 'typ srflx' in candidate_lower: + # Medium priority: server reflexive (NAT traversal via STUN) + return (50, 'srflx') + elif 'typ relay' in candidate_lower: + # Lowest priority: relay via TURN (most expensive/slowest) + return (0, 'relay') + else: + # Unknown type - try to extract priority number from candidate string + # Format: "candidate:foundation component priority protocol ..." + import re + priority_match = re.search(r'candidate:[^\s]+\s+\d+\s+\w+\s+(\d+)', candidate_str) + if priority_match: + priority_num = int(priority_match.group(1)) + # Normalize priority (ICE priorities are typically 0-2^31) + # Higher priority number = better candidate + normalized = min(100, priority_num // 21474836) # Scale to 0-100 + return (normalized, 'unknown') + return (10, 'unknown') # Default low priority for unknown types + + def _sort_candidates_by_priority(self, candidates): + """ + Sort ICE candidates by priority (best/fastest first). + + Args: + candidates: List of candidate strings + + Returns: + List of candidates sorted by priority (highest first) + """ + def priority_key(candidate): + priority_score, candidate_type = self._get_ice_candidate_priority(candidate) + # Sort by priority score (descending), then by type + return (-priority_score, candidate_type) + + return sorted(candidates, key=priority_key) def _send_ice_candidate_immediately(self, candidate_data, tube_id=None): """Send a single ICE candidate immediately via HTTP POST to /send_controller_message - + Always sends candidates as {"candidates": [candidate]} array format for gateway consistency. This matches the gateway expectation: action_inputs['data'].get('candidates') + + Serializes sending to prevent parallel sends and stops immediately if connection is established. """ + # CRITICAL: Double-check connection state before sending (connection might have been established) + if self.connection_connected: + logging.debug(f"Skipping ICE candidate send - connection already established") + return + + # Set flag to serialize sending (prevent parallel sends) + self.ice_sending_in_progress = True + try: # Always use array format for consistency with gateway expectations # Gateway expects: action_inputs['data'].get('candidates') and iterates: for candidate in ice_candidates @@ -1501,11 +1732,12 @@ def _send_ice_candidate_immediately(self, candidate_data, tube_id=None): destination_gateway_uid_str=self.gateway_uid, gateway_action=GatewayActionWebRTCSession( conversation_id=self.conversation_id, + message_id=GatewayAction.conversation_id_to_message_id(self.conversation_id), inputs={ "recordUid": self.record_uid, 'kind': 'icecandidate', 'base64Nonce': self.base64_nonce, - 'conversationType': 'tunnel', + 'conversationType': self.conversation_type, "data": encrypted_data, "trickleICE": self.trickle_ice, } @@ -1514,12 +1746,12 @@ def _send_ice_candidate_immediately(self, candidate_data, tube_id=None): is_streaming=self.trickle_ice, # Streaming only for trickle ICE gateway_timeout=GATEWAY_TIMEOUT ) - + if self.trickle_ice: logging.debug("ICE candidate sent via HTTP POST - response expected via WebSocket") else: logging.debug("ICE candidate sent via HTTP POST") - + except Exception as e: # Check if this is a gateway offline error (RRC_CONTROLLER_DOWN) or bad state error (RRC_BAD_STATE) error_str = str(e) @@ -1533,9 +1765,8 @@ def _send_ice_candidate_immediately(self, candidate_data, tube_id=None): # Bad state - transient during WebSocket startup, log at debug level logging.debug(f"Bad state when sending ICE candidate: {e}") else: - # Other errors - log at error level and print to console + # Other errors - log at error level logging.error(f"Failed to send ICE candidate via HTTP: {e}") - print(f"{bcolors.WARNING}Failed to send ICE candidate: {e}{bcolors.ENDC}") def _send_restart_offer(self, restart_sdp, tube_id): """Send ICE restart offer via HTTP POST to /send_controller_message with encryption @@ -1561,11 +1792,12 @@ def _send_restart_offer(self, restart_sdp, tube_id): destination_gateway_uid_str=self.gateway_uid, gateway_action=GatewayActionWebRTCSession( conversation_id=self.conversation_id, + message_id=GatewayAction.conversation_id_to_message_id(self.conversation_id), inputs={ "recordUid": self.record_uid, 'kind': 'ice_restart_offer', # New kind for ICE restart 'base64Nonce': self.base64_nonce, - 'conversationType': 'tunnel', + 'conversationType': self.conversation_type, "data": encrypted_data, "trickleICE": self.trickle_ice, } @@ -1576,10 +1808,10 @@ def _send_restart_offer(self, restart_sdp, tube_id): ) if self.trickle_ice: - logging.info(f"ICE restart offer sent via HTTP POST for tube {tube_id} - response expected via WebSocket") + logging.debug(f"ICE restart offer sent via HTTP POST for tube {tube_id} - response expected via WebSocket") else: - logging.info(f"ICE restart offer sent via HTTP POST for tube {tube_id}") - print(f"{bcolors.OKGREEN}ICE restart offer sent successfully{bcolors.ENDC}") + logging.debug(f"ICE restart offer sent via HTTP POST for tube {tube_id}") + logging.debug(f"{bcolors.OKGREEN}ICE restart offer sent successfully{bcolors.ENDC}") except Exception as e: # Check if this is a gateway offline error (RRC_CONTROLLER_DOWN) or bad state error (RRC_BAD_STATE) @@ -1594,9 +1826,8 @@ def _send_restart_offer(self, restart_sdp, tube_id): # Bad state - transient during WebSocket startup, log at debug level logging.debug(f"Bad state when sending ICE restart offer for tube {tube_id}: {e}") else: - # Other errors - log at error level and print to console + # Other errors - log at error level logging.error(f"Failed to send ICE restart offer for tube {tube_id}: {e}") - print(f"{bcolors.FAIL}Failed to send ICE restart offer: {e}{bcolors.ENDC}") def cleanup(self): """Cleanup resources""" @@ -1626,18 +1857,18 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, seed, target_host, target_port, socks, trickle_ice=True, record_title=None, allow_supply_host=False): """ Start a tunnel using Rust WebRTC with trickle ICE via HTTP POST and WebSocket responses. - + This function uses a global WebSocket architecture that supports multiple concurrent tunnels. Messages are routed to Rust based on conversationId using a shared global key store. The endpoint table is displayed ONLY when both the local socket AND WebRTC connection are ready. - + Architecture: - Shared WebSocket listener handles multiple tunnels simultaneously - Global conversation key store: conversationId → symmetric_key mapping - Message flow: WebSocket → decrypt with a conversation key → send to Rust - Signal handler shows endpoint table only when fully connected - Multiple tunnels can run concurrently - + Display flow: 1. "Establishing a tunnel with trickle ICE between Commander and Gateway..." 2. "Creating WebRTC offer and setting up local listener..." @@ -1651,24 +1882,24 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, - "connected" 8. Shows endpoint table with listening address (ONLY when fully ready) 9. "Tunnel is ready for traffic" - + Multi-tunnel Support: - Each tunnel gets its own conversation ID and encryption key - Single shared WebSocket connection handles all tunnel communications - Automatic key registration/cleanup per tunnel - Concurrent tunnels work independently - + Usage: # Start tunnel (shows endpoint table only when truly ready) result = start_rust_tunnel(params, record_uid, gateway_uid, host, port, seed, target_host, target_port, socks) - + If result["success"]: # Global WebSocket router automatically handles all responses # Endpoint table shown only when both socket and WebRTC are ready - + # Multiple tunnels can be started concurrently result2 = start_rust_tunnel(params, record_uid2, gateway_uid2, host2, port2, ...) - + Returns: dict: { "success": bool, @@ -1717,13 +1948,16 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, params, host, port, target_host, target_port, socks, nonce ) + # Determine conversation type (tunnel or protocol-specific) + conversation_type = webrtc_settings.get('conversationType', 'tunnel') + # Register the encryption key in the global conversation store # IMPORTANT: Gateway may convert URL-safe base64 to standard base64 and add padding # Register both versions to handle the conversion: # - URL-safe: uses - and _ (e.g., "2srIxfCAsQAEWGmH-52yzw") # - Standard: uses + and / with = padding (e.g., "2srIxfCAsQAEWGmH+52yzw==") register_conversation_key(conversation_id_original, symmetric_key) - + # Also register the standard base64 version that gateway might return # Convert URL-safe base64 to standard base64 standard_conversation_id = conversation_id_original.replace('-', '+').replace('_', '/') @@ -1734,11 +1968,10 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, if standard_conversation_id != conversation_id_original: register_conversation_key(standard_conversation_id, symmetric_key) logging.debug(f"Registered both URL-safe and standard base64 conversation IDs") - # Create a temporary tunnel session BEFORE creating the tube so ICE candidates can be buffered immediately import uuid temp_tube_id = str(uuid.uuid4()) - + # Pre-create tunnel session with temporary ID to buffer early ICE candidates tunnel_session = TunnelSession( tube_id=temp_tube_id, @@ -1753,13 +1986,13 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, target_host=target_host, target_port=target_port ) - + # Register the temporary session so ICE candidates can be buffered immediately register_tunnel_session(temp_tube_id, tunnel_session) - + # Create the tube to get the WebRTC offer with trickle ICE logging.debug("Creating WebRTC offer with trickle ICE gathering") - + # Create signal handler for Rust events signal_handler = TunnelSignalHandler( params=params, @@ -1771,13 +2004,14 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, tube_registry=tube_registry, tube_id=temp_tube_id, # Use temp ID initially trickle_ice=trickle_ice, + conversation_type=conversation_type ) # Store signal handler reference so we can send buffered candidates later tunnel_session.signal_handler = signal_handler - + logging.debug(f"{bcolors.OKBLUE}Creating WebRTC offer and setting up local listener...{bcolors.ENDC}") - + offer = tube_registry.create_tube( conversation_id=conversation_id_original, # Use original, not base64 encoded settings=webrtc_settings, @@ -1799,13 +2033,13 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, return {"success": False, "error": error_msg} commander_tube_id = offer['tube_id'] - + # Update both signal handler and tunnel session with real tube ID signal_handler.tube_id = commander_tube_id signal_handler.host = host # Store for later endpoint display signal_handler.port = port tunnel_session.tube_id = commander_tube_id - + # Get the actual listening address from Rust (source of truth) if 'actual_local_listen_addr' in offer and offer['actual_local_listen_addr']: rust_addr = offer['actual_local_listen_addr'] @@ -1819,27 +2053,27 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, logging.debug(f"Using actual Rust listening address: {rust_host}:{rust_port}") except Exception as e: logging.warning(f"Failed to parse Rust address '{rust_addr}': {e}") - + # Unregister temporary session and register with real tube ID unregister_tunnel_session(temp_tube_id) register_tunnel_session(commander_tube_id, tunnel_session) - + logging.debug(f"Registered encryption key for conversation: {conversation_id_original}") logging.debug(f"Expecting WebSocket responses for conversation ID: {conversation_id_original}") - + # Start DEDICATED WebSocket listener for this tunnel # Each tunnel gets its own WebSocket connection - no sharing, no contention! websocket_thread, is_websocket_reused = start_websocket_listener( params, tube_registry, timeout=300, gateway_uid=gateway_uid, tunnel_session=tunnel_session # Pass tunnel_session for dedicated WebSocket ) - + # Wait for WebSocket to establish connection before sending streaming requests # The router requires an active WebSocket connection for is_streaming=True if trickle_ice: # For trickle ICE, we MUST wait for WebSocket to be ready max_wait = 15.0 # Maximum wait time in seconds (increased for slow networks) - + # Use the tunnel's dedicated ready event (not global) if tunnel_session.websocket_ready_event: logging.debug(f"Waiting for dedicated WebSocket to connect (max {max_wait}s)...") @@ -1850,7 +2084,7 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, unregister_tunnel_session(commander_tube_id) return {"success": False, "error": "WebSocket connection timeout"} logging.debug("Dedicated WebSocket connection established and ready for streaming") - + # DEDICATED WebSocket: No mutex needed! Each tunnel has its own connection. # Backend registration is independent - no contention, no delays needed. # Just a small delay to ensure backend is ready (much shorter than before) @@ -1868,17 +2102,17 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, # Give it a moment to start but don't block time.sleep(0.5) logging.debug("Non-trickle ICE: WebSocket optional, proceeding") - + # Verify the session was stored correctly stored_session = get_tunnel_session(commander_tube_id) if stored_session: logging.debug(f"Verified tunnel session stored: tube={commander_tube_id}, host={stored_session.host}, port={stored_session.port}") else: logging.error(f"Failed to store tunnel session for tube: {commander_tube_id}") - + # Send offer to gateway via HTTP POST with streamResponse=true logging.debug(f"{bcolors.OKBLUE}Sending offer for {conversation_id_original} to gateway...{bcolors.ENDC}") - + # Prepare the offer data data = {"offer": offer.get("offer")} @@ -1899,13 +2133,13 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, # For non-trickle ICE, use is_streaming=False (response via HTTP) max_retries = int(os.getenv('TUNNEL_START_RETRIES', '2')) retry_delay = float(os.getenv('TUNNEL_RETRY_DELAY', '1.5')) - + for attempt in range(max_retries + 1): try: if attempt > 0: logging.debug(f"Retry attempt {attempt}/{max_retries} after {retry_delay}s delay...") time.sleep(retry_delay) - + router_response = router_send_action_to_gateway( params=params, destination_gateway_uid_str=gateway_uid, @@ -1925,20 +2159,20 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, is_streaming=trickle_ice, # Streaming only for trickle ICE gateway_timeout=GATEWAY_TIMEOUT ) - + # Success! Break out of retry loop logging.debug(f"{bcolors.OKGREEN}Offer sent to gateway{bcolors.ENDC}") - + # Mark offer as sent in both signal handler and session signal_handler.offer_sent = True tunnel_session.offer_sent = True break # Success - exit retry loop - + except Exception as e: error_msg = str(e) is_bad_state = "RRC_BAD_STATE" in error_msg is_last_attempt = (attempt == max_retries) - + if is_bad_state and not is_last_attempt: # Retryable error and we have attempts left - this is expected during WebSocket startup logging.debug(f"RRC_BAD_STATE on attempt {attempt + 1}/{max_retries + 1} - WebSocket backend may need more time") @@ -1949,9 +2183,9 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, if is_bad_state and is_last_attempt: logging.error(f"RRC_BAD_STATE persists after {max_retries} retries") logging.error("This may indicate network issues or backend problems") - + logging.error(f"Failed to send offer via HTTP: {error_msg}") - + # Cleanup on final failure logging.debug(f"Cleaning up failed tunnel {commander_tube_id}") @@ -1973,30 +2207,30 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, unregister_tunnel_session(commander_tube_id) return {"success": False, "error": f"Failed to send offer via HTTP: {e}"} - + # Continue with the rest of the flow after successful offer send # Trickle ICE: Response comes via WebSocket (HTTP response is empty) # Non-trickle ICE: Response comes via HTTP (contains SDP answer) - + # For non-trickle ICE, process the HTTP response (contains SDP answer) if not trickle_ice and router_response: logging.debug("Non-trickle ICE: Processing SDP answer from HTTP response") try: # router_response is a dict with 'response' key containing the gateway payload gateway_payload = router_response.get('response', {}) - + # The response has nested structure: response -> payload (JSON string) -> data (encrypted) payload_str = gateway_payload.get('payload') if payload_str: payload_json = json.loads(payload_str) logging.debug(f"Non-trickle ICE: Parsed payload JSON, keys: {payload_json.keys()}") - + encrypted_answer = payload_json.get('data') if encrypted_answer: # Decrypt the answer using the tunnel's symmetric key decrypted_answer = tunnel_decrypt(symmetric_key, encrypted_answer) answer_data = json.loads(decrypted_answer) - + if 'answer' in answer_data: answer_sdp = answer_data['answer'] logging.debug(f"Non-trickle ICE: Received SDP answer via HTTP, setting in Rust") @@ -2012,7 +2246,7 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, logging.error(f"Non-trickle ICE: Failed to process HTTP response: {e}") import traceback logging.error(f"Traceback: {traceback.format_exc()}") - + # Send any buffered ICE candidates that arrived before offer was sent (trickle ICE only) if trickle_ice and tunnel_session.buffered_ice_candidates: # Ensure WebSocket backend is fully ready before flushing candidates @@ -2069,12 +2303,12 @@ def start_rust_tunnel(params, record_uid, gateway_uid, host, port, def check_tunnel_connection_status(tube_registry, tube_id, timeout=None): """ Check the connection status of a tunnel tube. - + Args: tube_registry: The PyTubeRegistry instance tube_id: The tube ID to check timeout: Optional timeout in seconds to wait for connection (None = no waiting) - + Returns: dict: { "connected": bool, @@ -2084,7 +2318,7 @@ def check_tunnel_connection_status(tube_registry, tube_id, timeout=None): """ if not tube_registry or not tube_id: return {"connected": False, "state": "unknown", "error": "Invalid tube registry or ID"} - + try: if timeout is None: # Check the current state @@ -2098,17 +2332,17 @@ def check_tunnel_connection_status(tube_registry, tube_id, timeout=None): # Wait for connection with timeout max_wait_time = timeout check_interval = 0.5 - + for i in range(int(max_wait_time / check_interval)): try: state = tube_registry.get_connection_state(tube_id) logging.debug(f"Connection state check {i+1}: {state}") - + if state.lower() == "connected": return {"connected": True, "state": state, "error": None} elif state.lower() in ["failed", "closed", "disconnected"]: return {"connected": False, "state": state, "error": f"Connection failed with state: {state}"} - + time.sleep(check_interval) except Exception as e: if "not found" in str(e).lower(): @@ -2116,7 +2350,7 @@ def check_tunnel_connection_status(tube_registry, tube_id, timeout=None): else: logging.warning(f"Could not check connection state: {e}") time.sleep(check_interval) - + # Timeout reached try: final_state = tube_registry.get_connection_state(tube_id) @@ -2126,7 +2360,7 @@ def check_tunnel_connection_status(tube_registry, tube_id, timeout=None): return {"connected": False, "state": "not_found", "error": "Tube was removed from registry"} else: return {"connected": False, "state": "unknown", "error": f"Connection verification failed: {e}"} - + except Exception as e: return {"connected": False, "state": "error", "error": str(e)} @@ -2134,35 +2368,35 @@ def check_tunnel_connection_status(tube_registry, tube_id, timeout=None): def wait_for_tunnel_connection(tunnel_result, timeout=30, show_progress=True): """ Wait for a tunnel to establish connection, with optional progress display. - + Args: tunnel_result: Result dict from start_rust_tunnel timeout: Maximum time to wait in seconds show_progress: Whether to show progress messages - + Returns: dict: Connection status result """ if not tunnel_result.get("success"): return {"connected": False, "error": "Tunnel initiation failed"} - + tube_registry = tunnel_result.get("tube_registry") tube_id = tunnel_result.get("tube_id") - + if not tube_registry or not tube_id: return {"connected": False, "error": "Invalid tunnel result - missing registry or tube ID"} - + if show_progress: - print(f"{bcolors.OKBLUE}Waiting for tunnel connection (timeout: {timeout}s)...{bcolors.ENDC}") - + logging.debug(f"{bcolors.OKBLUE}Waiting for tunnel connection (timeout: {timeout}s)...{bcolors.ENDC}") + result = check_tunnel_connection_status(tube_registry, tube_id, timeout) - + if show_progress: if result["connected"]: # Success messages are now shown by the signal handler when connection establishes logging.debug("Tunnel connection wait completed successfully") else: error_msg = result.get("error", "Unknown error") - print(f"{bcolors.FAIL}Tunnel connection failed: {error_msg}{bcolors.ENDC}") - - return result \ No newline at end of file + logging.debug(f"{bcolors.FAIL}Tunnel connection failed: {error_msg}{bcolors.ENDC}") + + return result diff --git a/keepercommander/commands/tunnel_and_connections.py b/keepercommander/commands/tunnel_and_connections.py index 155c1a059..8ff224d4f 100644 --- a/keepercommander/commands/tunnel_and_connections.py +++ b/keepercommander/commands/tunnel_and_connections.py @@ -323,7 +323,8 @@ def execute(self, params, **kwargs): encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) if record_type in "pamNetworkConfiguration pamAwsConfiguration pamAzureConfiguration".split(): - tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid, is_config=True) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid, is_config=True, + transmission_key=transmission_key) tmp_dag.edit_tunneling_config(tunneling=_tunneling) tmp_dag.print_tunneling_config(record_uid, None) else: @@ -355,8 +356,10 @@ def execute(self, params, **kwargs): existing_config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) - tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid) - old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid) + tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid, + transmission_key=transmission_key) + old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid, + transmission_key=transmission_key) if config_uid and existing_config_uid != config_uid: old_dag.remove_from_dag(record_uid) @@ -827,7 +830,8 @@ def execute(self, params, **kwargs): encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) if record_type in "pamNetworkConfiguration pamAwsConfiguration pamAzureConfiguration".split(): - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid, is_config=True) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, record_uid, is_config=True, + transmission_key=transmission_key) tdag.edit_tunneling_config(connections=_connections, session_recording=_recording, typescript_recording=_typescript_recording) if not kwargs.get("silent", False): tdag.print_tunneling_config(record_uid, None) else: @@ -922,8 +926,10 @@ def execute(self, params, **kwargs): existing_config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid) - old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid, + transmission_key=transmission_key) + old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid, + transmission_key=transmission_key) if config_uid and existing_config_uid != config_uid: old_dag.remove_from_dag(record_uid) @@ -1307,7 +1313,7 @@ def update_connection_int(field_name, value): return # resolve PAM Config - encrypted_session_token, encrypted_transmission_key, _ = get_keeper_tokens(params) + encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) existing_config_uid = get_config_uid(params, encrypted_session_token, encrypted_transmission_key, record_uid) existing_config_uid = str(existing_config_uid) if existing_config_uid else '' @@ -1333,7 +1339,8 @@ def update_connection_int(field_name, value): if not config_uid: raise CommandError('pam rbi edit', f'{bcolors.FAIL}PAM Config record not found.{bcolors.ENDC}') - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, config_uid, + transmission_key=transmission_key) if tdag is None or not tdag.linking_dag.has_graph: raise CommandError('', f"{bcolors.FAIL}No valid PAM Configuration UID set. " "This must be set or supplied for connections to work. " @@ -1342,7 +1349,8 @@ def update_connection_int(field_name, value): if config_uid: if existing_config_uid and existing_config_uid != config_uid: - old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid) + old_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, existing_config_uid, + transmission_key=transmission_key) old_dag.remove_from_dag(record_uid) logging.debug(f'Updated existing PAM Config UID from: {existing_config_uid} to: {config_uid}') tdag.link_resource_to_config(record_uid) @@ -1527,7 +1535,8 @@ def execute(self, params, **kwargs): if pam_config_uid: encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) - tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_config_uid, True) + tdag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, pam_config_uid, True, + transmission_key=transmission_key) tdag.link_resource_to_config(record_uid) tdag.link_user_to_resource(pam_user_uid, record_uid, True, True) diff --git a/keepercommander/commands/utils.py b/keepercommander/commands/utils.py index 4d2feb96a..214f28619 100644 --- a/keepercommander/commands/utils.py +++ b/keepercommander/commands/utils.py @@ -25,6 +25,7 @@ from datetime import timedelta from typing import Optional, Dict, List, Set +from colorama import Fore, Style from cryptography.hazmat.primitives.asymmetric import ec, rsa from google.protobuf.json_format import MessageToDict @@ -40,7 +41,7 @@ from .. import __version__, vault from .. import api, rest_api, loginv3, crypto, utils, constants, error, vault_extensions from ..breachwatch import BreachWatch -from ..display import bcolors +from ..display import bcolors, post_login_summary from ..error import CommandError from ..generator import KeeperPasswordGenerator, DicewarePasswordGenerator, CryptoPassphraseGenerator from ..params import KeeperParams, LAST_RECORD_UID, LAST_FOLDER_UID, LAST_SHARED_FOLDER_UID @@ -234,8 +235,51 @@ def register_command_info(aliases, command_info): this_device_available_command_verbs = ['rename', 'register', 'persistent-login', 'ip-auto-approve', 'no-yubikey-pin', 'timeout', '2fa_expiration'] -this_device_parser = argparse.ArgumentParser(prog='this-device', description='Display and modify settings of the current device') -this_device_parser.add_argument('ops', nargs='*', help="operation str: " + ", ".join(this_device_available_command_verbs)) +this_device_help_epilog = """ +Sub-commands: + + this-device rename + Change the display name of this device + Example: this-device rename "My Laptop" + + this-device register + Register device for persistent login (stores encrypted data key) + + this-device persistent-login + Enable/disable 'Stay Logged In' on this device + Example: this-device persistent-login on + + this-device timeout + Set auto-logout timer + Examples: this-device timeout 10m (10 minutes) + this-device timeout 1h (1 hour) + this-device timeout 7d (7 days) + this-device timeout 30d (30 days) + + this-device ip-auto-approve + Auto-approve device when logging in from same IP address + Example: this-device ip-auto-approve on + + this-device no-yubikey-pin + Skip PIN prompt for YubiKey/security key authentication + Example: this-device no-yubikey-pin on + + this-device 2fa_expiration + Set how long 2FA approval is remembered + Examples: this-device 2fa_expiration login (every login) + this-device 2fa_expiration 12h (12 hours) + this-device 2fa_expiration 24h (24 hours) + this-device 2fa_expiration 30d (30 days) + this-device 2fa_expiration forever (remember forever) +""" +this_device_parser = argparse.ArgumentParser( + prog='this-device', + description='Display and modify settings of the current device', + epilog=this_device_help_epilog, + formatter_class=argparse.RawDescriptionHelpFormatter +) +this_device_parser.add_argument('ops', nargs='*', metavar='command', help='sub-command to run (see below)') +this_device_parser.add_argument('--format', dest='format', action='store', choices=['text', 'json'], default='text', help='output format') this_device_parser.error = raise_parse_exception this_device_parser.exit = suppress_exit @@ -286,7 +330,7 @@ def register_command_info(aliases, command_info): help_parser = argparse.ArgumentParser(prog='help', description='Displays help on a specific command') -help_help = 'Commander\'s command (Optional -- if not specified, list of available commands is displayed)' +help_help = 'Command name, or "basics" for essential commands guide' help_parser.add_argument('command', action='store', type=str, nargs='*', help=help_help) help_parser.add_argument('--legacy', dest='legacy', action='store_true', help='Show legacy/deprecated commands') help_parser.error = raise_parse_exception @@ -451,20 +495,35 @@ def get_parser(self): def execute(self, params, **kwargs): ops = kwargs.get('ops') + output_format = kwargs.get('format', 'text') if len(ops) == 0: - ThisDeviceCommand.print_device_info(params) + return ThisDeviceCommand.print_device_info(params, output_format=output_format) + + + action = ops[0].lower() if ops else '' + + # Check for valid sub-commands that need a value + if len(ops) == 1 and action != 'register': + if action in ('timeout', 'to'): + print(f"Usage: this-device timeout (e.g., 10m, 1h, 7d, 30d)") + elif action == '2fa_expiration': + print(f"Usage: this-device 2fa_expiration (e.g., login, 12h, 24h, 30d, forever)") + elif action in ('persistent_login', 'persistent-login', 'pl'): + print(f"Usage: this-device persistent-login ") + elif action in ('ip_auto_approve', 'ip-auto-approve', 'iaa'): + print(f"Usage: this-device ip-auto-approve ") + elif action == 'no-yubikey-pin': + print(f"Usage: this-device no-yubikey-pin ") + elif action in ('rename', 'ren'): + print(f"Usage: this-device rename ") + else: + print(f"Unknown sub-command: {action}") + print(f"Run {Fore.GREEN}this-device -h{Fore.RESET} for detailed help.") return - if len(ops) >= 1 and ops[0].lower() != 'register': - if len(ops) == 1 and ops[0].lower() != 'register': - logging.error("Must supply action and value. Available sub-commands: " + ", ".join(this_device_available_command_verbs)) - return - - if len(ops) != 2: - logging.error("Must supply action and value. Available sub-commands: " + ", ".join(this_device_available_command_verbs)) - return - - action = ops[0].lower() + if len(ops) >= 1 and action != 'register' and len(ops) != 2: + print(f"Invalid arguments. Run {Fore.GREEN}this-device -h{Fore.RESET} for help.") + return def register_device(): is_device_registered = loginv3.LoginV3API.register_encrypted_data_key_for_device(params) @@ -589,68 +648,119 @@ def compare_device_tokens(t1: str, t2: str): return acct_summary_dict, this_device @staticmethod - def print_device_info(params: KeeperParams): + def get_device_info(params: KeeperParams): + """Get device info as a dictionary (for programmatic use)""" acct_summary_dict, this_device = ThisDeviceCommand.get_account_summary_and_this_device(params) - print('{:>32}: {}'.format('Device Name', this_device['deviceName'])) - # print("{:>32}: {}".format('API Client Version', rest_api.CLIENT_VERSION)) + # Build device info dictionary + device_name = this_device.get('deviceName', 'Unknown') + data_key_present = this_device.get('encryptedDataKeyPresent', False) - if 'encryptedDataKeyPresent' in this_device: - print("{:>32}: {}".format('Data Key Present', (bcolors.OKGREEN + 'YES' + bcolors.ENDC) if this_device['encryptedDataKeyPresent'] else (bcolors.FAIL + 'NO' + bcolors.ENDC))) - else: - print("{:>32}: {}".format('Data Key Present', (bcolors.FAIL + 'missing' + bcolors.ENDC))) - - if 'ipDisableAutoApprove' in acct_summary_dict['settings']: - ipDisableAutoApprove = acct_summary_dict['settings']['ipDisableAutoApprove'] - # ip_disable_auto_approve - If enabled, the device is NOT automatically approved - # If disabled, the device will be auto approved - ipAutoApprove = not ipDisableAutoApprove - print("{:>32}: {}".format('IP Auto Approve', - (bcolors.OKGREEN + 'ON' + bcolors.ENDC) - if ipAutoApprove else - (bcolors.FAIL + 'OFF' + bcolors.ENDC))) - else: - print("{:>32}: {}".format('IP Auto Approve', (bcolors.OKGREEN + 'ON' + bcolors.ENDC))) - # ip_disable_auto_approve = 0 / disabled (default) <==> IP Auto Approve :ON + # IP Auto Approve (inverted from ipDisableAutoApprove) + ip_auto_approve = not acct_summary_dict['settings'].get('ipDisableAutoApprove', False) - persistentLogin = acct_summary_dict['settings'].get('persistentLogin', False) - print("{:>32}: {}".format('Persistent Login', - (bcolors.OKGREEN + 'ON' + bcolors.ENDC) - if persistentLogin and not ThisDeviceCommand.is_persistent_login_disabled(params) else - (bcolors.FAIL + 'OFF' + bcolors.ENDC))) + # Persistent Login + persistent_login = acct_summary_dict['settings'].get('persistentLogin', False) + if persistent_login: + persistent_login = not ThisDeviceCommand.is_persistent_login_disabled(params) - no_user_verify = acct_summary_dict['settings'].get('securityKeysNoUserVerify', False) - print("{:>32}: {}".format( - 'Security Key No PIN', (bcolors.OKGREEN + 'ON' + bcolors.ENDC) - if no_user_verify else (bcolors.FAIL + 'OFF' + bcolors.ENDC))) - - if 'securityKeysNoUserVerify' in acct_summary_dict['settings']: - device_timeout = get_delta_from_timeout_setting(acct_summary_dict['settings']['logoutTimer']) - print("{:>32}: {}".format('Device Logout Timeout', format_timeout(device_timeout))) + # Security Key No PIN + security_key_no_pin = acct_summary_dict['settings'].get('securityKeysNoUserVerify', False) + # Device Logout Timeout if 'logoutTimer' in acct_summary_dict['settings']: device_timeout = get_delta_from_timeout_setting(acct_summary_dict['settings']['logoutTimer']) - print("{:>32}: {}".format('Device Logout Timeout', format_timeout(device_timeout))) - + device_timeout_str = format_timeout(device_timeout) else: device_timeout = timedelta(hours=1) - print("{:>32}: Default".format('Logout Timeout')) + device_timeout_str = 'Default' + # Enterprise Logout Timeout + enterprise_timeout = None + enterprise_timeout_str = None + effective_timeout_str = None if 'Enforcements' in acct_summary_dict and 'longs' in acct_summary_dict['Enforcements']: logout_timeout = next((x['value'] for x in acct_summary_dict['Enforcements']['longs'] if x['key'] == 'logout_timer_desktop'), None) if logout_timeout: enterprise_timeout = timedelta(minutes=int(logout_timeout)) - print("{:>32}: {}".format('Enterprise Logout Timeout', format_timeout(enterprise_timeout))) + enterprise_timeout_str = format_timeout(enterprise_timeout) + effective_timeout_str = format_timeout(min(enterprise_timeout, device_timeout)) + + is_sso_user = params.settings.get('sso_user', False) if hasattr(params, 'settings') else False + config_file = params.config_filename if hasattr(params, 'config_filename') else None + + return { + 'device_name': device_name, + 'data_key_present': data_key_present, + 'ip_auto_approve': ip_auto_approve, + 'persistent_login': persistent_login, + 'security_key_no_pin': security_key_no_pin, + 'device_logout_timeout': device_timeout_str, + 'enterprise_logout_timeout': enterprise_timeout_str, + 'effective_logout_timeout': effective_timeout_str, + 'is_sso_user': is_sso_user, + 'config_file': config_file, + } + + @staticmethod + def print_device_info(params: KeeperParams, output_format: str = 'text'): + device_info = ThisDeviceCommand.get_device_info(params) + + device_name = device_info['device_name'] + data_key_present = device_info['data_key_present'] + ip_auto_approve = device_info['ip_auto_approve'] + persistent_login = device_info['persistent_login'] + security_key_no_pin = device_info['security_key_no_pin'] + device_timeout_str = device_info['device_logout_timeout'] + enterprise_timeout_str = device_info['enterprise_logout_timeout'] + effective_timeout_str = device_info['effective_logout_timeout'] + is_sso_user = device_info['is_sso_user'] + config_file = device_info['config_file'] + + # JSON output + if output_format == 'json': + import json + print(json.dumps(device_info)) + return - print("{:>32}: {}".format('Effective Logout Timeout', - format_timeout(min(enterprise_timeout, device_timeout)))) + # Text output - clean display + DIM = Fore.WHITE - print('{:>32}: {}'.format('Is SSO User', params.settings['sso_user'] if 'sso_user' in params.settings else False)) + def label_value(label, value): + return f" {DIM}{label:>22}{Fore.RESET}: {value}" - print('{:>32}: {}'.format('Config file', params.config_filename)) + print() + print(f" {Style.BRIGHT}Device Settings{Style.RESET_ALL}") + print(f" {DIM}{'─' * 50}{Fore.RESET}") - print("\nAvailable sub-commands: ", bcolors.OKBLUE + (", ".join(this_device_available_command_verbs)) + bcolors.ENDC) + print(label_value('Device Name', device_name)) + + if data_key_present is not None: + dk_display = f"{Fore.GREEN}YES{Fore.RESET}" if data_key_present else f"{Fore.RED}NO{Fore.RESET}" + else: + dk_display = f"{Fore.RED}missing{Fore.RESET}" + print(label_value('Data Key Present', dk_display)) + + ip_display = f"{Fore.GREEN}ON{Fore.RESET}" if ip_auto_approve else f"{Fore.RED}OFF{Fore.RESET}" + print(label_value('IP Auto Approve', ip_display)) + + pl_display = f"{Fore.GREEN}ON{Fore.RESET}" if persistent_login else f"{Fore.RED}OFF{Fore.RESET}" + print(label_value('Persistent Login', pl_display)) + + sk_display = f"{Fore.GREEN}ON{Fore.RESET}" if security_key_no_pin else f"{Fore.RED}OFF{Fore.RESET}" + print(label_value('Security Key No PIN', sk_display)) + + print(label_value('Logout Timeout', device_timeout_str)) + + if enterprise_timeout_str: + print(label_value('Enterprise Timeout', enterprise_timeout_str)) + print(label_value('Effective Timeout', effective_timeout_str)) + + print(label_value('Is SSO User', str(is_sso_user))) + print(label_value('Config file', config_file)) + print() + print(f" For usage details, type {Fore.GREEN}this-device -h{Fore.RESET}") class RecordDeleteAllCommand(Command): @@ -1122,6 +1232,60 @@ class WhoamiCommand(Command): def get_parser(self): return whoami_parser + @staticmethod + def get_whoami_info(params: KeeperParams, verbose: bool = False): + """Get whoami info as a dictionary (for programmatic use)""" + data = {} + + if params.session_token: + data['logged_in'] = True + hostname = get_hostname(params.rest_context.server_base) + data['user'] = params.user + data['server'] = hostname + data['data_center'] = get_data_center(hostname) + + environment = get_environment(hostname) + if environment: + data['environment'] = environment + + if params.license: + account_type = params.license['account_type'] if 'account_type' in params.license else None + if account_type == 2: + data['admin'] = params.enterprise is not None + + account_type_name = 'Enterprise' if account_type == 2 \ + else 'Family Plan' if account_type == 1 \ + else params.license['product_type_name'] + data['account_type'] = account_type_name + data['renewal_date'] = params.license['expiration_date'] + + if 'bytes_total' in params.license: + storage_bytes = int(params.license['bytes_total']) + storage_gb = storage_bytes >> 30 + storage_bytes_used = params.license['bytes_used'] if 'bytes_used' in params.license else 0 + data['storage_capacity'] = f'{storage_gb}GB' + storage_usage = (int(storage_bytes_used) * 100 // storage_bytes) if storage_bytes != 0 else 0 + data['storage_usage'] = f'{storage_usage}%' + data['storage_renewal_date'] = params.license['storage_expiration_date'] + + data['breachwatch'] = params.license.get('breach_watch_enabled', False) + if params.enterprise: + data['reporting_and_alerts'] = params.license.get('audit_and_reporting_enabled', False) + + if verbose: + data['records_count'] = len(params.record_cache) + sf_count = len(params.shared_folder_cache) + if sf_count > 0: + data['shared_folders_count'] = sf_count + team_count = len(params.team_cache) + if team_count > 0: + data['teams_count'] = team_count + else: + data['logged_in'] = False + data['message'] = 'Not logged in' + + return data + def execute(self, params, **kwargs): json_output = kwargs.get('json_output', False) verbose = kwargs.get('verbose', False) @@ -1251,52 +1415,69 @@ def execute(self, params, **kwargs): import json print(json.dumps(data, indent=2)) else: - # Original formatted output + # Clean formatted output + DIM = Fore.WHITE + + def label_value(label, value): + return f" {DIM}{label:>22}{Fore.RESET}: {value}" + + def yes_no(val): + return f"{Fore.GREEN}Yes{Fore.RESET}" if val else f"{Fore.RED}No{Fore.RESET}" + if params.session_token: hostname = get_hostname(params.rest_context.server_base) - print('{0:>20s}: {1:<20s}'.format('User', params.user)) - print('{0:>20s}: {1:<20s}'.format('Server', hostname)) - print('{0:>20s}: {1:<20s}'.format('Data Center', get_data_center(hostname))) + + print() + print(f" {Style.BRIGHT}User Info{Style.RESET_ALL}") + print(f" {DIM}{'─' * 50}{Fore.RESET}") + print(label_value('User', params.user)) + print(label_value('Server', hostname)) + print(label_value('Data Center', get_data_center(hostname))) environment = get_environment(hostname) if environment: - print('{0:>20s}: {1:<20s}'.format('Environment', get_environment(hostname))) + print(label_value('Environment', environment)) if params.license: account_type = params.license['account_type'] if 'account_type' in params.license else None if account_type == 2: - display_admin = 'No' if params.enterprise is None else 'Yes' - print('{0:>20s}: {1:<20s}'.format('Admin', display_admin)) + display_admin = yes_no(params.enterprise is not None) + print(label_value('Admin', display_admin)) - print('') + print() + print(f" {Style.BRIGHT}Account{Style.RESET_ALL}") + print(f" {DIM}{'─' * 50}{Fore.RESET}") account_type_name = 'Enterprise' if account_type == 2 \ else 'Family Plan' if account_type == 1 \ else params.license['product_type_name'] - print('{0:>20s}: {1:<20s}'.format('Account Type', account_type_name)) - print('{0:>20s}: {1:<20s}'.format('Renewal Date', params.license['expiration_date'])) + print(label_value('Account Type', account_type_name)) + print(label_value('Renewal Date', params.license['expiration_date'])) if 'bytes_total' in params.license: - storage_bytes = int(params.license['bytes_total']) # note: int64 in protobuf in python produces string as opposed to an int or long. + storage_bytes = int(params.license['bytes_total']) storage_gb = storage_bytes >> 30 storage_bytes_used = params.license['bytes_used'] if 'bytes_used' in params.license else 0 - print('{0:>20s}: {1:<20s}'.format('Storage Capacity', f'{storage_gb}GB')) - storage_usage = (int(storage_bytes_used) * 100 // storage_bytes) if storage_bytes != 0 else 0 # note: int64 in protobuf in python produces string as opposed to an int or long. - print('{0:>20s}: {1:<20s}'.format('Usage', f'{storage_usage}%')) - print('{0:>20s}: {1:<20s}'.format('Storage Renewal Date', params.license['storage_expiration_date'])) - print('{0:>20s}: {1:<20s}'.format('BreachWatch', 'Yes' if params.license.get('breach_watch_enabled') else 'No')) + print(label_value('Storage Capacity', f'{storage_gb}GB')) + storage_usage = (int(storage_bytes_used) * 100 // storage_bytes) if storage_bytes != 0 else 0 + print(label_value('Usage', f'{storage_usage}%')) + print(label_value('Storage Renewal Date', params.license['storage_expiration_date'])) + print(label_value('BreachWatch', yes_no(params.license.get('breach_watch_enabled')))) if params.enterprise: - print('{0:>20s}: {1:<20s}'.format('Reporting & Alerts', 'Yes' if params.license.get('audit_and_reporting_enabled') else 'No')) + print(label_value('Reporting & Alerts', yes_no(params.license.get('audit_and_reporting_enabled')))) if verbose: - print('') - print('{0:>20s}: {1}'.format('Records', len(params.record_cache))) + print() + print(f" {Style.BRIGHT}Vault Stats{Style.RESET_ALL}") + print(f" {DIM}{'─' * 50}{Fore.RESET}") + print(label_value('Records', str(len(params.record_cache)))) sf_count = len(params.shared_folder_cache) if sf_count > 0: - print('{0:>20s}: {1}'.format('Shared Folders', sf_count)) + print(label_value('Shared Folders', str(sf_count))) team_count = len(params.team_cache) if team_count > 0: - print('{0:>20s}: {1}'.format('Teams', team_count)) + print(label_value('Teams', str(team_count))) if params.enterprise: - print('') - print('{0:>20s}:'.format('Enterprise License')) + print() + print(f" {Style.BRIGHT}Enterprise License{Style.RESET_ALL}") + print(f" {DIM}{'─' * 50}{Fore.RESET}") for x in params.enterprise.get('licenses', []): product_type_id = x.get('product_type_id', 0) tier = x.get('tier', 0) @@ -1313,7 +1494,7 @@ def execute(self, params, **kwargs): plan = 'Unknown' if product_type_id in (5, 10, 12): plan += ' Trial' - print('{0:>20s}: {1}'.format('Base Plan', plan)) + print(label_value('Base Plan', plan)) paid = x.get('paid') is True if paid: exp = x.get('expiration') @@ -1325,12 +1506,15 @@ def execute(self, params, **kwargs): if td > 0: expires += f' (in {td} days)' else: - expires += ' (expired)' - print('{0:>20s}: {1}'.format('Expires', expires)) - print('{0:>20s}: {1}'.format('User Licenses', f'Plan: {x.get("number_of_seats", "")} Active: {x.get("seats_allocated", "")} Invited: {x.get("seats_pending", "")}')) + expires += f' ({Fore.RED}expired{Fore.RESET})' + print(label_value('Expires', expires)) + seats_plan = x.get("number_of_seats", "") + seats_active = x.get("seats_allocated", "") + seats_invited = x.get("seats_pending", "") + print(label_value('User Licenses', f'Plan: {seats_plan} Active: {seats_active} Invited: {seats_invited}')) file_plan = x.get('file_plan') - file_plan_lookup = {x[0]: x[2] for x in constants.ENTERPRISE_FILE_PLANS} - print('{0:>20s}: {1}'.format('Secure File Storage', file_plan_lookup.get(file_plan, ''))) + file_plan_lookup = {fp[0]: fp[2] for fp in constants.ENTERPRISE_FILE_PLANS} + print(label_value('Secure File Storage', file_plan_lookup.get(file_plan, ''))) addons = [] addon_lookup = {a[0]: a[1] for a in constants.MSP_ADDONS} for ao in x.get('add_ons'): @@ -1349,9 +1533,10 @@ def execute(self, params, **kwargs): addon_name += f' ({seats} licenses)' addons.append(addon_name) for i, addon in enumerate(addons): - print('{0:>20s}: {1}'.format('Secure Add Ons' if i == 0 else '', addon)) + print(label_value('Add-ons' if i == 0 else '', addon)) + print() else: - print('{0:>20s}:'.format('Not logged in')) + print(f"\n {Fore.YELLOW}Not logged in{Fore.RESET}\n") class VersionCommand(Command): @@ -1495,12 +1680,13 @@ def execute(self, params, **kwargs): params.password = password new_login = kwargs.get('new_login') is True + skip_sync = kwargs.get('skip_sync') is True try: api.login(params, new_login=new_login) except Exception as exc: logging.warning(str(exc)) - if params.session_token: + if params.session_token and not skip_sync: params.enterprise = None params._pedm_plugin = None SyncDownCommand().execute(params, force=True) @@ -1515,6 +1701,26 @@ def execute(self, params, **kwargs): logging.warning(f'A problem was encountered while updating BreachWatch/security data: {e}') logging.debug(e, exc_info=True) + # Auto-register device for persistent login (stores encrypted data key on server) + try: + loginv3.LoginV3API.register_encrypted_data_key_for_device(params) + except Exception as e: + logging.debug(f'Device registration: {e}') + + # Show post-login message + if params.batch_mode: + # One-shot login from terminal - show simple success message + print() + print(f'{Fore.GREEN}Keeper login successful.{Fore.RESET}') + print(f'Type "{Fore.GREEN}keeper shell{Fore.RESET}" for the interactive shell, "{Fore.GREEN}keeper supershell{Fore.RESET}" for the vault UI,') + print(f'or "{Fore.GREEN}keeper help{Fore.RESET}" to see all available commands.') + print() + else: + # Interactive shell - show full summary with tips + record_count = getattr(params, '_sync_record_count', 0) + breachwatch_count = getattr(params, '_sync_breachwatch_count', 0) + post_login_summary(record_count=record_count, breachwatch_count=breachwatch_count) + class CheckEnforcementsCommand(Command): def get_parser(self): @@ -1584,7 +1790,10 @@ def get_parser(self): return logout_parser def is_authorised(self): - return False + return True + + # Logout needs auth but not sync + skip_sync_on_auth = True def execute(self, params, **kwargs): if msp.current_mc_id: @@ -1643,7 +1852,10 @@ def execute(self, params, **kwargs): sp_url = urllib.parse.urlunparse(sp_url_builder) logging.info('SSO Logout URL\n%s', sp_url) + # Preserve commands queue (clear_session() clears it, but we need 'q' to exit) + saved_commands = list(params.commands) params.clear_session() + params.commands.extend(saved_commands) class EchoCommand(Command): @@ -1685,6 +1897,74 @@ class HelpCommand(Command): def get_parser(self): return help_parser + @staticmethod + def display_basics_help(): + """Display help for the most common commands with detailed explanations""" + from colorama import Fore, Style + + DIM = Fore.WHITE + GRN = Fore.GREEN + + print() + print(f" {Style.BRIGHT}Getting Started - Essential Commands{Style.RESET_ALL}") + print(f" {DIM}{'─' * 60}{Fore.RESET}") + print() + + # Each entry: (command, short_desc, details, examples) + # examples is a list of example strings, or None + basics = [ + ('whoami', 'Display information about the current user', + 'Shows your email, data center, account type, and license info', None), + + ('this-device', 'Display and modify settings of the current device', + 'View/change device name, logout timeout, and persistent login', + ['this-device rename "My Laptop"', 'this-device timeout 7d', 'this-device persistent_login on']), + + ('tree', 'Display the folder structure', + 'Shows a hierarchical view of all folders in your vault', None), + + ('ls', 'List folder contents', + 'Lists records and subfolders in the current folder', + ['ls', 'ls "My Folder"']), + + ('cd', 'Change current folder', + 'Navigate to a different folder in your vault', + ['cd "Business Accounts"', 'cd ..']), + + ('get', 'Get the details of a record/folder/team', + 'Display full details of a record by title or UID', + ['get "Amazon"', 'get AbCdEf123456']), + + ('search', 'Search the vault using a regular expression', + 'Find records matching a search pattern', + ['search amazon', 'search "bank.*account"']), + + ('list-sf', 'List all shared folders', + 'Shows all shared folders you have access to', None), + + ('enterprise-info', 'Display enterprise tenant structure', + 'Shows nodes, users, teams, and roles (Enterprise admins only)', None), + + ('supershell (ss)', 'Launch full terminal vault UI', + 'Opens the interactive TUI with vim-style navigation', None), + ] + + max_cmd_width = max(len(cmd) for cmd, _, _, _ in basics) + indent = ' ' * (max_cmd_width + 4) + + for cmd, short_desc, detail, examples in basics: + print(f" {GRN}{cmd:<{max_cmd_width}}{Fore.RESET} {short_desc}") + print(f" {indent}{DIM}{detail}{Fore.RESET}") + if examples: + for i, ex in enumerate(examples): + prefix = "Examples: " if i == 0 else " " + print(f" {indent}{DIM}{prefix}{GRN}{ex}{Fore.RESET}") + print() + + print(f" {DIM}Type {GRN}help {DIM} for detailed help on any command{Fore.RESET}") + print(f" {DIM}Type {GRN}?{DIM} to see all available commands{Fore.RESET}") + print() + def execute(self, params, **kwargs): help_commands = kwargs.get('command') show_legacy = kwargs.get('legacy', False) @@ -1695,6 +1975,12 @@ def execute(self, params, **kwargs): if isinstance(help_commands, list) and len(help_commands) > 0: cmd = help_commands[0] + + # Handle "help basics" special case + if cmd.lower() == 'basics': + self.display_basics_help() + return + help_commands = help_commands[1:] if cmd in aliases: ali = aliases[cmd] @@ -2066,7 +2352,7 @@ def parse_input_records(): # type: () -> Set[str] api.sync_down(params) if not kwargs.get('quiet'): if num_updated: - logging.info(f'Updated security data for [{num_updated}] record(s)') + logging.info(f'Updated security data for {num_updated} {"record" if num_updated == 1 else "records"}') elif not kwargs.get('suppress_no_op') and not num_to_update: logging.info('No records requiring security-data updates found') diff --git a/keepercommander/commands/verify_records.py b/keepercommander/commands/verify_records.py index 086172596..0f619daaf 100644 --- a/keepercommander/commands/verify_records.py +++ b/keepercommander/commands/verify_records.py @@ -352,7 +352,7 @@ def execute(self, params, **kwargs): if len(records_v2_to_fix) > 0 or len(records_v3_to_fix) > 0: total_records = len(records_v2_to_fix) + len(records_v3_to_fix) - print(f'There are {total_records} record(s) to be corrected') + print(f'There {"is" if total_records == 1 else "are"} {total_records} {"record" if total_records == 1 else "records"} to be corrected') answer = user_choice('Do you want to proceed?', 'yn', 'n') if answer.lower() == 'y': success = 0 @@ -415,9 +415,9 @@ def execute(self, params, **kwargs): failed.append(f'{record_uid}: {status.message}') if success > 0: - logging.info('Successfully corrected %d record(s)', success) + logging.info('Successfully corrected %d %s', success, 'record' if success == 1 else 'records') if len(failed) > 0: - logging.warning('Failed to correct %d record(s)', len(failed)) + logging.warning('Failed to correct %d %s', len(failed), 'record' if len(failed) == 1 else 'records') logging.info('\n'.join(failed)) params.sync_data = True diff --git a/keepercommander/constants.py b/keepercommander/constants.py index 6417c3aec..6b532a99d 100644 --- a/keepercommander/constants.py +++ b/keepercommander/constants.py @@ -323,6 +323,72 @@ def get_cron_month_day(text): # type: (Optional[str]) -> Optional[int] 'GOV': 'govcloud.keepersecurity.us' } +# All valid Keeper server hosts including dev and QA environments +KEEPER_SERVERS = { + # Production + 'US': 'keepersecurity.com', + 'EU': 'keepersecurity.eu', + 'AU': 'keepersecurity.com.au', + 'CA': 'keepersecurity.ca', + 'JP': 'keepersecurity.jp', + 'GOV': 'govcloud.keepersecurity.us', + # Dev environments + 'US_DEV': 'dev.keepersecurity.com', + 'EU_DEV': 'dev.keepersecurity.eu', + 'AU_DEV': 'dev.keepersecurity.com.au', + 'CA_DEV': 'dev.keepersecurity.ca', + 'JP_DEV': 'dev.keepersecurity.jp', + 'GOV_DEV': 'govcloud.dev.keepersecurity.us', + # QA environments + 'US_QA': 'qa.keepersecurity.com', + 'EU_QA': 'qa.keepersecurity.eu', + 'AU_QA': 'qa.keepersecurity.com.au', + 'CA_QA': 'qa.keepersecurity.ca', + 'JP_QA': 'qa.keepersecurity.jp', + 'GOV_QA': 'govcloud.qa.keepersecurity.us', +} + + +def resolve_server(server_input): + """ + Resolve a server input to a valid Keeper host. + + Args: + server_input: Can be a region code (US, EU, GOV_DEV, etc.) or a full hostname + + Returns: + The resolved hostname if valid, None if invalid + + Examples: + resolve_server('US') -> 'keepersecurity.com' + resolve_server('us') -> 'keepersecurity.com' + resolve_server('GOV_DEV') -> 'govcloud.dev.keepersecurity.us' + resolve_server('keepersecurity.eu') -> 'keepersecurity.eu' + resolve_server('foo.com') -> None + """ + if not server_input: + return None + + # Normalize input - uppercase for lookup + server_upper = server_input.upper().replace('-', '_') + + # Check if it's a region code + if server_upper in KEEPER_SERVERS: + return KEEPER_SERVERS[server_upper] + + # Check if it's already a valid hostname (direct match) + server_lower = server_input.lower() + if server_lower in KEEPER_SERVERS.values(): + return server_lower + + # Not a valid server + return None + + +def get_valid_server_codes(): + """Return list of valid server codes for help text""" + return sorted(KEEPER_SERVERS.keys()) + def get_abbrev_by_host(host): # Return abbreviation of the Keeper's public host diff --git a/keepercommander/discovery_common/__version__.py b/keepercommander/discovery_common/__version__.py index 1a72d32e5..7b344eca4 100644 --- a/keepercommander/discovery_common/__version__.py +++ b/keepercommander/discovery_common/__version__.py @@ -1 +1 @@ -__version__ = '1.1.0' +__version__ = '1.1.2' diff --git a/keepercommander/discovery_common/constants.py b/keepercommander/discovery_common/constants.py index e792806fa..ea6e1f855 100644 --- a/keepercommander/discovery_common/constants.py +++ b/keepercommander/discovery_common/constants.py @@ -21,11 +21,31 @@ PAM_USER = "pamUser" LOCAL_USER = "local" +PAM_RESOURCES = [ + PAM_DIRECTORY, + PAM_DATABASE, + PAM_MACHINE +] + +PAM_DOMAIN_CONFIGURATION = "pamDomainConfiguration" +PAM_AZURE_CONFIGURATION = "pamAzureConfiguration" +PAM_AWS_CONFIGURATION = "pamAwsConfiguration" +PAM_NETWORK_CONFIGURATION = "pamNetworkConfiguration" +PAM_GCP_CONFIGURATION = "pamGcpConfiguration" + +PAM_CONFIGURATIONS = [ + PAM_DOMAIN_CONFIGURATION, + PAM_AZURE_CONFIGURATION, + PAM_AWS_CONFIGURATION, + PAM_NETWORK_CONFIGURATION, + PAM_GCP_CONFIGURATION +] + # These are configuration that could domain users. # Azure included because of AADDS. DOMAIN_USER_CONFIGS = [ - "pamDomainConfiguration", - "pamAzureConfiguration" + PAM_DOMAIN_CONFIGURATION, + PAM_AZURE_CONFIGURATION ] # The record types to process. diff --git a/keepercommander/discovery_common/process.py b/keepercommander/discovery_common/process.py index eea654008..2e835112a 100644 --- a/keepercommander/discovery_common/process.py +++ b/keepercommander/discovery_common/process.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging import os -from .constants import PAM_DIRECTORY, PAM_USER, VERTICES_SORT_MAP, LOCAL_USER +from .constants import PAM_DIRECTORY, PAM_USER, VERTICES_SORT_MAP, LOCAL_USER, PAM_CONFIGURATIONS from .jobs import Jobs from .infrastructure import Infrastructure from .record_link import RecordLink @@ -55,6 +55,14 @@ class NoDiscoveryDataException(Exception): class Process: + + """ + Process discovery results + + While this class update the PAM/record linking graph, it does not save it. + + """ + # Warn when bulk record lists exceed this size (potential memory issue) BULK_LIST_WARNING_THRESHOLD = 10000 # Hard limit for bulk record lists (safety mechanism) @@ -144,7 +152,7 @@ def set_user_based_ids(configuration_uid: str, content: DiscoveryObject, parent_ object_id = object_id.split("\\")[1] if parent_content.record_type == PAM_DIRECTORY: domain = parent_content.name - if object_id.endswith(domain) is False: + if not object_id.endswith(domain): object_id += f"@{domain}" else: object_id += parent_content.id @@ -198,7 +206,7 @@ def _update_with_record_uid(self, record_cache: dict, current_vertex: DAGVertex) # If the current vertex is not active, then return. # It won't have a DATA edge. - if current_vertex.active is False: + if not current_vertex.active: return for vertex in current_vertex.has_vertices(): @@ -236,16 +244,18 @@ def _prepare_record(record_prepare_func: Callable, content: DiscoveryObject, parent_content: DiscoveryObject, vertex: DAGVertex, + admin_uid: Optional[str] = None, context: Optional[Any] = None) -> DiscoveryObject: """ Prepare a record to be added. - :param record_prepare_func: - :param bulk_add_records: - :param content: - :param parent_content: - :param vertex: - :param context: + :param record_prepare_func: Function to call to prepare a record to be created. + :param bulk_add_records: List of records to be added. + :param content: Discovery content of the current discovery item. + :param parent_content: Discovery content of the parent of the current discovery item. + :param vertex: Infrastructure vertex of the current discovery item. + :params admin_uid: If resource, if there is an admin, this is the UID of that PAM User + :param context: The context; dictionary of random instances. :return: """ @@ -268,7 +278,8 @@ def _prepare_record(record_prepare_func: Callable, record_type=content.record_type, record_uid=record_uid, parent_record_uid=parent_record_uid, - shared_folder_uid=content.shared_folder_uid + shared_folder_uid=content.shared_folder_uid, + admin_uid=admin_uid ) ) @@ -358,7 +369,7 @@ def _directory_exists(self, domain: str, directory_info_func: Callable, context: self.logger.debug(f"search for directories: {', '.join(domains)}") # Some providers provider directory type services. - # They can also provide mulitple domains + # They can also provide multiple domains provider_vertices = self.infra.dag.search_content({ "record_type": ["pamAzureConfiguration", "pamDomainConfiguration"], }, ignore_case=True) @@ -371,9 +382,9 @@ def _directory_exists(self, domain: str, directory_info_func: Callable, context: if domain.lower() in provider_domain.lower(): found = True break - if found is True: + if found: break - if found is True: + if found: found_provider_directories.append(provider_vertex) if len(found_provider_directories) > 0: return found_provider_directories @@ -412,7 +423,7 @@ def _find_directory_user(self, find_dn: Optional[str] = None) -> Optional[DirectoryUserResult]: # If the passed in results were a DirectoryInfo then check the Vault for users. - if isinstance(results, DirectoryInfo) is True: + if isinstance(results, DirectoryInfo): self.logger.debug("search for directory user from vault records") self.logger.debug(f"have {len(results.directory_user_record_uids)} users") for user_record_id in results.directory_user_record_uids: @@ -532,66 +543,6 @@ def _record_link_directory_users(self, else: self.logger.debug("could not find user vertex") - def _record_link_user_to_directories(self, - directory_vertex: DAGVertex, - directory_content: DiscoveryObject, - user_content: DiscoveryObject, - directory_info_func: Callable, - context: Optional[Any] = None): - - """ - Connect a user to all the directories for a domain. - - Directories may be in the vault or in the discovery graph. - The first step is to get all vault directories. - - """ - - self.logger.debug("resource is directory and we are a user; handle record links to others") - - record_link = context.get("record_link") # type: RecordLink - - # Get the directory user record UIDs from the vault that belong to directories using the same domain. - # We can skip getting directory users. - directory_record_uids = [] - directory_info = directory_info_func( - domain=directory_content.name, - skip_users=True, - context=context - ) # type: DirectoryInfo - if directory_info is not None: - directory_record_uids = directory_info.directory_record_uids - - self.logger.debug(f"found {len(directory_record_uids)} directories in records.") - - # Check our current discovery data. - # This is a delta, it will not contain discovery from prior runs. - # This will only contain objects in this run. - # Make sure the object is a directory and the domain is the same. - # Also make sure there is a record UID; it might not be added yet. - for parent_vertex in directory_vertex.belongs_to_vertices(): - self.logger.debug("finding directories in discovery vertices") - for child_vertex in parent_vertex.has_vertices(): - try: - other_directory_content = DiscoveryObject.get_discovery_object(child_vertex) - self.logger.debug(f"{other_directory_content.record_type}, {other_directory_content.name}, " - f"{directory_content.name}, {other_directory_content.record_uid}") - if (other_directory_content.record_type != PAM_DIRECTORY or - other_directory_content.name != directory_content.name): - continue - if (other_directory_content.record_uid is not None and - other_directory_content.record_uid not in directory_record_uids): - self.logger.debug(f" * adding {other_directory_content.record_uid}") - directory_record_uids.append(other_directory_content.record_uid) - except Exception as err: - self.logger.debug(f"could not link user to directory {directory_content.name}: {err}") - - self.logger.debug(f"found {len(directory_record_uids)} directories in records and discovery data.") - - for directory_record_uid in directory_record_uids: - if record_link.get_acl(user_content.record_uid, directory_record_uid) is None: - record_link.belongs_to(user_content.record_uid, directory_record_uid, acl=UserAcl.default()) - def _find_admin_directory_user(self, domain: str, admin_acl: UserAcl, @@ -619,7 +570,7 @@ def _find_admin_directory_user(self, # No need to create a record, just link, belongs_to is False # Since we are using records, just the belongs_to method instead of # discovery_belongs_to. - if isinstance(directory_user, NormalizedRecord) is True: + if isinstance(directory_user, NormalizedRecord): admin_acl.belongs_to = False return directory_user.record_uid else: @@ -637,6 +588,8 @@ def _find_admin_directory_user(self, if admin_content.record_uid is not None: admin_acl.belongs_to = False return admin_content.record_uid + + return None else: raise UserNotFoundException(f"Could not find the directory user in domain {domain}") else: @@ -650,16 +603,12 @@ def _process_auto_add_level(self, record_prepare_func: Callable, directory_info_func: Callable, record_cache: dict, - smart_add: bool = False, - add_all: bool = False, context: Optional[Any] = None): """ This method will add items to the bulk_add_records queue to be added by the client. - Items are added because: - * Smart Add is enabled, and the resource was logged into with credentials. - * The rule engine flagged an item as ADD + These are items where the rule engine has flagged them to be added. :param current_vertex: The current/parent discovery vertex. :param bulk_add_records: List of records to be added. @@ -668,25 +617,23 @@ def _process_auto_add_level(self, :param record_prepare_func: Function to convert content into an unsaved record. :param directory_info_func: Function to lookup directories. :param record_cache: - :param smart_add: Add the resource record if the admin exists. - :param add_all: Just add the record. This is not the params from Commander. :param context: Client context; could be anything. :return: """ - if current_vertex.active is False: + if not current_vertex.active: self.logger.debug(f"vertex {current_vertex.uid} is not active, skip") return # Check if this vertex has a record. # We cannot add child vertices to a vertex that does not have a record. - current_content = current_vertex.content_as_object(DiscoveryObject) + current_content = DiscoveryObject.get_discovery_object(current_vertex) if current_content.record_uid is None: self.logger.debug(f"vertex {current_content.uid} does not have a record id") return self.logger.debug(f"Current Vertex: {current_content.record_type}, {current_vertex.uid}, " - f"{current_content.name}, smart add {smart_add}, add all {add_all}") + f"{current_content.name}") # Sort all the vertices under the current vertex. # Return a dictionary where the record type is the key. @@ -698,8 +645,8 @@ def _process_auto_add_level(self, self.logger.debug(f" processing {record_type}") for vertex in record_type_to_vertices_map[record_type]: - content = DiscoveryObject.get_discovery_object(vertex) - self.logger.debug(f" child vertex {vertex.uid}, {content.name}") + child_content = DiscoveryObject.get_discovery_object(vertex) + self.logger.debug(f" child vertex {vertex.uid}, {child_content.name}") # If we are going to add an admin user, this is the default ACL # This is for the smart add feature @@ -708,10 +655,10 @@ def _process_auto_add_level(self, # This ACL is None for resource, and populated for users. default_acl = None - if content.record_type == PAM_USER: + if child_content.record_type == PAM_USER: default_acl = self._default_acl( discovery_vertex=vertex, - content=content, + content=child_content, discovery_parent_vertex=current_vertex) # Check for a vault record, if it exists. @@ -720,16 +667,17 @@ def _process_auto_add_level(self, # We are doing this because the record might be an active directory user, that we have # not created a record for yet, however it might have been assigned a record UID from a prior prompt. - existing_record = content.record_exists + existing_record = child_content.record_exists if record_lookup_func is not None: check_the_vault = True for item in bulk_add_records: - if item.record_uid == content.record_uid: + if item.record_uid == child_content.record_uid: self.logger.debug(f" record is in the bulk add list, do not check the vault if exists") check_the_vault = False break - if check_the_vault is True: - existing_record = record_lookup_func(record_uid=content.record_uid, context=context) is not None + if check_the_vault: + existing_record = record_lookup_func(record_uid=child_content.record_uid, + context=context) is not None self.logger.debug(f" record exists in the vault: {existing_record}") else: self.logger.debug(f" record lookup function not defined, record existing: {existing_record}") @@ -737,247 +685,58 @@ def _process_auto_add_level(self, # Determine if we are going to add the item. # If the item has a record UID already, we don't need to add. add_record = False - add_all_users = False - if content.record_exists is False: - - ################################################################################################# - # - # RULE ENGINE ADD - - if content.action_rules_result == RuleActionEnum.ADD.value: - self.logger.debug(f" vertex {vertex.uid} had an ADD result for the rule engine, auto add") - add_record = True - - ################################################################################################# - # - # SMART ADD - - # If we are using smart add and the there was an admin user, add it. - elif smart_add is True and content.access_user is not None and content.record_type != PAM_USER: - self.logger.debug(f" resource has credentials, and using smart add") - add_record = True - add_all_users = True - - ################################################################################################# - # - # ADD ALL FLAG (not Commander's) - - # If add_all is set, then add it. - # This is normally used with smart_add to add the resource's users. - elif add_all is True: - # If the current content/parent is not a Directory - # and the content is a User and the source is not 'local' user, - # then don't add the user. - # We don't want an AD user to belongs_to a machine. - if (current_content.record_type != PAM_DIRECTORY - and content.record_type == PAM_USER - and content.item.source != LOCAL_USER): - add_record = False + if (child_content.record_exists is False + and child_content.action_rules_result == RuleActionEnum.ADD.value): + self.logger.debug(f" vertex {vertex.uid} had an ADD result for the rule engine, auto add") + add_record = True + + if add_record: + + self.logger.debug(f"adding resource record") + + # For a resource, the ACL will be None. + # It will a UserAcl if a user. + self.record_link.belongs_to(child_content.record_uid, current_content.record_uid, acl=default_acl) + + admin_uid = None + # If the rules have set the admin_uid then connect the user to the resource. + if (child_content.admin_uid is not None + and child_content.record_type != PAM_USER + and record_lookup_func is not None): + + self.logger.debug("the admin UID has been set for this resource") + + admin_record = record_lookup_func(record_uid=child_content.admin_uid, + context=context) # type: NormalizedRecord + if admin_record is not None and admin_record.record_type == PAM_USER: + self.logger.debug("was able to find the admin record, connect to resource") + admin_uid = child_content.admin_uid + admin_acl.is_admin = True + self.record_link.belongs_to(child_content.admin_uid, child_content.record_uid, + acl=admin_acl) else: - self.logger.debug(f" items is a user, add all is True, adding record") - add_record = True - - if add_record is True: - - # If we can create an admin user record, then the admin_user_record_uid will be populated. - admin_user_record_uid = None - admin_content = None - admin_vertex = None - - # If this is a resource, then auto add the admin user if one exists. - # In this scenario ... - # There is a rule to auto add. - # A credential was passed to discovery and it worked. - # Along with the resource, auto create the admin user. - # First we need to make sure the current record type is a resource and logged in. - if smart_add is True and content.access_user is not None: - - # Get the username and DN. - # Lowercase them for the comparison. - access_username_and_domain = content.access_user.user - access_username = access_username_and_domain - access_domain = None - if access_username_and_domain is not None: - access_username_and_domain = access_username_and_domain.lower() - access_username, access_domain = split_user_and_domain(access_username_and_domain) - - # We want to pay attention to the admin source. - # The users from the user list might not contain a source. - # For example, Linux PAM that are remote users will not have a domain in their username. - admin_source = content.access_user.source - - # If the admin source is the current directory name, then it local to the resource (directory). - if content.record_type == PAM_DIRECTORY and content.name == admin_source: - self.logger.debug(" change source to local for directory user") - admin_source = LOCAL_USER - - access_dn = content.access_user.dn - if access_dn is not None: - access_dn = access_dn.lower() - - # Go through the users to find the administrative user. - found_user_in_discovery_user_list = False - for user_vertex in vertex.has_vertices(): - - user_content = DiscoveryObject.get_discovery_object(user_vertex) - if user_content.record_type != PAM_USER: - continue - - # Get the user from the content. - # We want to use the full username and also one without the domain, if there is a domain. - user_and_domain = user_content.item.user - user = user_and_domain - domain = None - if user_and_domain is not None: - user_and_domain = user_and_domain.lower() - user, domain = split_user_and_domain(user_and_domain) - if user is None: - continue - - # Get the dn, if it exists. - dn = user_content.item.dn - if dn is not None: - dn = dn.lower() - - if (access_username_and_domain == user_and_domain - or access_username_and_domain == user - or access_username == user - or access_dn == dn): - - self.logger.debug(" access user matches the current user") - self.logger.debug(f" access user source is {user_content.item.source}") - - # If the user has a record UID, it has already been created. - # This means the record already belongs to another resource, so belongs_to is False. - if user_content.record_uid is not None: - self.logger.debug(" user has a record uid, add this user as admin") - admin_acl.belongs_to = False - admin_user_record_uid = user_content.record_uid - found_user_in_discovery_user_list = True - break - - # Is this user a local user? - # If so prepare a record and link it. Since its local belongs_to is True - if admin_source == LOCAL_USER or admin_source is None: - - self.logger.debug(" user is new local user, add this user as admin") - admin_acl.belongs_to = True - admin_content = user_content - admin_vertex = user_vertex - found_user_in_discovery_user_list = True - break - - # The user is a remote user. - else: - self.logger.debug(" check directory for remote user") - domain = content.access_user.source - if content.record_type == PAM_DIRECTORY: - domain = content.name - - try: - admin_user_record_uid = self._find_admin_directory_user( - domain=domain, - admin_acl=admin_acl, - directory_info_func=directory_info_func, - record_lookup_func=record_lookup_func, - context=context, - user=access_username, - dn=access_dn - ) - self.logger.debug(" found directory user for admin") - found_user_in_discovery_user_list = True - except (DirectoryNotFoundException, UserNotFoundException) as err: - # Not an error. - # Just could not find the directory or directory user. - self.logger.debug(f" did not find the directory user: {err}") - - self.logger.debug("done checking user list") - - # If the user_record_uid is None, and it's a domain user, and we didn't find a user - # then there is chance that it's dirctory user not picked up while getting users in - # discovery. - # This is similar to the remote user code above, except the access user was not found in - # the user list. - if (found_user_in_discovery_user_list is False and admin_user_record_uid is None - and access_domain is not None): - self.logger.debug("could not find admin user in the user list, " - "attempt to find in directory") - try: - admin_user_record_uid = self._find_admin_directory_user( - domain=access_domain, - admin_acl=admin_acl, - directory_info_func=directory_info_func, - record_lookup_func=record_lookup_func, - context=context, - user=access_username, - dn=access_dn - ) - except (DirectoryNotFoundException, UserNotFoundException): - # Not an error. - # Just could not find the directory or directory user. - pass - - # Create the record if we are not using smart add. - # If we are using smart add, only added if we could make an admin record. - if smart_add is False or (smart_add is True - and (admin_user_record_uid is not None or admin_content is not None)): - - self.logger.debug(f"adding resource record, smart add {smart_add}") - # The record could be a resource or user record. - self._prepare_record( - record_prepare_func=record_prepare_func, - bulk_add_records=bulk_add_records, - content=content, - parent_content=current_content, - vertex=vertex, - context=context - ) - if content.record_uid is None: - raise Exception(f"the record uid is blank for {content.description} after prepare") - - # For a resource, the ACL will be None. - # It will a UserAcl if a user. - self.record_link.belongs_to(content.record_uid, current_content.record_uid, acl=default_acl) - - # user_record_uid will only be populated if using smart add. - # Link the admin user to the resource. - if admin_user_record_uid is not None or admin_content is not None: + self.logger.info(f"The PAM User record {child_content.admin_uid} does not exists. " + "Cannot set the administrator for an auto added " + f"record {child_content.title}.") - if admin_content is not None: - self.logger.debug("the admin record does not exists, create it") - - # Create the local admin here since we need the resource record added. - self._prepare_record( - record_prepare_func=record_prepare_func, - bulk_add_records=bulk_add_records, - content=admin_content, - parent_content=content, - vertex=admin_vertex, - context=context - ) - if admin_content.record_uid is None: - raise Exception(f"the record uid is blank for {admin_content.description} " - "after prepare") - - admin_user_record_uid = admin_content.record_uid - - self.logger.debug("connecting admin user to resource") - self.record_link.belongs_to(admin_user_record_uid, content.record_uid, acl=admin_acl) + # The record could be a resource or user record. + self._prepare_record( + record_prepare_func=record_prepare_func, + bulk_add_records=bulk_add_records, + content=child_content, + parent_content=current_content, + vertex=vertex, + context=context, + admin_uid=admin_uid + ) + if child_content.record_uid is None: + raise Exception(f"the record uid is blank for {child_content.description} after prepare") # If the record type is a PAM User, we don't need to go deeper. # In the future we might need to change if PAM User becomes a branch and not a leaf. # This is for safety reasons - if content.record_type != PAM_USER: + if child_content.record_type != PAM_USER: # Process the vertices that belong to the current vertex. - - next_smart_add = smart_add - if add_all_users is True: - add_all = True - if add_all is True: - self.logger.debug("turning off smart add since add_all is enabled") - next_smart_add = False - self.logger.debug(f"smart add = {next_smart_add}, add all = {add_all}") - self._process_auto_add_level( current_vertex=vertex, bulk_add_records=bulk_add_records, @@ -986,19 +745,22 @@ def _process_auto_add_level(self, record_prepare_func=record_prepare_func, directory_info_func=directory_info_func, record_cache=record_cache, - - # Use the value of smart_add if add_all is False. - # If add_all is True, we don't have to run it through the logic, we are going add a record. - smart_add=next_smart_add, - - # If we could access a resource, add all it's users. - add_all=add_all_users, context=context ) self.logger.debug(f" finished auto add processing {record_type}") self.logger.debug(f" Finished auto add current Vertex: {current_vertex.uid}, {current_content.name}") + @staticmethod + def _apply_admin_uid(bulk_add_records: List[BulkRecordAdd], + resource_uid: str, + admin_uid: str): + + for item in bulk_add_records: + if item.record_uid == resource_uid: + item.admin_uid = admin_uid + break + def _process_level(self, current_vertex: DAGVertex, bulk_add_records: List[BulkRecordAdd], @@ -1030,13 +792,13 @@ def _process_level(self, :return: """ - if current_vertex.active is False: + if not current_vertex.active: self.logger.debug(f"vertex {current_vertex.uid} is not active, skip") return # Check if this vertex has a record. # We cannot add child vertices to a vertex that does not have a record. - current_content = current_vertex.content_as_object(DiscoveryObject) + current_content = DiscoveryObject.get_discovery_object(current_vertex) if current_content.record_uid is None: self.logger.debug(f"vertex {current_content.uid} does not have a record id") return @@ -1054,14 +816,14 @@ def _process_level(self, self.logger.debug(f" processing {record_type}") for vertex in record_type_to_vertices_map[record_type]: - content = DiscoveryObject.get_discovery_object(vertex) - self.logger.debug(f" child vertex {vertex.uid}, {content.name}") + child_content = DiscoveryObject.get_discovery_object(vertex) + self.logger.debug(f" child vertex {vertex.uid}, {child_content.name}") default_acl = None - if content.record_type == PAM_USER: + if child_content.record_type == PAM_USER: default_acl = self._default_acl( discovery_vertex=vertex, - content=content, + content=child_content, discovery_parent_vertex=current_vertex) # Check for a vault record, if it exists. @@ -1070,16 +832,17 @@ def _process_level(self, # We are doing this because the record might be an active directory user, that we have # not created a record for yet, however it might have been assigned a record UID from a prior prompt. - existing_record = content.record_exists + existing_record = child_content.record_exists if record_lookup_func is not None: check_the_vault = True for item in bulk_add_records: - if item.record_uid == content.record_uid: + if item.record_uid == child_content.record_uid: self.logger.debug(f" record is in the bulk add list, do not check the vault if exists") check_the_vault = False break - if check_the_vault is True: - existing_record = record_lookup_func(record_uid=content.record_uid, context=context) is not None + if check_the_vault: + existing_record = record_lookup_func(record_uid=child_content.record_uid, + context=context) is not None self.logger.debug(f" record exists in the vault: {existing_record}") else: self.logger.debug(f" record lookup function not defined, record existing: {existing_record}") @@ -1093,14 +856,14 @@ def _process_level(self, # If the rule engine result is to ignore this object, then continue. # This normally would not happen since discovery wouldn't add the object. # However, make sure we skip any object where the rule engine action is to ignore the object. - elif content.action_rules_result == RuleActionEnum.IGNORE.value: + elif child_content.action_rules_result == RuleActionEnum.IGNORE.value: self.logger.debug(f" vertex {vertex.uid} had a IGNORE result for the rule engine, " "skip processing") # If the rule engine result is to ignore this object, then continue. continue # If this flag is set, the user set the ignore_object flag when prompted. - elif content.ignore_object is True: + elif child_content.ignore_object: self.logger.debug(f" vertex {vertex.uid} was flagged as ignore, skip processing") # If the ignore_object flag is set, then continue. continue @@ -1113,7 +876,7 @@ def _process_level(self, # If not, prompt the user if they want to add this user as the admin. # The returned ACL will have the is_admin flag set to True if they do. resource_has_admin = False - if content.record_type == PAM_USER: + if child_content.record_type == PAM_USER: resource_has_admin = (self.record_link.get_admin_record_uid(current_content.record_uid) is not None) self.logger.debug(f"resource has an admin is {resource_has_admin}") @@ -1121,8 +884,8 @@ def _process_level(self, # If the current resource does not allow an admin, then it has and admin, it's just controlled by # us. # This is going to be a resource record, or a configuration record. - if hasattr(current_content.item, "allows_admin") is True: - if current_content.item.allows_admin is False: + if hasattr(current_content.item, "allows_admin"): + if not current_content.item.allows_admin: self.logger.debug(f"resource allows an admin is {current_content.item.allows_admin}") resource_has_admin = True else: @@ -1132,7 +895,7 @@ def _process_level(self, result = prompt_func( vertex=vertex, parent_vertex=current_vertex, - content=content, + content=child_content, acl=default_acl, resource_has_admin=resource_has_admin, indent=indent, @@ -1163,22 +926,26 @@ def _process_level(self, # Use the content from the prompt. # The user may have modified it. - content = result.content + add_content = result.content acl = result.acl + # If the current content is a PAM configuration and the child content/add content is a PAM User; + # then the user is an IAM user. + if current_content.record_type in PAM_CONFIGURATIONS and add_content.record_type == PAM_USER: + acl.is_iam_user = True + # The record could be a resource or user record. - # The content + # add_content will have record UID after this. self._prepare_record( record_prepare_func=record_prepare_func, bulk_add_records=bulk_add_records, - content=content, + content=add_content, parent_content=current_content, vertex=vertex, context=context ) - # Update the DATA edge for this vertex. - # vertex.add_data(content) + admin_uid = None # Make a record link. # The acl will be None if not a pamUser. @@ -1187,25 +954,48 @@ def _process_level(self, # If the object is NOT a pamUser and the resource allows an admin. # Prompt the user to create an admin. should_prompt_for_admin = True - self.logger.debug(f" added record type was {content.record_type}") - if (content.record_type != PAM_USER and content.item.allows_admin is True and + self.logger.debug(f" added record type was {add_content.record_type}") + if (add_content.record_type != PAM_USER and add_content.item.allows_admin is True and prompt_admin_func is not None): + self.logger.debug("checking if can add admin") + + # If the rule engine sets the admin UID + if child_content.admin_uid is not None and record_lookup_func is not None: + + self.logger.debug(f"the resource rule set the admin uid to {add_content.admin_uid}") + + admin_record = record_lookup_func(record_uid=add_content.admin_uid, + context=context) # type: NormalizedRecord + if admin_record is not None and admin_record.record_type == PAM_USER: + self.logger.debug("was able to find the admin record, connect to resource") + + admin_uid = add_content.admin_uid + # admin_acl = UserAcl.default() + # admin_acl.is_admin = True + # self.record_link.belongs_to(child_content.admin_uid, child_content.record_uid, + # acl=admin_acl) + should_prompt_for_admin = False + else: + self.logger.info(f"The PAM User record {child_content.admin_uid} does not exists. " + "Cannot set the administrator for an auto added " + f"record {child_content.title}.") + # This block checks to see if the admin is a directory user that exists. # We don't want to prompt the user for an admin if we have one already. - if content.access_user is not None and content.access_user.user is not None: + elif add_content.access_user is not None and add_content.access_user.user is not None: self.logger.debug(" for this resource, credentials were provided.") - self.logger.error(f" {content.access_user.user}, {content.access_user.dn}, " - f"{content.access_user.password}") + self.logger.error(f" {add_content.access_user.user}, {add_content.access_user.dn}, " + f"{add_content.access_user.password}") # Check if this user is a directory users, first check the source. # If local, check the username incase the domain in part of the username. - source = content.access_user.source - if content.record_type == PAM_DIRECTORY: - source = content.name + source = add_content.access_user.source + if add_content.record_type == PAM_DIRECTORY: + source = add_content.name elif source == LOCAL_USER: - _, domain = split_user_and_domain(content.access_user.user) + _, domain = split_user_and_domain(add_content.access_user.user) if domain is not None: source = domain @@ -1215,35 +1005,35 @@ def _process_level(self, acl = UserAcl.default() acl.is_admin = True - admin_record_uid = None try: - admin_record_uid = self._find_admin_directory_user( + admin_uid = self._find_admin_directory_user( domain=source, admin_acl=acl, directory_info_func=directory_info_func, record_lookup_func=record_lookup_func, context=context, - user=content.access_user.user, - dn=content.access_user.dn + user=add_content.access_user.user, + dn=add_content.access_user.dn ) + + if admin_uid is not None: + self.logger.debug(" found directory user admin, connect to resource") + # self.record_link.belongs_to(admin_uid, add_content.record_uid, acl=acl) + should_prompt_for_admin = False + else: + self.logger.debug(" did not find the directory user for the admin, " + "prompt the user") except DirectoryNotFoundException: self.logger.debug(f" directory {source} was not found for admin user") except UserNotFoundException: self.logger.debug(f" directory user was not found in directory {source}") - if admin_record_uid is not None: - self.logger.debug(" found directory user admin, connect to resource") - self.record_link.belongs_to(admin_record_uid, content.record_uid, acl=acl) - should_prompt_for_admin = False - else: - self.logger.debug(" did not find the directory user for the admin, " - "prompt the user") - - if should_prompt_for_admin is True: + + if should_prompt_for_admin: self.logger.debug(f" prompt for admin user") - self._process_admin_user( + admin_uid = self._process_admin_user( resource_vertex=vertex, - resource_content=content, + resource_content=add_content, bulk_add_records=bulk_add_records, bulk_convert_records=bulk_convert_records, record_lookup_func=record_lookup_func, @@ -1254,12 +1044,22 @@ def _process_level(self, context=context ) + # If we have an admin UID, add it to the last bulk record. + # It will be the one we added above. + if admin_uid is not None: + + self._apply_admin_uid( + bulk_add_records=bulk_add_records, + resource_uid=add_content.record_uid, + admin_uid=admin_uid + ) + items_left -= 1 # If the record type is a PAM User, we don't need to go deeper. # In the future we might need to change if PAM User becomes a branch and not a leaf. # This is for safety reasons - if content.record_type != PAM_USER: + if child_content.record_type != PAM_USER: # Process the vertices that belong to the current vertex. self._process_level( current_vertex=vertex, @@ -1289,229 +1089,228 @@ def _process_admin_user(self, prompt_admin_func: Callable, record_prepare_func: Callable, indent: int = 0, - context: Optional[Any] = None): - - # Find the record UID that admins this resource. - # If it is None, there is a user vertex that has an ACL with is_admin with a true value. - record_uid = self.record_link.get_record_uid(resource_vertex) - admin = self.record_link.get_admin_record_uid(record_uid) - if admin is None: - - # If the access_user is None, create an empty one. - # We will need this below when adding values to the fields. - if resource_content.access_user is None: - resource_content.access_user = DiscoveryUser() - - # Initialize a discovery object for the admin user. - # The PLACEHOLDER will be replaced after the admin user prompt. - - values = {} - for field in ["user", "password", "private_key", "dn", "database"]: - value = getattr(resource_content.access_user, field) - if value is None: - value = [] - else: - value = [value] - values[field] = value - - managed = [False] - if resource_content.access_user.managed is not None: - managed = [resource_content.access_user.managed] - - admin_content = DiscoveryObject( - uid="PLACEHOLDER", - object_type_value="users", - parent_record_uid=resource_content.record_uid, - record_type=PAM_USER, - id="PLACEHOLDER", - name="PLACEHOLDER", - description=resource_content.description + ", Administrator", - title=resource_content.title + ", Administrator", - item=DiscoveryUser( - user="PLACEHOLDER" - ), - fields=[ - RecordField(type="login", label="login", value=values["user"], required=True), - RecordField(type="password", label="password", value=values["password"], required=False), - RecordField(type="secret", label="privatePEMKey", value=values["private_key"], required=False), - RecordField(type="text", label="distinguishedName", value=values["dn"], required=False), - RecordField(type="text", label="connectDatabase", value=values["database"], required=False), - RecordField(type="checkbox", label="managed", value=managed, required=False), - ] - ) + context: Optional[Any] = None) -> Optional[str]: - admin_acl = UserAcl.default() - admin_acl.is_admin = True + # If the access_user is None, create an empty one. + # We will need this below when adding values to the fields. + if resource_content.access_user is None: + resource_content.access_user = DiscoveryUser() - # Prompt to add an admin user to this resource. - # We are not passing an ACL instance. - # We'll make it based on if the user is adding a new record or linking to an existing record. - admin_result = prompt_admin_func( - parent_vertex=resource_vertex, - content=admin_content, - acl=admin_acl, - bulk_convert_records=bulk_convert_records, - indent=indent, - context=context - ) + # Initialize a discovery object for the admin user. + # The PLACEHOLDER will be replaced after the admin user prompt. - # If the action is to ADD, replace the PLACEHOLDER data. - if admin_result.action == PromptActionEnum.ADD: - self.logger.debug("adding admin user") - - source = "local" - if resource_content.record_type == PAM_DIRECTORY: - source = resource_content.name - - admin_record_uid = admin_result.record_uid - - if admin_record_uid is None: - admin_content = admin_result.content - - # With the result, we can fill in information in the object item. - admin_content.item.user = admin_content.get_field_value("login") - admin_content.item.password = admin_content.get_field_value("password") - admin_content.item.private_key = admin_content.get_field_value("privatePEMKey") - admin_content.item.dn = admin_content.get_field_value("distinguishedName") - admin_content.item.database = admin_content.get_field_value("connectDatabase") - admin_content.item.managed = value_to_boolean( - admin_content.get_field_value("managed")) or False - admin_content.item.source = source - admin_content.name = admin_content.item.user - - self.logger.debug(f"added admin user from content") - - if admin_content.item.user is None or admin_content.item.user == "": - raise ValueError("The user name is missing or is blank. Cannot create the administrator user.") - - if admin_content.name is not None: - admin_content.description = (resource_content.description + ", User " + - admin_content.name) - - # We need to populate the id and uid of the content, now that we have data in the content. - self.populate_admin_content_ids(admin_content, resource_vertex) - - ad_user, ad_domain = split_user_and_domain(admin_content.item.user) - if ad_domain is not None and admin_content.item.source == LOCAL_USER: - self.logger.debug("The admin is an directory user, but the source is set to a local user") - - found_admin_record_uid = None - try: - found_admin_record_uid = self._find_admin_directory_user( - domain=ad_domain, - admin_acl=admin_acl, - directory_info_func=directory_info_func, - record_lookup_func=record_lookup_func, - context=context, - user=admin_content.item.user, - dn=admin_content.item.dn - ) - except DirectoryNotFoundException: - self.logger.debug(f" directory {source} was not found for admin user") - except UserNotFoundException: - self.logger.debug(f" directory user was not found in directory {source}") - - if found_admin_record_uid is not None: - self.logger.debug(" found directory user admin, connect to resource") - found_admin_vertices = self.infra.dag.search_content({"record_uid": found_admin_record_uid}) - if len(found_admin_vertices) == 1: - found_admin_vertices[0].belongs_to(resource_vertex, edge_type=EdgeType.KEY) - self.record_link.belongs_to(found_admin_record_uid, resource_content.record_uid, - acl=admin_acl) - return - - # Does an admin vertex already exist for this user? - # This most likely user on the gateway, since without a resource record users can be discovered. - # If we did find it, get the content for the admin; we really want any existing record uid. - admin_vertex = self.infra.dag.get_vertex(admin_content.uid) - if admin_vertex is not None and admin_vertex.active is True and admin_vertex.has_data is True: - self.logger.debug("admin exists in the graph") - found_content = DiscoveryObject.get_discovery_object(admin_vertex) - admin_record_uid = found_content.record_uid - else: - self.logger.debug("admin does not exists in the graph") + values = {} + for field in ["user", "password", "private_key", "dn", "database"]: + value = getattr(resource_content.access_user, field) + if value is None: + value = [] + else: + value = [value] + values[field] = value + + managed = [False] + if resource_content.access_user.managed is not None: + managed = [resource_content.access_user.managed] + + admin_content = DiscoveryObject( + uid="PLACEHOLDER", + object_type_value="users", + parent_record_uid=resource_content.record_uid, + record_type=PAM_USER, + id="PLACEHOLDER", + name="PLACEHOLDER", + description=resource_content.description + ", Administrator", + title=resource_content.title + ", Administrator", + item=DiscoveryUser( + user="PLACEHOLDER" + ), + fields=[ + RecordField(type="login", label="login", value=values["user"], required=True), + RecordField(type="password", label="password", value=values["password"], required=False), + RecordField(type="secret", label="privatePEMKey", value=values["private_key"], required=False), + RecordField(type="text", label="distinguishedName", value=values["dn"], required=False), + RecordField(type="text", label="connectDatabase", value=values["database"], required=False), + RecordField(type="checkbox", label="managed", value=managed, required=False), + ] + ) - # If there is a record UID for the admin user, connect it. - if admin_record_uid is not None: - self.logger.debug("the admin has a record UID") + admin_acl = UserAcl.default() + admin_acl.is_admin = True + + # Prompt to add an admin user to this resource. + # We are not passing an ACL instance. + # We'll make it based on if the user is adding a new record or linking to an existing record. + admin_result = prompt_admin_func( + parent_vertex=resource_vertex, + content=admin_content, + acl=admin_acl, + bulk_convert_records=bulk_convert_records, + indent=indent, + context=context + ) - # If the admin record does not belong to another resource, make this resource its owner. - if self.record_link.get_parent_record_uid(admin_record_uid) is None: - self.logger.debug("the admin does not belong to another resourse, " - "setting it belong to this resource") - admin_acl.belongs_to = True + # If the action is to ADD, replace the PLACEHOLDER data. + if admin_result.action == PromptActionEnum.ADD: + self.logger.debug("adding admin user") + + source = "local" + if resource_content.record_type == PAM_DIRECTORY: + source = resource_content.name + + admin_record_uid = admin_result.record_uid + + if admin_record_uid is None: + admin_content = admin_result.content + + # With the result, we can fill in information in the object item. + admin_content.item.user = admin_content.get_field_value("login") + admin_content.item.password = admin_content.get_field_value("password") + admin_content.item.private_key = admin_content.get_field_value("privatePEMKey") + admin_content.item.dn = admin_content.get_field_value("distinguishedName") + admin_content.item.database = admin_content.get_field_value("connectDatabase") + admin_content.item.managed = value_to_boolean( + admin_content.get_field_value("managed")) or False + admin_content.item.source = source + admin_content.name = admin_content.item.user + + self.logger.debug(f"added admin user from content") + + if admin_content.item.user is None or admin_content.item.user == "": + raise ValueError("The user name is missing or is blank. Cannot create the administrator user.") + + if admin_content.name is not None: + admin_content.description = (resource_content.description + ", User " + + admin_content.name) + + # We need to populate the id and uid of the content, now that we have data in the content. + self.populate_admin_content_ids(admin_content, resource_vertex) + + ad_user, ad_domain = split_user_and_domain(admin_content.item.user) + if ad_domain is not None and admin_content.item.source == LOCAL_USER: + self.logger.debug("The admin is an directory user, but the source is set to a local user") + + found_admin_record_uid = None + try: + found_admin_record_uid = self._find_admin_directory_user( + domain=ad_domain, + admin_acl=admin_acl, + directory_info_func=directory_info_func, + record_lookup_func=record_lookup_func, + context=context, + user=admin_content.item.user, + dn=admin_content.item.dn + ) + except DirectoryNotFoundException: + self.logger.debug(f" directory {source} was not found for admin user") + except UserNotFoundException: + self.logger.debug(f" directory user was not found in directory {source}") + + if found_admin_record_uid is not None: + self.logger.debug(" found directory user admin, connect to resource") + found_admin_vertices = self.infra.dag.search_content({"record_uid": found_admin_record_uid}) + if len(found_admin_vertices) == 1: + found_admin_vertices[0].belongs_to(resource_vertex, edge_type=EdgeType.KEY) + self.record_link.belongs_to(found_admin_record_uid, resource_content.record_uid, + acl=admin_acl) + return found_admin_record_uid + + # Does an admin vertex already exist for this user? + # This most likely user on the gateway, since without a resource record users can be discovered. + # If we did find it, get the content for the admin; we really want any existing record uid. + admin_vertex = self.infra.dag.get_vertex(admin_content.uid) + if admin_vertex is not None and admin_vertex.active is True and admin_vertex.has_data is True: + self.logger.debug("admin exists in the graph") + found_content = DiscoveryObject.get_discovery_object(admin_vertex) + admin_record_uid = found_content.record_uid + else: + self.logger.debug("admin does not exists in the graph") - admin_vertex.belongs_to(resource_vertex, edge_type=EdgeType.KEY) - self.record_link.belongs_to(admin_record_uid, resource_content.record_uid, acl=admin_acl) - else: - if admin_vertex is None: - self.logger.debug("creating an entry in the graph for the admin") - admin_vertex = self.infra.dag.add_vertex(uid=admin_content.uid, - name=admin_content.description) + # If there is a record UID for the admin user, connect it. + if admin_record_uid is not None: + self.logger.debug("the admin has a record UID") - # Since this record does not exist, it will belong to the resource, + # If the admin record does not belong to another resource, make this resource its owner. + if self.record_link.get_parent_record_uid(admin_record_uid) is None: + self.logger.debug("the admin does not belong to another resources, " + "setting it belong to this resource") admin_acl.belongs_to = True - # Connect the user vertex to the resource vertex. - # We need to add a KEY edge for the admin content stored on the DATA edge. - admin_vertex.belongs_to(resource_vertex, edge_type=EdgeType.KEY) - admin_vertex.add_data(admin_content) + admin_vertex.belongs_to(resource_vertex, edge_type=EdgeType.KEY) + self.record_link.belongs_to(admin_record_uid, resource_content.record_uid, acl=admin_acl) + else: + if admin_vertex is None: + self.logger.debug("creating an entry in the graph for the admin") + admin_vertex = self.infra.dag.add_vertex(uid=admin_content.uid, + name=admin_content.description) - # The record will be a user record; admin_acl will not be None - self._prepare_record( - record_prepare_func=record_prepare_func, - bulk_add_records=bulk_add_records, - content=admin_content, - parent_content=resource_content, - vertex=admin_vertex, - context=context - ) + # Since this record does not exist, it will belong to the resource, + admin_acl.belongs_to = True - self.record_link.discovery_belongs_to(admin_vertex, resource_vertex, acl=admin_acl) + # Connect the user vertex to the resource vertex. + # We need to add a KEY edge for the admin content stored on the DATA edge. + admin_vertex.belongs_to(resource_vertex, edge_type=EdgeType.KEY) + admin_vertex.add_data(admin_content) - else: - self.logger.debug("add admin user from existing record") + # The record will be a user record; admin_acl will not be None + self._prepare_record( + record_prepare_func=record_prepare_func, + bulk_add_records=bulk_add_records, + content=admin_content, + parent_content=resource_content, + vertex=admin_vertex, + context=context + ) - # If this is NOT existing directory user, we want to convert the record rotation setting to - # work with this gateway/controller. - # If it is a directory user, we just want link this record; no conversion. - if admin_result.is_directory_user is False: + self.record_link.discovery_belongs_to(admin_vertex, resource_vertex, acl=admin_acl) - self.logger.debug("the admin user is NOT a directory user, convert record's rotation settings") + admin_record_uid = admin_content.record_uid + else: + self.logger.debug("add admin user from existing record") - # This is a pamUser record that may need to have the controller set. - # Add it to this queue to make sure the protobuf items are current. - parent_record_uid = resource_content.record_uid - if resource_content.object_type_value == "providers": - parent_record_uid = None + # If this is NOT existing directory user, we want to convert the record rotation setting to + # work with this gateway/controller. + # If it is a directory user, we just want link this record; no conversion. + if admin_result.is_directory_user is False: - bulk_convert_records.append( - BulkRecordConvert( - record_uid=admin_record_uid, - parent_record_uid=parent_record_uid - ) - ) + self.logger.debug("the admin user is NOT a directory user, convert record's rotation settings") - # If this user record does not belong to another resource, make it belong to this one. - record_vertex = self.record_link.acl_has_belong_to_record_uid(admin_record_uid) - if record_vertex is None: - admin_acl.belongs_to = True + # This is a pamUser record that may need to have the controller set. + # Add it to this queue to make sure the protobuf items are current. + parent_record_uid = resource_content.record_uid + if resource_content.object_type_value == "providers": + parent_record_uid = None - # There is _prepare_record, the record exists. - # Needs to add to records linking. - else: - self.logger.debug("the admin user is a directory user") - - # Link the record UIDs. - # We might not have this user in discovery data. - # It might not belong to the resource; if so, it cannot be rotated. - # It only has is_admin in the ACL. - self.record_link.belongs_to( - admin_record_uid, - record_uid, - acl=admin_acl + bulk_convert_records.append( + BulkRecordConvert( + record_uid=admin_record_uid, + parent_record_uid=parent_record_uid + ) ) + # If this user record does not belong to another resource, make it belong to this one. + record_vertex = self.record_link.acl_has_belong_to_record_uid(admin_record_uid) + if record_vertex is None: + admin_acl.belongs_to = True + + # There is _prepare_record, the record exists. + # Needs to add to records linking. + else: + self.logger.debug("the admin user is a directory user") + + # Link the record UIDs. + # We might not have this user in discovery data. + # It might not belong to the resource; if so, it cannot be rotated. + # It only has is_admin in the ACL. + # self.record_link.belongs_to( + # admin_record_uid, + # record_uid, + # acl=admin_acl + # ) + + return admin_record_uid + + return None + def _get_count(self, current_vertex: DAGVertex) -> int: """ @@ -1534,7 +1333,7 @@ def _get_count(self, current_vertex: DAGVertex) -> int: count = 0 for vertex in current_vertex.has_vertices(): - if vertex.active is False: + if not vertex.active: continue content = DiscoveryObject.get_discovery_object(vertex) @@ -1579,7 +1378,7 @@ def run(self, :param record_cache: A dictionary of record types to keys to record UID. :param prompt_func: Function to call when the user needs to make a decision about an object. - :param smart_add: If we have resource cred, add the resource and the users. + :param smart_add: If we have resource cred, add the resource and the users. DEPRECATED :param record_lookup_func: Function to look up a record by UID. :param record_prepare_func: Function to call to prepare a record to be created. :param record_create_func: Function to call to save the prepared records. @@ -1602,7 +1401,7 @@ def run(self, # There will be only one. self.logger.debug(f"loading the graph at sync point {sync_point}") self.infra.load(sync_point=sync_point) - if self.infra.has_discovery_data is False: + if not self.infra.has_discovery_data: raise NoDiscoveryDataException("There is no discovery data to process.") # If the graph is corrupted, delete the bad vertices. @@ -1656,7 +1455,6 @@ def run(self, current_vertex=configuration, bulk_add_records=bulk_add_records, bulk_convert_records=bulk_convert_records, - smart_add=smart_add, record_lookup_func=record_lookup_func, record_prepare_func=record_prepare_func, directory_info_func=directory_info_func, @@ -1693,7 +1491,7 @@ def run(self, # This mainly for testing. # If throw and quit exception, so we can prompt the user. - if force_quit is True: + if force_quit: raise QuitException() except QuitException: @@ -1713,7 +1511,7 @@ def run(self, should_add_records = False # We should add the record, and a method was passed in to create them; then add the records. - if should_add_records is True: + if should_add_records: self.logger.debug("# ####################################################################################") self.logger.debug("# CREATE NEW RECORD") @@ -1764,19 +1562,7 @@ def run(self, self.infra.save(delta_graph=False) self.logger.debug("# ####################################################################################") - # Save the record linking, only if we added records. - # This will be the additions and any changes to ACL. - if should_add_records is True: - - self.logger.debug("# ####################################################################################") - self.logger.debug("# Save RECORD LINKING graph") - self.logger.debug("#") - - self.logger.debug(f"save additions from record linking ") - self.record_link.save() - self.logger.debug("# ####################################################################################") - - # Map user to service/task on a machine - self.user_service.run(infra=self.infra) + # Update the user service mapping + self.user_service.run(infra=self.infra) return bulk_process_results diff --git a/keepercommander/discovery_common/record_link.py b/keepercommander/discovery_common/record_link.py index e43ed31c0..5afc2b450 100644 --- a/keepercommander/discovery_common/record_link.py +++ b/keepercommander/discovery_common/record_link.py @@ -408,7 +408,9 @@ def delete(vertex: DAGVertex): vertex.delete() def save(self): - if self.dag.has_graph is True: + + self.logger.info("DISCOVERY COMMON RECORD LINKING GRAPH SAVE CALLED") + if self.dag.has_graph: self.logger.debug("saving the record linking.") self.dag.save(delta_graph=False) else: diff --git a/keepercommander/discovery_common/rm_types.py b/keepercommander/discovery_common/rm_types.py index c8c2433c6..3f6d00b3c 100644 --- a/keepercommander/discovery_common/rm_types.py +++ b/keepercommander/discovery_common/rm_types.py @@ -335,7 +335,7 @@ class RmBaseLdapUserAddMeta(RmMetaBase): auto_uid_number: bool = True gid_number_match_uid: bool = True home_dir_base: Optional[str] = "/home" - first_rdn_component: Optional[str] = None + first_rdn_component: Optional[str] = "CN" attributes: Optional[dict] = {} groups: List[str] = [] diff --git a/keepercommander/discovery_common/rule.py b/keepercommander/discovery_common/rule.py index c5ec0836b..328ae289f 100644 --- a/keepercommander/discovery_common/rule.py +++ b/keepercommander/discovery_common/rule.py @@ -4,7 +4,7 @@ from .utils import value_to_boolean, get_connection, make_agent from ..keeper_dag import DAG, EdgeType from ..keeper_dag.exceptions import DAGException -from ..keeper_dag.types import PamGraphId, PamEndpoints +from ..keeper_dag.types import PamGraphId from time import time import base64 import os @@ -57,6 +57,9 @@ class Rules: "domainName": {"type": str}, "directoryId": {"type": str}, "directoryType": {"type": str}, + + # Progmatically added + "ip": {"type": str}, } BREAK_OUT = { @@ -66,6 +69,7 @@ class Rules: } } + # If creating an ignore role, these fields are used in the rule. RECORD_FIELD = { "pamMachine": ["pamHostname"], "pamDatabase": ["pamHostname", "databaseType"], @@ -133,11 +137,19 @@ def close(self): Clean up resources held by this Rules instance. Releases the DAG instance and connection to prevent memory leaks. """ - if self._dag is not None: - if self.logger: - self.logger.debug("closing Rules DAG instance") - self._dag = None - self.conn = None + + try: + if hasattr(self, "_dag"): + self.conn = None + del self._dag + if hasattr(self, "conn"): + self.conn = None + del self.conn + if hasattr(self, "record"): + self.conn = None + del self.conn + except (Exception,): + pass def __enter__(self): """Context manager entry.""" @@ -148,6 +160,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() return False + def __del__(self): + self.close() + + @staticmethod def data_path(rule_type: RuleTypeEnum): return f"/{rule_type.value}" @@ -236,6 +252,23 @@ def _remove_rule(r: RuleItem, rs: List[RuleItem]): func=_remove_rule ) + def remove_all(self, rule_type: RuleTypeEnum): + + def _remove_all_rules(r: Any, rs: List[RuleItem]): + return [] + + # _rule_transaction determines the graph vertex from Rule class type + fake_rule = None + if rule_type == RuleTypeEnum.ACTION: + fake_rule = ActionRuleItem(statement=[]) + else: + raise ValueError("rule type not supported with remove_all") + + self._rule_transaction( + rule=fake_rule, + func=_remove_all_rules + ) + def rule_list(self, rule_type: RuleTypeEnum, search: Optional[str] = None) -> List[RuleItem]: rule_list = [] for rule_item in self.get_ruleset(rule_type).rules: @@ -319,18 +352,32 @@ def make_action_rule_statement_str(statement: List[Statement]) -> str: field_type = Rules.RULE_FIELDS.get(item.field).get("type") if field_type is None: raise ValueError("Unknown field in rule") - if field_type is str: - statement_str += f"'{item.value}'" - elif field_type is bool: - if value_to_boolean(item.value) is True: - statement_str += "true" - else: - statement_str += "false" - elif field_type is float: - if int(item.value) == item.value: - statement_str += str(int(item.value)) + + values = item.value + new_values = [] + if item.operator != "in": + values = [values] + + for value in values: + if field_type is str: + new_value = f"'{value}'" + elif field_type is bool: + if value_to_boolean(value) is True: + new_value = "true" + else: + new_value = "false" + elif field_type is float: + if int(value) == value: + new_value = str(int(value)) + else: + new_value = str(value) else: - statement_str += str(item.value) + raise ValueError("Cannot determine the field type for rule statement.") + + new_values.append(new_value) + + if item.operator == "in": + statement_str += "[" + ", ".join(new_values) + "]" else: - raise ValueError("Cannot determine the field type for rule statement.") + statement_str += new_values[0] return statement_str diff --git a/keepercommander/discovery_common/types.py b/keepercommander/discovery_common/types.py index dc1053802..cea781fb3 100644 --- a/keepercommander/discovery_common/types.py +++ b/keepercommander/discovery_common/types.py @@ -179,10 +179,11 @@ class RuleActionEnum(BaseEnum): class Statement(BaseModel): field: str operator: str - value: Union[str, bool, float] + value: Any class RuleItem(BaseModel): + name: Optional[str] = None added_ts: Optional[int] = None rule_id: Optional[str] = None enabled: bool = True @@ -209,10 +210,23 @@ def search(self, search: str) -> bool: return False + def close(self): + try: + if self.engine_rule: + self.engine_rule.close() + self.engine_rule = None + del self.engine_rule + except (Exception,): + pass + + def __del__(self): + self.close() + class ActionRuleItem(RuleItem): action: RuleActionEnum = RuleActionEnum.PROMPT shared_folder_uid: Optional[str] = None + admin_uid: Optional[str] = None class ScheduleRuleItem(RuleItem): @@ -224,7 +238,18 @@ class ComplexityRuleItem(RuleItem): class RuleSet(BaseModel): - pass + rules: List[RuleItem] = [] + + @property + def count(self) -> int: + return len(self.rules) + + def __str__(self): + rule_set = [] + for item in self.rules: + rule_set.append(item.model_dump_json()) + + return "[" + ",\n" .join(rule_set) + "]" class ActionRuleSet(RuleSet): @@ -470,6 +495,7 @@ class DiscoveryObject(BaseModel): fields: List[RecordField] ignore_object: bool = False action_rules_result: Optional[str] = None + admin_uid: Optional[str] = None shared_folder_uid: Optional[str] = None name: str title: str @@ -650,6 +676,9 @@ class BulkRecordAdd(BaseModel): record: Any record_type: str + # If record_type is a PAM User, is this user the admin of the resource? + admin_uid: Optional[str] = None + # Normal record UID strings record_uid: str parent_record_uid: Optional[str] = None diff --git a/keepercommander/discovery_common/verify.py b/keepercommander/discovery_common/verify.py index a2d89b99d..5e90ab5a7 100644 --- a/keepercommander/discovery_common/verify.py +++ b/keepercommander/discovery_common/verify.py @@ -67,6 +67,49 @@ def __init__(self, record: Any, logger: Optional[Any] = None, debug_level: int = self.debug_level = debug_level self.logger.debug(f"configuration uid is {self.conn.get_record_uid(record)}") + def close(self): + """ + Clean up all resources held by Verify instance. + Pattern matches keeper-dag's DAG.cleanup() defensive approach. + """ + try: + # Close Infrastructure (has DAG connection) + if hasattr(self, 'infra') and self.infra is not None: + if hasattr(self.infra, 'close'): + self.infra.close() + self.infra = None + + # Close RecordLink + if hasattr(self, 'record_link') and self.record_link is not None: + if hasattr(self.record_link, 'close'): + self.record_link.close() + self.record_link = None + + # Close UserService + if hasattr(self, 'user_service') and self.user_service is not None: + if hasattr(self.user_service, 'close'): + self.user_service.close() + self.user_service = None + + # Close main connection + if hasattr(self, 'conn') and self.conn is not None: + if hasattr(self.conn, 'close'): + self.conn.close() + self.conn = None + + except (Exception,): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def __del__(self): + self.close() + def _msg(self, msg, color_name="NONE"): print(f"{self.colors.get(color_name, '')}{msg}{self.colors.get(Verify.COLOR_RESET, '')}", file=self.output) diff --git a/keepercommander/display.py b/keepercommander/display.py index 43737bdf0..a05ca5911 100644 --- a/keepercommander/display.py +++ b/keepercommander/display.py @@ -67,6 +67,23 @@ def keeper_colorize(text, color): return text +def show_government_warning(): + """Display U.S. Government Information System warning for GOV environments.""" + print('') + print(f'{bcolors.WARNING}' + '=' * 80 + f'{bcolors.ENDC}') + print(f'{bcolors.WARNING}U.S. GOVERNMENT INFORMATION SYSTEM{bcolors.ENDC}') + print(f'{bcolors.WARNING}' + '=' * 80 + f'{bcolors.ENDC}') + print('') + print('You are about to access a U.S. Government Information System. Although the') + print('encrypted vault adheres to a zero-knowledge security architecture, system') + print('access logs are subject to monitoring, recording and audit. Unauthorized') + print('use of this system is prohibited and may result in civil and criminal') + print('penalties. Your use of this system indicates your acknowledgement and consent.') + print('') + print(f'{bcolors.WARNING}' + '=' * 80 + f'{bcolors.ENDC}') + print('') + + def welcome(): lines = [] # type: List[Union[str, Tuple[str, str]]] @@ -107,7 +124,8 @@ def welcome(): white_line = '' print('\033[2K' + Fore.LIGHTYELLOW_EX + yellow_line + Fore.LIGHTWHITE_EX + white_line) - print('\033[2K' + Fore.LIGHTBLACK_EX + f'{("v" + __version__):>93}\n' + Style.RESET_ALL) + print('\033[2K' + Fore.LIGHTBLACK_EX + f'{("v" + __version__):>93}' + Style.RESET_ALL) + print() def formatted_records(records, **kwargs): @@ -227,3 +245,75 @@ def print_record(params, record_uid): raise Exception('Record not found: ' + record_uid) data = json.loads(cached_rec['data_unencrypted'].decode('utf-8')) print(data) + + +import sys +import time +import threading + + +class Spinner: + """Animated spinner for long-running operations.""" + + # Claude-style spinner frames + FRAMES = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] + + def __init__(self, message=""): + self.message = message + self.running = False + self.thread = None + + def _animate(self): + idx = 0 + while self.running: + frame = self.FRAMES[idx % len(self.FRAMES)] + sys.stdout.write(f'\r{Fore.CYAN}{frame}{Fore.RESET} {self.message}') + sys.stdout.flush() + idx += 1 + time.sleep(0.08) + # Clear the line when done + sys.stdout.write('\r' + ' ' * (len(self.message) + 4) + '\r') + sys.stdout.flush() + + def start(self): + self.running = True + self.thread = threading.Thread(target=self._animate, daemon=True) + self.thread.start() + + def stop(self): + self.running = False + if self.thread: + self.thread.join(timeout=0.5) + + +def post_login_summary(record_count=0, breachwatch_count=0, show_tips=True): + """Display a polished post-login summary.""" + + ACCENT = Fore.GREEN + DIM = Fore.WHITE + WARN = Fore.YELLOW + + print() + + # Vault summary + if record_count > 0: + print(f" {ACCENT}✓{Fore.RESET} Decrypted {record_count} records") + + # BreachWatch warning + if breachwatch_count > 0: + print(f" {WARN}⚠ {breachwatch_count} high-risk passwords{Fore.RESET} - run {ACCENT}breachwatch list{Fore.RESET}") + + if show_tips: + print() + print(f" {DIM}Quick Start:{Fore.RESET}") + print(f" {ACCENT}ls{Fore.RESET} List records") + print(f" {ACCENT}ls -l -f{Fore.RESET} List folders") + print(f" {ACCENT}cd {Fore.RESET} Change folder") + print(f" {ACCENT}get {Fore.RESET} Get record or folder info") + print(f" {ACCENT}supershell{Fore.RESET} Launch vault TUI") + print(f" {ACCENT}search{Fore.RESET} Search your vault") + print(f" {ACCENT}this-device{Fore.RESET} Configure device settings") + print(f" {ACCENT}whoami{Fore.RESET} Display account info") + print(f" {ACCENT}?{Fore.RESET} List all commands") + + print() diff --git a/keepercommander/importer/commands.py b/keepercommander/importer/commands.py index fcf2649f9..8e9869cfa 100644 --- a/keepercommander/importer/commands.py +++ b/keepercommander/importer/commands.py @@ -137,7 +137,7 @@ def register_command_info(aliases, command_info): load_record_type_parser = argparse.ArgumentParser( - prog='load-record-types', description='Loads custom record types from JSON file into Keeper.') + prog='load-record-types', description='Loads custom record types from JSON file') load_record_type_parser.add_argument( 'name', type=str, nargs='?', help='Input file name. "record_types.json" if omitted.') diff --git a/keepercommander/keeper_dag/__version__.py b/keepercommander/keeper_dag/__version__.py index 12ce4098d..ffd0919d5 100644 --- a/keepercommander/keeper_dag/__version__.py +++ b/keepercommander/keeper_dag/__version__.py @@ -1 +1,2 @@ -__version__ = '1.1.0' # pragma: no cover +__version__ = '1.1.3' # pragma: no cover + diff --git a/keepercommander/keeper_dag/connection/commander.py b/keepercommander/keeper_dag/connection/commander.py index b69a28b4f..be023c1ce 100644 --- a/keepercommander/keeper_dag/connection/commander.py +++ b/keepercommander/keeper_dag/connection/commander.py @@ -65,13 +65,15 @@ def get_key_bytes(record: KeeperRecord) -> bytes: def hostname(self) -> str: # The host is connect.keepersecurity.com, connect.dev.keepersecurity.com, etc. Append "connect" in front # of host used for Commander. - configured_host = f'connect.{self.params.config.get("server")}' - - # In GovCloud environments, the router service is not under the govcloud subdomain - if 'govcloud.' in configured_host: - # "connect.govcloud.keepersecurity.com" -> "connect.keepersecurity.com" - configured_host = configured_host.replace('govcloud.', '') - + server = self.params.config.get("server") + + # Only PROD GovCloud strips the subdomain (workaround for prod infrastructure). + # DEV/QA GOV (govcloud.dev.keepersecurity.us, govcloud.qa.keepersecurity.us) keep govcloud. + if server == 'govcloud.keepersecurity.us': + configured_host = 'connect.keepersecurity.us' + else: + configured_host = f'connect.{server}' + return os.environ.get("ROUTER_HOST", configured_host) @property diff --git a/keepercommander/keeper_dag/connection/ksm.py b/keepercommander/keeper_dag/connection/ksm.py index d2ce1a01b..ed89697dc 100644 --- a/keepercommander/keeper_dag/connection/ksm.py +++ b/keepercommander/keeper_dag/connection/ksm.py @@ -116,9 +116,10 @@ def app_key(self) -> str: def router_url_from_ksm_config(self) -> str: hostname = self.hostname - # In GovCloud environments, the router service is not under the govcloud subdomain - if 'govcloud.' in hostname: - hostname = hostname.replace('govcloud.', '') + # Only PROD GovCloud strips the subdomain (workaround for prod infrastructure). + # DEV/QA GOV (govcloud.dev.keepersecurity.us, govcloud.qa.keepersecurity.us) keep govcloud. + if hostname == 'govcloud.keepersecurity.us': + hostname = 'keepersecurity.us' return f'connect.{hostname}' def ws_router_url_from_ksm_config(self, is_ws: bool = False) -> str: @@ -297,6 +298,7 @@ def rest_call_to_router(self, # If we get a 401 Unauthorized, and we have not yet refreshed, # refresh the signature. if response.status_code == 401 and refresh is False: + response.close() self.logger.debug("rest call was Unauthorized") # The attempt didn't count. diff --git a/keepercommander/keeper_dag/dag.py b/keepercommander/keeper_dag/dag.py index 45f8abd36..a2d6ddfed 100644 --- a/keepercommander/keeper_dag/dag.py +++ b/keepercommander/keeper_dag/dag.py @@ -219,29 +219,7 @@ def __init__(self, self.debug(f"{self.log_prefix} UID {self.uid}", level=1) self.debug(f"{self.log_prefix} UID HEX {urlsafe_str_to_bytes(self.uid).hex()}", level=1) - def __del__(self): - self.cleanup() - - def cleanup(self): - """ - Explicitly clean up the DAG and break circular references. - - This method allows users to manually trigger cleanup before the object - goes out of scope. This is useful in scenarios where you want to ensure - immediate memory release, such as: - - High-frequency DAG creation/destruction - - Long-running processes - - Memory-constrained environments - - After calling this method, the DAG object should not be used. - - Example: - dag = DAG(conn=conn, key_bytes=key) - # ... use the dag ... - dag.cleanup() # Explicitly clean up - del dag - """ - + def close(self): try: # Safely get the root vertex without creating a new one if hasattr(self, '_vertices') and hasattr(self, 'uid') and hasattr(self, '_uid_lookup'): @@ -266,12 +244,28 @@ def cleanup(self): pass # Clear all collections to break circular references - self.read_struct_obj = None - del self.read_struct_obj - self.write_struct_obj = None - del self.write_struct_obj - self.conn = None - del self.conn + try: + self.read_struct_obj = None + del self.read_struct_obj + self.write_struct_obj = None + del self.write_struct_obj + self.conn = None + del self.conn + except (Exception,): + pass + + def cleanup(self): + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def __del__(self): + self.close() def debug(self, msg: str, level: int = 0): """ @@ -839,7 +833,7 @@ def _decrypt_data(self): edge.content = content edge.needs_encryption = False - self.debug(f" * edge is not encrypted or key is incorrect.") + self.debug(f" * edge is not encrypted or key is incorrect.", level=3) # Change the flag indicating that the content is in decrypted state. edge.is_encrypted = False @@ -1266,7 +1260,7 @@ def _search(self, content: Any, value: QueryValue, ignore_case: bool = False): content = content.lower() value = value.lower() - return value in content + return value == content def search_content(self, query, ignore_case: bool = False): results = [] diff --git a/keepercommander/loginv3.py b/keepercommander/loginv3.py index 68857a807..a7b27f6b0 100644 --- a/keepercommander/loginv3.py +++ b/keepercommander/loginv3.py @@ -48,7 +48,7 @@ def _fallback_to_password_auth(self, params, encryptedDeviceToken, clone_code_by logging.info("Falling back to default authentication...") return LoginV3API.startLoginMessage(params, encryptedDeviceToken, cloneCode=clone_code_bytes, loginType=login_type) - def login(self, params, new_device=False, new_login=False): # type: (KeeperParams, bool, bool) -> None + def login(self, params, new_device=False, new_login=False, new_password_if_reset_required=None): # type: (KeeperParams, bool, bool, string) -> None logging.debug("Login v3 Start as '%s'", params.user) @@ -351,13 +351,13 @@ def cancel(self): elif resp.loginState == APIRequest_pb2.UPGRADE: raise Exception('Application or device is out of date and requires an update.') elif resp.loginState == APIRequest_pb2.LOGGED_IN: - LoginV3Flow.post_login_processing(params, resp) + LoginV3Flow.post_login_processing(params, resp, new_password_if_reset_required) return else: raise Exception("UNKNOWN LOGIN STATE [%s]" % resp.loginState) @staticmethod - def post_login_processing(params: KeeperParams, resp: APIRequest_pb2.LoginResponse): + def post_login_processing(params: KeeperParams, resp: APIRequest_pb2.LoginResponse, new_password_if_reset_required=None): """Processing after login Returns True if authentication is successful and False otherwise. @@ -384,12 +384,17 @@ def post_login_processing(params: KeeperParams, resp: APIRequest_pb2.LoginRespon ) raise Exception(msg) elif resp.sessionTokenType == APIRequest_pb2.ACCOUNT_RECOVERY: - print('Your Master Password has expired, you are required to change it before you can login.\n') - if LoginV3Flow.change_master_password(params): + if new_password_if_reset_required: + print('Resetting expired Master Password.\n') + LoginV3API.change_master_password(params, new_password_if_reset_required) # always returns False return False - else: - params.clear_session() - raise Exception('Change password failed') + elif new_password_if_reset_required is None: + print('Your Master Password has expired, you are required to change it before you can login.\n') + if LoginV3Flow.change_master_password(params): + return False + # Return exception if password change fails + params.clear_session() + raise Exception('Change password failed') elif resp.sessionTokenType == APIRequest_pb2.SHARE_ACCOUNT: logging.info('Account transfer required') accepted = api.accept_account_transfer_consent(params) @@ -442,7 +447,7 @@ def get_data_key(params: KeeperParams, resp: APIRequest_pb2.LoginResponse): elif resp.encryptedDataKeyType == APIRequest_pb2.BY_PASSWORD: decrypted_data_key = \ utils.decrypt_encryption_params(resp.encryptedDataKey, params.password) - login_type_message = bcolors.UNDERLINE + "Password" + login_type_message = bcolors.UNDERLINE + "Master Password" elif resp.encryptedDataKeyType == APIRequest_pb2.BY_ALTERNATE: decryption_key = crypto.derive_keyhash_v2('data_key', params.password, params.salt, params.iterations) diff --git a/keepercommander/plugins/commands.py b/keepercommander/plugins/commands.py index 641e2953a..a1ee7ec34 100644 --- a/keepercommander/plugins/commands.py +++ b/keepercommander/plugins/commands.py @@ -37,7 +37,7 @@ def register_command_info(aliases, command_info): rotate_parser = argparse.ArgumentParser( - prog='rotate', description='Rotate the password for a Keeper record from this Commander.' + prog='rotate', description='Rotate the password for a Keeper record.' ) rotate_parser.add_argument( '--print', dest='print', action='store_true', help='display the record content after rotation' diff --git a/keepercommander/record.py b/keepercommander/record.py index 133c7914e..472254e16 100644 --- a/keepercommander/record.py +++ b/keepercommander/record.py @@ -94,7 +94,8 @@ def load(self, data, **kwargs): if 'notes' in data: self.notes = Record.xstr(data['notes']) - if self.version == 2: + if self.version in (1, 2): + self.record_type = 'general' if 'secret1' in data: self.login = Record.xstr(data['secret1']) if 'secret2' in data: @@ -110,7 +111,7 @@ def load(self, data, **kwargs): for field in extra['fields']: if field['field_type'] == 'totp': self.totp = field['data'] - elif self.version == 3: + elif self.version in (3, 5, 6): self.record_type = data.get('type', 'login') for field in itertools.chain(data['fields'], data.get('custom') or []): field_label = field.get('label', '') @@ -137,7 +138,8 @@ def load(self, data, **kwargs): continue if field_type: - field_name = f'{field_type}:{field_label}' + # Only include colon separator if there's a label + field_name = f'{field_type}:{field_label}' if field_label else field_type elif field_label: field_name = field_label else: @@ -197,13 +199,33 @@ def remove_field(self, name): if len(idxs) == 1: return self.custom_fields.pop(idxs[0]) + def get_unmasked_field_params(self): + """Return unmasked field parameters (login, url, etc.)""" + if self.login: + yield 'login', self.login + if self.login_url: + yield 'login_url', self.login_url + + def get_masked_field_params(self): + """Return masked field parameters (password, totp)""" + if self.password: + yield 'password', self.password + if self.totp: + yield 'totp', self.totp + + def get_typed_fields(self): + """Return typed fields (for v3 records) - returns empty for legacy Record""" + return [] + def display(self, unmask=False): print('') print('{0:>20s}: {1:<20s}'.format('UID', self.record_uid)) - print('{0:>20s}: {1:<20s}'.format('Type', '')) + print('{0:>20s}: {1:<20s}'.format('Type', self.record_type if self.record_type else '')) if self.title: print('{0:>20s}: {1:<20s}'.format('Title', self.title)) if self.login: print('{0:>20s}: {1:<20s}'.format('Login', self.login)) - if self.password: print('{0:>20s}: {1:<20s}'.format('Password', self.password if unmask else '********')) + if self.password: + display_password = (self.unmasked_password or self.password) if unmask else '********' + print('{0:>20s}: {1:<20s}'.format('Password', display_password)) if self.login_url: print('{0:>20s}: {1:<20s}'.format('URL', self.login_url)) # print('{0:>20s}: https://keepersecurity.com/vault#detail/{1}'.format('Link',self.record_uid)) @@ -211,7 +233,55 @@ def display(self, unmask=False): for c in self.custom_fields: if not 'value' in c: c['value'] = '' if not 'name' in c: c['name'] = c['type'] if 'type' in c else '' - print('{0:>20s}: {1:20s}:'.format('Passkey')) + # Format created date + created_ts = pk_value.get('createdDate', 0) + if created_ts: + created_dt = datetime.datetime.fromtimestamp(created_ts / 1000) + created_str = created_dt.strftime('%m/%d/%Y, %I:%M %p') + print('{0:>28s}: {1}'.format('Created', created_str)) + username = pk_value.get('username', '') + if username: + print('{0:>28s}: {1}'.format('Username', username)) + relying_party = pk_value.get('relyingParty', '') + if relying_party: + print('{0:>28s}: {1}'.format('Relying Party', relying_party)) + continue + # Strip type prefixes from field names (e.g., "text:Sign-In Address" -> "Sign-In Address") + field_type_prefixes = ('text:', 'multiline:', 'url:', 'phone:', 'email:', 'secret:', 'date:', 'name:', 'host:', 'address:') + display_name = field_name + for prefix in field_type_prefixes: + if field_name.lower().startswith(prefix): + display_name = field_name[len(prefix):] + # If label was empty, use a friendly name based on type + if not display_name: + type_friendly_names = { + 'text:': 'Text', + 'multiline:': 'Note', + 'url:': 'URL', + 'phone:': 'Phone', + 'email:': 'Email', + 'secret:': 'Secret', + 'date:': 'Date', + 'name:': 'Name', + 'host:': 'Host', + 'address:': 'Address', + } + display_name = type_friendly_names.get(prefix, prefix.rstrip(':').title()) + break + print('{0:>20s}: {1:20s}:'.format('Passkey')) + print('{0:>28s}: {1}'.format('Created', created_str)) + if username: + print('{0:>28s}: {1}'.format('Username', username)) + if relying_party: + print('{0:>28s}: {1}'.format('Relying Party', relying_party)) + continue + elif ftyp not in record_types.RecordFields and not unmask: fval = '********' else: fval = json.dumps(flds[0], indent=2) diff --git a/keepercommander/rest_api.py b/keepercommander/rest_api.py index 5754169b4..ca9a80e93 100644 --- a/keepercommander/rest_api.py +++ b/keepercommander/rest_api.py @@ -165,6 +165,7 @@ def execute_rest(context, endpoint, payload): qrc_success = True except Exception as e: logging.warning(f"QRC encryption failed ({e}), falling back to EC encryption") + context.disable_qrc() # Fallback to EC encryption if QRC not available or failed if not qrc_success: diff --git a/keepercommander/service/config/command_validator.py b/keepercommander/service/config/command_validator.py index 6398c7f68..e00d61c9a 100644 --- a/keepercommander/service/config/command_validator.py +++ b/keepercommander/service/config/command_validator.py @@ -95,8 +95,10 @@ def _process_new_command_line(self, line: str, valid_commands: Set, main_command = parts[0] alias = None - # Check if there's an alias in parentheses as the next token - if len(parts) > 1 and parts[1].startswith('(') and parts[1].endswith(')'): + if len(parts) >= 3 and parts[2].startswith('(') and parts[2].endswith(')'): + main_command = f"{parts[0]} {parts[1]}" + alias = parts[2][1:-1].strip() + elif len(parts) > 1 and parts[1].startswith('(') and parts[1].endswith(')'): # Extract alias: "(alias)" -> "alias" alias = parts[1][1:-1].strip() # Check if command and alias are combined: "command (alias)" diff --git a/keepercommander/sox/__init__.py b/keepercommander/sox/__init__.py index 960d48532..0fd70834d 100644 --- a/keepercommander/sox/__init__.py +++ b/keepercommander/sox/__init__.py @@ -6,6 +6,26 @@ from typing import Dict, Tuple from .. import api, crypto, utils + +# Module-level connection cache to ensure single connection per database +_connection_cache = {} # type: Dict[str, sqlite3.Connection] + + +def get_cached_connection(database_name): # type: (str) -> sqlite3.Connection + """Get or create a cached connection for the given database.""" + if database_name not in _connection_cache: + _connection_cache[database_name] = sqlite3.connect(database_name) + return _connection_cache[database_name] + + +def close_cached_connection(database_name): # type: (str) -> None + """Close and remove a cached connection.""" + if database_name in _connection_cache: + try: + _connection_cache[database_name].close() + except Exception: + pass + del _connection_cache[database_name] from ..commands.helpers.enterprise import user_has_privilege, is_addon_enabled from ..error import CommandError, Error, KeeperApiError from ..params import KeeperParams @@ -161,7 +181,10 @@ def sync_all(): ecc_key = crypto.decrypt_aes_v2(ecc_key, tree_key) key = crypto.load_ec_private_key(ecc_key) storage = sqlite_storage.SqliteSoxStorage( - get_connection=lambda: sqlite3.connect(database_name), owner=params.user, database_name=database_name + get_connection=lambda: get_cached_connection(database_name), + owner=params.user, + database_name=database_name, + close_connection=lambda: close_cached_connection(database_name) ) last_updated = storage.last_prelim_data_update only_shared_cached = storage.shared_records_only @@ -171,7 +194,7 @@ def sync_all(): storage.clear_non_aging_data() sync_down(user_lookup, storage) storage.set_shared_records_only(shared_only) - return sox_data.SoxData(params, storage=storage, no_cache=no_cache) + return sox_data.SoxData(params, storage=storage) def get_compliance_data(params, node_id, enterprise_id=0, rebuild=False, min_updated=0, no_cache=False, shared_only=False): @@ -359,7 +382,7 @@ def save_records(records): user_node_ids = {e_user.get('enterprise_user_id'): e_user.get('node_id') for e_user in enterprise_users} sync_down(sd, node_id, user_node_id_lookup=user_node_ids) rebuild_task = sox_data.RebuildTask(is_full_sync=False, load_compliance_data=True) - sd.rebuild_data(rebuild_task, no_cache=no_cache) + sd.rebuild_data(rebuild_task) return sd diff --git a/keepercommander/sox/sox_data.py b/keepercommander/sox/sox_data.py index 26483c85a..c7193ff5a 100644 --- a/keepercommander/sox/sox_data.py +++ b/keepercommander/sox/sox_data.py @@ -39,8 +39,8 @@ def clear_lookup(lookup, uids=None): # type: (dict, Optional[Iterable]) -> None class SoxData: - def __init__(self, params, storage, no_cache=False): - # type: (KeeperParams, sqlite_storage.SqliteSoxStorage, Optional[bool]) -> None + def __init__(self, params, storage): + # type: (KeeperParams, sqlite_storage.SqliteSoxStorage) -> None self.storage = storage # type: sqlite_storage.SqliteSoxStorage self._records = {} # type: Dict[str, sox_types.Record] self._users = {} # type: Dict[int, sox_types.EnterpriseUser] @@ -49,7 +49,7 @@ def __init__(self, params, storage, no_cache=False): self.ec_private_key = get_ec_private_key(params) self.tree_key = params.enterprise.get('unencrypted_tree_key', b'') task = RebuildTask(True) - self.rebuild_data(task, no_cache) + self.rebuild_data(task) def get_records(self, record_ids=None): return self._records if record_ids is None else {uid: self._records.get(uid) for uid in record_ids} @@ -137,7 +137,7 @@ def clear_all(self): def record_count(self): # type: () -> int return len(self._records) - def rebuild_data(self, changes, no_cache=False): # type: (RebuildTask, Optional[bool]) -> None + def rebuild_data(self, changes): # type: (RebuildTask) -> None def decrypt(data): # type: (bytes or str) -> str if isinstance(data, str): return data @@ -267,5 +267,3 @@ def link_sf_teams(store, folder_lookup): self._records.update(load_records(self.storage, changes)) if changes.is_full_sync or changes.load_compliance_data: self._users.update(load_users(self.storage)) - if no_cache: - self.storage.delete_db() diff --git a/keepercommander/sox/sqlite_storage.py b/keepercommander/sox/sqlite_storage.py index 4bed6326c..de7066adc 100644 --- a/keepercommander/sox/sqlite_storage.py +++ b/keepercommander/sox/sqlite_storage.py @@ -29,8 +29,9 @@ def __init__(self): class SqliteSoxStorage: - def __init__(self, get_connection, owner, database_name=''): + def __init__(self, get_connection, owner, database_name='', close_connection=None): self.get_connection = get_connection + self.close_connection = close_connection self.owner = owner self.database_name = database_name @@ -207,8 +208,11 @@ def clear_all(self): def delete_db(self): try: - conn = self.get_connection() - conn.close() + if self.close_connection: + self.close_connection() + else: + conn = self.get_connection() + conn.close() os.remove(self.database_name) except Exception as e: logging.info(f'could not delete db from filesystem, name = {self.database_name}') diff --git a/keepercommander/sync_down.py b/keepercommander/sync_down.py index 4208a6026..bb76d0515 100644 --- a/keepercommander/sync_down.py +++ b/keepercommander/sync_down.py @@ -16,7 +16,7 @@ import google from . import api, utils, crypto, convert_keys -from .display import bcolors +from .display import bcolors, Spinner from .params import KeeperParams, RecordOwner from .proto import SyncDown_pb2, record_pb2, client_pb2, breachwatch_pb2 from .subfolder import RootFolderNode, UserFolderNode, SharedFolderNode, SharedFolderFolderNode, BaseFolderNode @@ -28,8 +28,12 @@ def sync_down(params, record_types=False): # type: (KeeperParams, bool) -> Non params.sync_data = False token = params.sync_down_token - if not token: - logging.info('Syncing...') + + # Use spinner animation for full sync (only in interactive mode, not batch/automation) + spinner = None + if not token and not params.batch_mode: + spinner = Spinner('Syncing...') + spinner.start() for record in params.record_cache.values(): if 'shares' in record: @@ -1017,26 +1021,27 @@ def convert_user_folder_shared_folder(ufsf): type_id += rt.scope * 1000000 params.record_type_cache[type_id] = rt.content + # Stop spinner if running + if spinner: + spinner.stop() + if full_sync: convert_keys.change_key_types(params) + # Count breachwatch issues and store on params for summary display + breachwatch_count = 0 if params.breach_watch: - weak_count = 0 for _ in params.breach_watch.get_records_by_status(params, ['WEAK', 'BREACHED']): - weak_count += 1 - if weak_count > 0: - logging.info(bcolors.WARNING + - f'The number of records that are affected by breaches or contain high-risk passwords: {weak_count}' + - '\nUse \"breachwatch list\" command to get more details' + - bcolors.ENDC) + breachwatch_count += 1 + params._sync_breachwatch_count = breachwatch_count + # Count records and store on params for summary display record_count = 0 valid_versions = {2, 3} for r in params.record_cache.values(): if r.get('version', 0) in valid_versions: record_count += 1 - if record_count: - logging.info('Decrypted [%d] record(s)', record_count) + params._sync_record_count = record_count def _sync_record_types(params): # type: (KeeperParams) -> Any diff --git a/keepercommander/utils.py b/keepercommander/utils.py index 1d674a049..a28cb5b99 100644 --- a/keepercommander/utils.py +++ b/keepercommander/utils.py @@ -10,6 +10,7 @@ # import base64 +import datetime import json import logging import math @@ -19,7 +20,7 @@ import stat import subprocess import time -from typing import Dict, Union +from typing import Dict, Union, Optional from urllib.parse import urlparse, parse_qs, unquote from pathlib import Path import sys @@ -155,6 +156,23 @@ def current_milli_time(): # type: () -> int return int(round(time.time() * 1000)) +def millis_to_datetime(millis, tz=None): # type: (Union[int, float], datetime.timezone) -> Optional[datetime.datetime] + if isinstance(millis, (int, float)): + # epoch = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc) + seconds = int(millis // 1000) + if seconds != 0: + return epoch + datetime.timedelta(seconds=seconds) + return None + + +def datetime_to_millis(dt): # type: (datetime.datetime) -> int + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=datetime.timezone.utc) + return int((dt - epoch).total_seconds()) * 1000 + + def base64_url_decode(s): # type: (str) -> bytes return base64.urlsafe_b64decode(s + '==') diff --git a/keepercommander/vault.py b/keepercommander/vault.py index 0dcf03c84..365e01530 100644 --- a/keepercommander/vault.py +++ b/keepercommander/vault.py @@ -10,7 +10,6 @@ import abc import collections.abc -import datetime import json import logging from typing import Optional, List, Tuple, Iterable, Type, Union, Dict, Any @@ -18,7 +17,7 @@ import itertools from .params import KeeperParams -from . import record_types, constants +from . import record_types, constants, utils def sanitize_str_field_value(value): # type: (Any) -> str @@ -253,7 +252,7 @@ def get_version(self): # type: () -> int return 2 def get_record_type(self): - return '' + return 'general' def load_record_data(self, data, extra=None): self.title = sanitize_str_field_value(data.get('title')).strip() @@ -814,8 +813,9 @@ def get_exported_value(field_type, field_value): if isinstance(field_value, int): if ft and ft.name == 'date': if field_value != 0: - dt = datetime.datetime.fromtimestamp(int(field_value // 1000)).date() - yield str(dt) + dt = utils.millis_to_datetime(field_value) + if dt: + yield str(dt.date()) else: yield str(field_value) elif isinstance(field_value, list): diff --git a/keepercommander/vault_extensions.py b/keepercommander/vault_extensions.py index 1fc9bc2ee..118dae863 100644 --- a/keepercommander/vault_extensions.py +++ b/keepercommander/vault_extensions.py @@ -90,6 +90,15 @@ def find_records(params, # type: KeeperParams if pattern: is_match = matches_record(record, pattern, search_fields) + if not is_match and isinstance(record, vault.TypedRecord): + field = record.get_typed_field('fileRef') + if field and isinstance(field.value, list): + for file_uid in field.value: + file_record = vault.KeeperRecord.load(params, file_uid) + if isinstance(file_record, vault.FileRecord): + is_match = matches_record(file_record, pattern, search_fields) + if is_match: + break else: is_match = True if is_match: diff --git a/requirements.txt b/requirements.txt index dc4fd7fc8..a58194f81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,13 +5,14 @@ prompt_toolkit pycryptodomex pyperclip tabulate +textual>=0.82.0 websockets fido2>=2.0.0; python_version>='3.10' requests>=2.31.0 cryptography>=39.0.1 protobuf>=4.23.0 keeper-secrets-manager-core>=16.6.0 -keeper_pam_webrtc_rs; python_version>='3.8' +keeper_pam_webrtc_rs>=1.2.1; python_version>='3.8' pydantic>=2.6.4; python_version>='3.8' flask; python_version>='3.8' pyngrok>=7.5.0 @@ -24,4 +25,5 @@ pyobjc-framework-LocalAuthentication; sys_platform == "darwin" and python_versio winrt-runtime; sys_platform == "win32" winrt-Windows.Foundation; sys_platform == "win32" winrt-Windows.Security.Credentials.UI; sys_platform == "win32" -keeper-mlkem; python_version>='3.11' \ No newline at end of file +keeper-mlkem; python_version>='3.11' +textual; python_version>='3.9' \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 3a673a360..3e145e711 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ install_requires = requests>=2.31.0 tabulate websockets - keeper_pam_webrtc_rs; python_version>='3.8' + keeper_pam_webrtc_rs>=1.2.1; python_version>='3.8' pydantic>=2.6.4; python_version>='3.8' fpdf2>=2.8.3 cbor2; sys_platform == "darwin" and python_version>='3.10' @@ -52,6 +52,7 @@ install_requires = winrt-Windows.Foundation; sys_platform == "win32" and python_version>='3.10' winrt-Windows.Security.Credentials.UI; sys_platform == "win32" and python_version>='3.10' keeper-mlkem; python_version>='3.11' + textual; python_version>='3.9' [options.package_data] keepercommander = resources/*, resources/email_templates/*