diff --git a/keepercommander/command_categories.py b/keepercommander/command_categories.py index a5046214a..dbb4945b2 100644 --- a/keepercommander/command_categories.py +++ b/keepercommander/command_categories.py @@ -23,7 +23,7 @@ # Record Type Commands 'Record Type Commands': { - 'record-type-info', 'record-type', 'convert' + 'record-type-info', 'record-type', 'convert', 'convert-all' }, # Import and Exporting Data diff --git a/keepercommander/commands/aram.py b/keepercommander/commands/aram.py index 6651c7be5..26b4a7a86 100644 --- a/keepercommander/commands/aram.py +++ b/keepercommander/commands/aram.py @@ -143,10 +143,10 @@ aging_report_parser.exit = suppress_exit action_report_parser = argparse.ArgumentParser(prog='action-report', description='Run an action based on user activity', parents=[report_output_parser]) -action_report_target_statuses = ['no-logon', 'no-update', 'locked', 'invited', 'no-recovery'] +action_report_target_statuses = ['all','no-logon', 'no-update', 'locked', 'invited', 'no-recovery'] action_report_parser.add_argument('--target', '-t', dest='target_user_status', action='store', - choices=action_report_target_statuses, default='no-logon', - help='user status to report on') + choices=action_report_target_statuses, default='all', + help='user status to report on') action_report_parser.add_argument('--days-since', '-d', dest='days_since', action='store', type=int, help='number of days since event of interest (e.g., login, record add/update, lock)') action_report_columns = {'name', 'status', 'transfer_status', 'node', 'team_count', 'teams', 'role_count', 'roles', @@ -156,16 +156,20 @@ action_report_parser.add_argument('--columns', dest='columns', action='store', type=str, help=columns_help) action_report_parser.add_argument('--apply-action', '-a', dest='apply_action', action='store', - choices=['lock', 'delete', 'transfer', 'none'], default='none', + choices=['lock', 'delete', 'transfer', 'move', 'none'], default='none', help='admin action to apply to each user in the report') target_user_help = 'username/email of account to transfer users to when --apply-action=transfer is specified' action_report_parser.add_argument('--target-user', action='store', help=target_user_help) +target_node_help = 'Node name/ID to move users to when --apply-action=move is specified' +action_report_parser.add_argument('--target-node', action='store', help=target_node_help) action_report_parser.add_argument('--dry-run', '-n', dest='dry_run', default=False, action='store_true', help='flag to enable dry-run mode') force_action_help = 'skip confirmation prompt when applying irreversible admin actions (e.g., delete, transfer)' action_report_parser.add_argument('--force', '-f', action='store_true', help=force_action_help) node_filter_help = 'filter users by node (node name or ID)' action_report_parser.add_argument('--node', dest='node', action='store', help=node_filter_help) +recursive_help = 'Search in node and subnodes when --node is specified' +action_report_parser.add_argument('--recursive', action='store_true', help=recursive_help) syslog_templates = None # type: Optional[List[str]] @@ -2264,14 +2268,14 @@ def transfer_accounts(from_users, to_user, dryrun=False): def apply_admin_action(targets, status='no-update', action='none', dryrun=False): # type: (List[Dict[str, Any]], Optional[str], Optional[str], Optional[bool]) -> str - default_allowed = {'none'} + default_allowed = {'none', 'move'} status_actions = { 'no-logon': {*default_allowed, 'lock'}, 'no-update': {*default_allowed}, 'locked': {*default_allowed, 'delete', 'transfer'}, 'invited': {*default_allowed, 'delete'}, 'no-recovery': default_allowed, - 'blocked': {*default_allowed, 'delete'} + 'all': default_allowed } actions_allowed = status_actions.get(status) @@ -2285,12 +2289,15 @@ def apply_admin_action(targets, status='no-update', action='none', dryrun=False) action_handlers = { 'none': partial(run_cmd, targets, None, None, dryrun), 'lock': partial(run_cmd, targets, - lambda: exec_fn(params, email=emails, lock=True, force=True, return_results=True), - 'lock', dryrun), + lambda: exec_fn(params, email=emails, lock=True, force=True, return_results=True), + 'lock', dryrun), 'delete': partial(run_cmd, targets, - lambda: exec_fn(params, email=emails, delete=True, force=True, return_results=True), - 'delete', dryrun), - 'transfer': partial(transfer_accounts, targets, kwargs.get('target_user'), dryrun) + lambda: exec_fn(params, email=emails, delete=True, force=True, return_results=True), + 'delete', dryrun), + 'transfer': partial(transfer_accounts, targets, kwargs.get('target_user'), dryrun), + 'move': partial(run_cmd, targets, + lambda: exec_fn(params, email=emails, node=kwargs.get('target_node'), force=True, return_results=True), + 'move', dryrun) } if action in ('delete', 'transfer') and not dryrun and not kwargs.get('force') and targets: @@ -2356,11 +2363,11 @@ def get_report_data_and_headers(targets, output_fmt, columns=None, lock_times=No logging.warning(f'More than one node "{node_name}" found. Use Node ID.') return - target_node_id = nodes[0]['node_id'] + node_id = nodes[0]['node_id'] - # Validate target_node_id - if not isinstance(target_node_id, int) or target_node_id <= 0: - logging.warning(f'Invalid node ID: {target_node_id}') + # Validate node_id + if not isinstance(node_id, int) or node_id <= 0: + logging.warning(f'Invalid node ID: {node_id}') return # Build parent-child lookup dictionary once to avoid deep recursion @@ -2386,7 +2393,7 @@ def get_descendant_nodes(node_id): queue.append(child_id) return descendants - target_nodes = get_descendant_nodes(target_node_id) + target_nodes = get_descendant_nodes(node_id) if kwargs.get('recursive') else [node_id] filtered_user_ids = {user['enterprise_user_id'] for user in params.enterprise.get('users', []) if user.get('node_id') in target_nodes} @@ -2397,9 +2404,10 @@ def get_descendant_nodes(node_id): target_status = kwargs.get('target_user_status', 'no-logon') days = kwargs.get('days_since') if days is None: - days = 90 if target_status == 'locked' else 30 + days = 90 if target_status == 'locked' else 0 if target_status == 'all' else 30 args_by_status = { + 'all': [active+locked+invited,days,[]], 'no-logon': [active, days, ['login', 'login_console', 'chat_login', 'accept_invitation']], 'no-update': [active, days, ['record_add', 'record_update']], 'locked': [locked, days, ['lock_user'], 'to_username'], @@ -2413,7 +2421,7 @@ def get_descendant_nodes(node_id): logging.warning(f'Invalid target_user_status \'{target_status}\': value must be one of {valid_targets}') return - target_users = get_no_action_users(*args) + target_users = args[0] usernames = {user['username'] for user in target_users} columns_arg = kwargs.get('columns') diff --git a/keepercommander/commands/convert.py b/keepercommander/commands/convert.py index 79591d86c..697a056cd 100644 --- a/keepercommander/commands/convert.py +++ b/keepercommander/commands/convert.py @@ -14,25 +14,35 @@ import json import logging import re +import shutil +import sys +import time from collections import OrderedDict -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, List, Dict from ..utils import is_url, is_email -from .base import raise_parse_exception, suppress_exit, Command +from .base import raise_parse_exception, suppress_exit, user_choice, Command from .folder import get_folder_path from .. import api, crypto, loginv3, utils from ..params import KeeperParams from ..proto import record_pb2 from ..subfolder import try_resolve_path, find_parent_top_folder +CONVERT_BATCH_LIMIT = 999 + def register_commands(commands): commands['convert'] = ConvertCommand() + commands['convert-all'] = ConvertAllCommand() def register_command_info(aliases, command_info): + aliases['ca'] = 'convert-all' command_info[convert_parser.prog] = convert_parser.description + command_info[convert_all_parser.prog] = convert_all_parser.description + +# Argument Parsers convert_parser = argparse.ArgumentParser(prog='convert', description='Convert record(s) to use record types') convert_parser.add_argument( @@ -64,48 +74,333 @@ def register_command_info(aliases, command_info): convert_parser.error = raise_parse_exception convert_parser.exit = suppress_exit +convert_all_parser = argparse.ArgumentParser( + prog='convert-all', + description='Convert all legacy General records in the vault to a typed record format' +) +convert_all_parser.add_argument( + '-t', '--record-type', '--record_type', dest='record_type', action='store', + help='Target record type (default: login)' +) +convert_all_parser.add_argument( + '-ia', '--include-attachments', dest='include_attachments', action='store_true', + help='Include records with file attachments. ' + 'Note: file attachments may not be decrypted by Keeper clients after conversion.' +) +convert_all_parser.add_argument( + '-f', '--force', dest='force', action='store_true', + help='Skip confirmation prompt' +) +convert_all_parser.add_argument( + '-n', '--dry-run', dest='dry_run', action='store_true', + help='Preview records that would be converted without making changes' +) +convert_all_parser.add_argument( + '-io', '--ignore-ownership', dest='ignore_owner', action='store_true', + help='Include records not owned by the current user' +) +convert_all_parser.error = raise_parse_exception +convert_all_parser.exit = suppress_exit -def get_matching_records_from_folder(params, folder_uid, regex, url_regex, attachments=False, ignore_owner=False): - records = [] - if folder_uid in params.subfolder_record_cache: - for uid in params.subfolder_record_cache[folder_uid]: - if not ignore_owner: - if uid not in params.record_owner_cache: - continue - own = params.record_owner_cache[uid] - if not own.owner is True: - continue - if uid not in params.record_cache: - continue - rv = params.record_cache[uid].get('version', 0) - if rv != 2: - continue - r = api.get_record(params, uid) - if not attachments and r.attachments: - continue - r_attrs = (r.title, r.record_uid) - if any(attr for attr in r_attrs if isinstance(attr, str) and len(attr) > 0 and regex(attr) is not None): - if url_regex: - url_match = r.login_url and url_regex(r.login_url) is not None - else: - url_match = True - if url_match: - records.append(r) - return records - - -def recurse_folder(params, folder_uid, folder_path, records_by_folder, regex, url_regex, recurse, attachments=False, ignore_owner=False): - folder_records = get_matching_records_from_folder(params, folder_uid, regex, url_regex, attachments, ignore_owner) - if len(folder_records) > 0: - if folder_uid not in folder_path: - folder_path[folder_uid] = get_folder_path(params, folder_uid) - records_by_folder[folder_uid] = set() - records_by_folder[folder_uid].update(folder_records) - - if recurse: - folder = params.folder_cache[folder_uid] if folder_uid else params.root_folder - for subfolder_uid in folder.subfolders: - recurse_folder(params, subfolder_uid, folder_path, records_by_folder, regex, url_regex, recurse, attachments, ignore_owner) + +class ConvertHelper: + """Shared utilities for v2-to-v3 record conversion used by both convert and convert-all.""" + + @staticmethod + def validate_v3_enabled(params): + # type: (KeeperParams) -> bool + if params.settings and isinstance(params.settings.get('record_types_enabled'), bool): + return params.settings['record_types_enabled'] + return False + + @staticmethod + def resolve_record_type(params, record_type_name): + # type: (KeeperParams, str) -> Optional[dict] + available_types = [json.loads(x) for x in params.record_type_cache.values()] + type_info = next((x for x in available_types if x.get('$id') == record_type_name), None) + if type_info is None: + valid = ', '.join(sorted(x.get('$id') for x in available_types)) + logging.warning('Specified record type "%s" is not valid. Valid types are:\n%s', + record_type_name, valid) + return type_info + + @staticmethod + def is_v2_record(params, uid): + # type: (KeeperParams, str) -> bool + if uid not in params.record_cache: + return False + return params.record_cache[uid].get('version', 0) == 2 + + @staticmethod + def is_record_owner(params, uid): + # type: (KeeperParams, str) -> bool + if uid not in params.record_owner_cache: + return False + return bool(params.record_owner_cache[uid].owner) + + @staticmethod + def has_attachments(params, uid): + # type: (KeeperParams, str) -> bool + record = api.get_record(params, uid) + return bool(record and record.attachments) + + @staticmethod + def infer_field_type(field_value): + # type: (Optional[str]) -> str + if not field_value: + return 'text' + if len(field_value) > 128: + return 'note' + if is_url(field_value): + return 'url' + if is_email(field_value): + return 'email' + return 'text' + + @staticmethod + def build_v2_to_v3_data(record_uid, params, type_info): + # type: (str, KeeperParams, dict) -> Optional[Tuple[dict, list]] + if not (record_uid and params and params.record_cache and record_uid in params.record_cache): + logging.warning('Record %s not found.', record_uid) + return None + + record = params.record_cache[record_uid] + if (record.get('version') or 0) != 2: + logging.warning('Record %s is not version 2.', record_uid) + return None + + try: + raw_data = record.get('data_unencrypted') + raw_extra = record.get('extra_unencrypted') + data = raw_data if isinstance(raw_data, dict) else json.loads(raw_data or '{}') + extra = raw_extra if isinstance(raw_extra, dict) else json.loads(raw_extra or '{}') + except (json.JSONDecodeError, TypeError) as e: + logging.warning('Record %s has malformed data: %s', record_uid, e) + return None + + v2_fields = ConvertHelper._extract_v2_fields(data, extra) + totp_extras = v2_fields.pop('_totp_extras', []) + fields, custom = ConvertHelper._map_fields_to_v3(v2_fields, type_info, data.get('custom') or []) + + if totp_extras: + custom.extend({'type': 'oneTimeCode', 'value': [x]} for x in totp_extras if x) + + v3_data = { + 'title': data.get('title') or '', + 'type': type_info['$id'], + 'fields': fields, + 'custom': custom, + 'notes': data.get('notes') or '' + } + file_info = extra.get('files') or [] + return v3_data, file_info + + @staticmethod + def _extract_v2_fields(data, extra): + # type: (dict, dict) -> dict + extra_fields = extra.get('fields') or [] + otps = [x['data'] for x in extra_fields if x.get('field_type') == 'totp' and 'data' in x] + + v2_fields = {} + for key, v2_key in [('login', 'secret1'), ('password', 'secret2'), ('url', 'link')]: + value = data.get(v2_key) or '' + if value: + v2_fields[key] = value + + if otps: + v2_fields['oneTimeCode'] = otps[0] + v2_fields['_totp_extras'] = otps[1:] + else: + v2_fields['_totp_extras'] = [] + + return v2_fields + + @staticmethod + def _map_fields_to_v3(v2_fields, type_info, custom_v2): + # type: (dict, dict, list) -> Tuple[list, list] + fields = [] + for field_def in type_info.get('fields', []): + ref = field_def.get('$ref', 'text') + label = field_def.get('label') + typed_field = {'type': ref, 'value': []} + if label: + typed_field['label'] = label + if not label and ref in v2_fields: + typed_field['value'].append(v2_fields.pop(ref)) + fields.append(typed_field) + + custom = [] + custom.extend({'type': k, 'value': [v]} for k, v in v2_fields.items() if k != '_totp_extras') + custom.extend( + { + 'type': ConvertHelper.infer_field_type(entry.get('value')), + 'label': entry.get('name') or '', + 'value': [entry['value']] if entry.get('value') else [] + } + for entry in custom_v2 if entry.get('name') or entry.get('value') + ) + return fields, custom + + @staticmethod + def build_convert_request(record_uid, params, type_info): + # type: (str, KeeperParams, dict) -> Optional[record_pb2.RecordConvertToV3] + convert_result = ConvertHelper.build_v2_to_v3_data(record_uid, params, type_info) + if not convert_result: + return None + + v3_data, file_info = convert_result + + try: + record_key = params.record_cache[record_uid]['record_key_unencrypted'] + except KeyError: + logging.warning('Record %s is missing its encryption key.', record_uid) + return None + + record = api.get_record(params, record_uid) + if not record: + logging.warning('Record %s could not be loaded.', record_uid) + return None + + rc = record_pb2.RecordConvertToV3() + rc.record_uid = loginv3.CommonHelperMethods.url_safe_str_to_bytes(record_uid) + rc.client_modified_time = api.current_milli_time() + rc.revision = record.revision + + if file_info: + ConvertHelper._attach_file_refs(rc, v3_data, file_info, record_key, params.data_key) + + rc.data = crypto.encrypt_aes_v2(api.get_record_data_json_bytes(v3_data), record_key) + + ConvertHelper._attach_shared_folder_keys(rc, params, record_uid, record_key) + ConvertHelper._attach_audit_data(rc, record, v3_data, params) + + return rc + + @staticmethod + def _attach_file_refs(rc, v3_data, file_info, record_key, data_key): + file_ref = next((x for x in v3_data['fields'] if x.get('type') == 'fileRef'), None) + if file_ref is None: + file_ref = {'type': 'fileRef'} + v3_data['fields'].append(file_ref) + if not isinstance(file_ref.get('value'), list): + file_ref['value'] = [] + + for f_info in file_info: + try: + file_uid = utils.generate_uid() + file_ref['value'].append(file_uid) + file_key = utils.base64_url_decode(f_info['key']) + + metadata = {k: f_info.get(k) for k in ('name', 'size', 'title', 'lastModified', 'type')} + + rf = record_pb2.RecordFileForConversion() + rf.record_uid = utils.base64_url_decode(file_uid) + rf.file_file_id = f_info['id'] + + thumbs = f_info.get('thumbs') or [] + if thumbs: + thumb = next((x for x in thumbs if isinstance(x, dict)), None) + if thumb: + rf.thumb_file_id = thumb.get('id', '') + + rf.data = crypto.encrypt_aes_v2(json.dumps(metadata).encode('utf-8'), file_key) + rf.record_key = crypto.encrypt_aes_v2(file_key, data_key) + rf.link_key = crypto.encrypt_aes_v2(file_key, record_key) + rc.record_file.append(rf) + except (KeyError, TypeError) as e: + logging.warning('Skipping file attachment due to incomplete metadata: %s', e) + + @staticmethod + def _attach_shared_folder_keys(rc, params, record_uid, record_key): + shared_folders = find_parent_top_folder(params, record_uid) + for shared_folder in shared_folders: + try: + sf = params.shared_folder_cache[shared_folder.uid] + fk = record_pb2.RecordFolderForConversion() + fk.folder_uid = utils.base64_url_decode(shared_folder.uid) + fk.record_folder_key = crypto.encrypt_aes_v2( + record_key, sf['shared_folder_key_unencrypted'] + ) + rc.folder_key.append(fk) + except KeyError: + logging.warning('Shared folder %s missing key, skipping folder key conversion.', shared_folder.uid) + + @staticmethod + def _attach_audit_data(rc, record, v3_data, params): + if not params.enterprise_ec_key: + return + audit_data = { + 'title': record.title or '', + 'record_type': v3_data['type'], + } + if record.login_url: + audit_data['url'] = utils.url_strip(record.login_url) + rc.audit.data = crypto.encrypt_ec( + json.dumps(audit_data).encode('utf-8'), params.enterprise_ec_key + ) + + @staticmethod + def send_batch(params, records, record_names=None, quiet=False): + # type: (KeeperParams, list, Optional[dict], bool) -> Tuple[int, int, list] + """Send a batch of RecordConvertToV3 to the API. Returns (success_count, fail_count, failures).""" + successes = 0 + failures = 0 + failed_list = [] # type: List[Tuple[str, str, str]] + + params.sync_data = True + rq = record_pb2.RecordsConvertToV3Request() + rq.records.extend(records) + rq.client_time = api.current_milli_time() + + rs = api.communicate_rest(params, rq, 'vault/records_convert3', + rs_type=record_pb2.RecordsModifyResponse) + + for r in rs.records: + uid = loginv3.CommonHelperMethods.bytes_to_url_safe_str(r.record_uid) + name = (record_names or {}).get(uid, uid) + if r.status == record_pb2.RS_SUCCESS: + successes += 1 + else: + failures += 1 + failed_list.append((uid, name, r.message)) + + if not quiet and record_names: + converted = [ + ' %s %s' % (loginv3.CommonHelperMethods.bytes_to_url_safe_str(r.record_uid), + record_names.get(loginv3.CommonHelperMethods.bytes_to_url_safe_str(r.record_uid), '')) + for r in rs.records if r.status == record_pb2.RS_SUCCESS + ] + if converted: + logging.info('Successfully converted the following %d record(s):', len(converted)) + logging.info('\n'.join(converted)) + + if failed_list: + logging.warning('Failed to convert the following %d record(s):', len(failed_list)) + logging.warning('\n'.join('%s %s : %s' % f for f in failed_list)) + + return successes, failures, failed_list + + @staticmethod + def render_progress(current, total, start_time, bar_width=30): + elapsed = time.time() - start_time + avg = elapsed / current if current > 0 else 0 + filled = int(bar_width * current / total) if total > 0 else bar_width + bar = '#' * filled + '-' * (bar_width - filled) + try: + term_width = shutil.get_terminal_size().columns + except Exception: + term_width = 80 + line = '\r [%s] %d/%d (%.1fs avg)' % (bar, current, total, avg) + sys.stderr.write(line[:term_width]) + sys.stderr.flush() + + @staticmethod + def print_failures(failed_records): + # type: (List[Tuple[str, str, str]]) -> None + logging.warning('') + logging.warning('Failed records:') + for uid, title, message in failed_records: + logging.warning(' %s %s — %s', uid, title, message) class ConvertCommand(Command): @@ -113,43 +408,56 @@ def get_parser(self): return convert_parser def execute(self, params, **kwargs): - if params.settings and isinstance(params.settings.get('record_types_enabled'), bool): - v3_enabled = params.settings.get('record_types_enabled') - else: - v3_enabled = False - if not v3_enabled: - logging.warning(f'Cannot convert record(s) if record types is not enabled') + if not ConvertHelper.validate_v3_enabled(params): + logging.warning('Cannot convert record(s) if record types is not enabled') return - recurse = kwargs.get('recursive', False) - url_pattern = kwargs.get('url') - url_regex = re.compile(fnmatch.translate(url_pattern)).match if url_pattern else None - record_type = kwargs.get('record_type') or 'login' - available_types = [json.loads(x) for x in params.record_type_cache.values()] - type_info = next((x for x in available_types if x.get('$id') == record_type), None) + type_info = ConvertHelper.resolve_record_type(params, record_type) if type_info is None: - logging.warning( - f'Specified record type "{record_type}" is not valid. ' - f'Valid types are:\n{", ".join(sorted((x.get("$id") for x in available_types)))}' - ) return - attachments = kwargs.get('force', False) - if not isinstance(attachments, bool): - attachments = False - ignore_owner = kwargs.get('ignore_owner', False) - if not isinstance(ignore_owner, bool): - ignore_owner = False + include_attachments = bool(kwargs.get('force', False)) + ignore_owner = bool(kwargs.get('ignore_owner', False)) + quiet = bool(kwargs.get('quiet', False)) pattern_list = kwargs.get('record-uid-name-patterns', []) - if len(pattern_list) == 0: - logging.warning(f'Please specify a record to convert') + if not pattern_list: + logging.warning('Please specify a record to convert') return + url_pattern = kwargs.get('url') + try: + url_regex = re.compile(fnmatch.translate(url_pattern)).match if url_pattern else None + except re.error as e: + logging.warning('Invalid URL pattern "%s": %s', url_pattern, e) + return + + recurse = kwargs.get('recursive', False) + + record_uids, record_names = self._find_matching_records( + params, pattern_list, url_regex, url_pattern, recurse, include_attachments, ignore_owner + ) + if not record_uids: + return + + if kwargs.get('dry_run', False): + print( + 'The following %d record(s) that you own were matched' + ' and would be converted to records with type "%s":' % (len(record_uids), record_type) + ) + print('\n'.join(' %s %s' % (k, v) for k, v in record_names.items())) + return + + self._convert_and_send(params, record_uids, record_names, type_info, quiet) + + @staticmethod + def _find_matching_records(params, pattern_list, url_regex, url_pattern, recurse, include_attachments, ignore_owner): + # type: (...) -> Tuple[set, OrderedDict] folder = params.folder_cache.get(params.current_folder, params.root_folder) - records_by_folder = {} # type: Dict + records_by_folder = {} # type: Dict folder_path = {} + for pattern in pattern_list: if pattern in params.record_cache: record = api.get_record(params, pattern) @@ -160,239 +468,285 @@ def execute(self, params, **kwargs): records_by_folder[''] = set() records_by_folder[''].add(record) continue + rs = try_resolve_path(params, pattern) if rs is not None: folder, pattern = rs - regex = re.compile(fnmatch.translate(pattern)).match if pattern else None + + try: + regex = re.compile(fnmatch.translate(pattern)).match if pattern else None + except re.error as e: + logging.warning('Invalid pattern "%s": %s', pattern, e) + continue folder_uid = folder.uid or '' - recurse_folder(params, folder_uid, folder_path, records_by_folder, regex, url_regex, recurse, - attachments=attachments, ignore_owner=ignore_owner) + ConvertCommand._recurse_folder( + params, folder_uid, folder_path, records_by_folder, regex, + url_regex, recurse, include_attachments, ignore_owner + ) - if len(records_by_folder) == 0: + if not records_by_folder: patterns = ', '.join(pattern_list) - msg = f'No records that can be converted to record types can be found for pattern "{patterns}"' + msg = 'No records that can be converted to record types can be found for pattern "%s"' % patterns if url_pattern: - msg += f' with url matching "{url_pattern}"' + msg += ' with url matching "%s"' % url_pattern logging.warning(msg) - return + return set(), OrderedDict() - # Sort records and if dry run print record_uids = set() record_names = OrderedDict() - for folder_uid in sorted(folder_path, key=lambda x: folder_path[x]): - path = folder_path[folder_uid] - for record in sorted(records_by_folder[folder_uid], key=lambda x: getattr(x, 'title')): + for fuid in sorted(folder_path, key=lambda x: folder_path[x]): + path = folder_path[fuid] + for record in sorted(records_by_folder[fuid], key=lambda x: getattr(x, 'title', '')): if record.record_uid not in record_uids: record_uids.add(record.record_uid) record_names[record.record_uid] = path + record.title - dry_run = kwargs.get('dry_run', False) - if dry_run: - print( - f'The following {len(record_uids)} record(s) that you own were matched' - f' and would be converted to records with type "{record_type}":' - ) + return record_uids, record_names - print('\n'.join(f' {k} {v}' for k, v in record_names.items())) - else: - records = [] - for record_uid in record_uids: - convert_result = ConvertCommand.convert_to_record_type_data(record_uid, params, type_info) - if not convert_result: - logging.warning(f'Conversion failed for {record_names[record_uid]} ({record_uid})\n') - continue - v3_data, file_info = convert_result - record_key = params.record_cache[record_uid]['record_key_unencrypted'] - - rc = record_pb2.RecordConvertToV3() - rc.record_uid = loginv3.CommonHelperMethods.url_safe_str_to_bytes(record_uid) - rc.client_modified_time = api.current_milli_time() - record = api.get_record(params, record_uid) - rc.revision = record.revision - - if file_info: - file_ref = next((x for x in v3_data['fields'] if x.get('type') == 'fileRef'), None) - if file_ref is None: - file_ref = {'type': 'fileRef'} - v3_data['fields'].append(file_ref) - if not isinstance(file_ref.get('value'), list): - file_ref['value'] = [] - - for f_info in file_info: - file_uid = utils.generate_uid() - file_ref['value'].append(file_uid) - file_key = utils.base64_url_decode(f_info['key']) - - data = {} - for k in ('name', 'size', 'title', 'lastModified', 'type'): - data[k] = f_info[k] - - rf = record_pb2.RecordFileForConversion() - rf.record_uid = utils.base64_url_decode(file_uid) - rf.file_file_id = f_info['id'] - if 'thumbs' in f_info: - thumbs = f_info['thumbs'] - if len(thumbs) > 0: - thumb = next((x for x in thumbs if isinstance(x, dict)), None) - if thumb: - rf.thumb_file_id = thumbs[0]['id'] - rf.data = crypto.encrypt_aes_v2(json.dumps(data).encode('utf-8'), file_key) - rf.record_key = crypto.encrypt_aes_v2(file_key, params.data_key) - rf.link_key = crypto.encrypt_aes_v2(file_key, record_key) - rc.record_file.append(rf) - rc.data = crypto.encrypt_aes_v2(api.get_record_data_json_bytes(v3_data), record_key) - - # Get share folder of the record so that we can convert the Record Folder Key - shared_folders = find_parent_top_folder(params, record_uid) - - for shared_folder in shared_folders: - sf = params.shared_folder_cache[shared_folder.uid] - folder_key = record_pb2.RecordFolderForConversion() - folder_key.folder_uid = utils.base64_url_decode(shared_folder.uid) - folder_key.record_folder_key = crypto.encrypt_aes_v2(record_key, sf['shared_folder_key_unencrypted']) - rc.folder_key.append(folder_key) - - if params.enterprise_ec_key: - audit_data = { - 'title': record.title or '', - 'record_type': v3_data['type'], - } - if record.login_url: - audit_data['url'] = utils.url_strip(record.login_url) - rc.audit.data = crypto.encrypt_ec(json.dumps(audit_data).encode('utf-8'), params.enterprise_ec_key) - - records.append(rc) - - quiet = kwargs.get('quiet', False) + @staticmethod + def _get_matching_records_from_folder(params, folder_uid, regex, url_regex, include_attachments, ignore_owner): + records = [] + if folder_uid not in params.subfolder_record_cache: + return records + for uid in params.subfolder_record_cache[folder_uid]: + if not ignore_owner and not ConvertHelper.is_record_owner(params, uid): + continue + if not ConvertHelper.is_v2_record(params, uid): + continue + r = api.get_record(params, uid) + if not r: + continue + if not include_attachments and r.attachments: + continue + r_attrs = (r.title, r.record_uid) + if not any(isinstance(a, str) and a and regex(a) is not None for a in r_attrs): + continue + if url_regex and not (r.login_url and url_regex(r.login_url) is not None): + continue + records.append(r) + return records + + @staticmethod + def _recurse_folder(params, folder_uid, folder_path, records_by_folder, regex, + url_regex, recurse, include_attachments, ignore_owner): + folder_records = ConvertCommand._get_matching_records_from_folder( + params, folder_uid, regex, url_regex, include_attachments, ignore_owner + ) + if folder_records: + if folder_uid not in folder_path: + folder_path[folder_uid] = get_folder_path(params, folder_uid) + records_by_folder[folder_uid] = set() + records_by_folder[folder_uid].update(folder_records) + + if recurse: + folder = params.folder_cache[folder_uid] if folder_uid else params.root_folder + for subfolder_uid in folder.subfolders: + ConvertCommand._recurse_folder( + params, subfolder_uid, folder_path, records_by_folder, regex, + url_regex, recurse, include_attachments, ignore_owner + ) + + @staticmethod + def _convert_and_send(params, record_uids, record_names, type_info, quiet): + records = [] + for record_uid in record_uids: + rc = ConvertHelper.build_convert_request(record_uid, params, type_info) + if not rc: + logging.warning('Conversion failed for %s (%s)', record_names.get(record_uid, ''), record_uid) + continue + records.append(rc) + + if not quiet: + logging.info('Matched %d record(s)', len(record_uids)) + + if not records: if not quiet: - logging.info(f'Matched {len(record_uids)} record(s)') + logging.info('No records successfully converted') + return - if len(records) == 0: - if not quiet: - logging.info('No records successfully converted') - return + while records: + batch = records[:CONVERT_BATCH_LIMIT] + records = records[CONVERT_BATCH_LIMIT:] + ConvertHelper.send_batch(params, batch, record_names, quiet) - while len(records) > 0: - rq = record_pb2.RecordsConvertToV3Request() - rq.records.extend(records[:999]) - records = records[999:] - - params.sync_data = True - rq.client_time = api.current_milli_time() - records_modify_rs = api.communicate_rest(params, rq, 'vault/records_convert3', - rs_type=record_pb2.RecordsModifyResponse) - if not quiet: - converted_record_names = [ - f' {utils.base64_url_encode(r.record_uid)} {record_names[loginv3.CommonHelperMethods.bytes_to_url_safe_str(r.record_uid)]}' - for r in records_modify_rs.records if r.status == record_pb2.RS_SUCCESS - ] - if len(converted_record_names) > 0: - logging.info(f'Successfully converted the following {len(converted_record_names)} record(s):') - logging.info('\n'.join(converted_record_names)) - - convert_errors = [(f' {utils.base64_url_encode(x.record_uid)} {record_names[loginv3.CommonHelperMethods.bytes_to_url_safe_str(x.record_uid)]}', x.message) - for x in records_modify_rs.records if x.status != record_pb2.RS_SUCCESS] - if len(convert_errors) > 0: - logging.warning(f'Failed to convert the following {len(convert_errors)} record(s):') - logging.warning('\n'.join((f'{x[0]} : {x[1]}' for x in convert_errors))) - @staticmethod - def get_v3_field_type(field_value): - return_type = 'text' - if field_value: - if is_url(field_value): - return_type = 'url' - elif is_email(field_value): - return_type = 'email' - if len(field_value) > 128: - return_type = 'note' - return return_type +class ConvertAllCommand(Command): + def get_parser(self): + return convert_all_parser - @staticmethod - def convert_to_record_type_data(record_uid, params, type_info): - # type: (str, KeeperParams, dict) -> Optional[Tuple[dict, list]] + def execute(self, params, **kwargs): + if not ConvertHelper.validate_v3_enabled(params): + logging.warning('Cannot convert records: record types is not enabled for this account.') + return - if not (record_uid and params and params.record_cache and record_uid in params.record_cache): - logging.warning('Record %s not found.', record_uid) + record_type = kwargs.get('record_type') or 'login' + type_info = ConvertHelper.resolve_record_type(params, record_type) + if type_info is None: return - record = params.record_cache[record_uid] - version = record.get('version') or 0 - if version != 2: - logging.warning('Record %s is not version 2.', record_uid) + include_attachments = kwargs.get('include_attachments', False) + ignore_owner = kwargs.get('ignore_owner', False) + dry_run = kwargs.get('dry_run', False) + force = kwargs.get('force', False) + + api.sync_down(params) + + partition = self._partition_v2_records(params, ignore_owner) + + if not partition['all']: + logging.info('No General-type records found in the vault. Nothing to convert.') return - data = record.get('data_unencrypted') - extra = record.get('extra_unencrypted') + record_uids = partition['all'] if include_attachments else partition['without_attachments'] + skipped_attachment_count = len(partition['with_attachments']) if not include_attachments else 0 + skipped_not_owned_count = partition['skipped_not_owned'] - data = data if isinstance(data, dict) else json.loads(data or '{}') - extra = extra if isinstance(extra, dict) else json.loads(extra or '{}') + if not record_uids: + logging.info('All %d General-type record(s) have attachments. Use -ia to include them.', + len(partition['all'])) + return - # check for other non-convertible data - ex. fields[] has "field_type" != "totp" if present - extra_fields = extra.get('fields') or [] - otps = [x['data'] for x in extra_fields if 'totp' == (x.get('field_type') or '') and 'data' in x] - totp = otps[0] if len(otps) > 0 else '' - otps = otps[1:] - # label = otp.get('field_title') or '' - - title = data.get('title') or '' - login = data.get('secret1') or '' - password = data.get('secret2') or '' - url = data.get('link') or '' - v2_fields = {} - if login: - v2_fields['login'] = login - if password: - v2_fields['password'] = password - if url: - v2_fields['url'] = url - if totp: - v2_fields['oneTimeCode'] = totp - - notes = data.get('notes') or '' - custom2 = data.get('custom') or [] - # custom.type - Always "text" for legacy reasons. - fields = [] - for field in type_info.get('fields', []): - ref = field.get('$ref', 'text') - label = field.get('label') - typed_field = { - 'type': ref, - 'value': [] - } - if label: - typed_field['label'] = label - if not label and ref in v2_fields: - value = v2_fields.pop(ref) - typed_field['value'].append(value) - fields.append(typed_field) + logging.info('Found %d General-type record(s) to convert to "%s".', len(record_uids), record_type) + if include_attachments and partition['with_attachments']: + logging.warning( + ' %d record(s) have file attachments. ' + 'Note: file attachments may not be decrypted by Keeper clients after conversion.', + len(partition['with_attachments']) + ) + if skipped_attachment_count > 0: + logging.info(' %d record(s) with attachments were skipped. Use -ia to include them.', + skipped_attachment_count) + if skipped_not_owned_count > 0: + logging.info(' %d record(s) not owned by you were skipped. Use -io to include them.', + skipped_not_owned_count) - custom = [] - custom.extend(({ - 'type': x[0], - 'value': [x[1]], - } for x in v2_fields.items())) - custom.extend(({ - 'type': ConvertCommand.get_v3_field_type(x.get('value')), - 'label': x.get('name') or '', - 'value': [x.get('value')] if x.get('value') else [] - } for x in custom2 if x.get('name') or x.get('value'))) - - # Add any remaining TOTP codes to custom[] - if len(otps) > 0: - custom.extend(({ - 'type': 'oneTimeCode', - 'value': [x] - } for x in otps if x)) + if dry_run: + self._print_dry_run(params, record_uids) + return - v3_data = { - 'title': title, - 'type': type_info['$id'], - 'fields': fields, - 'custom': custom, - 'notes': notes + if not force: + if params.batch_mode: + logging.warning('Confirmation required. Use --force (-f) to skip in batch mode.') + return + answer = user_choice( + 'Convert %d General record(s) to "%s"?' % (len(record_uids), record_type), + 'yn', default='n' + ) + if answer.lower() != 'y': + logging.info('Operation cancelled.') + return + + self._execute_conversion(params, record_uids, type_info) + + @staticmethod + def _partition_v2_records(params, ignore_owner): + # type: (KeeperParams, bool) -> dict + """Returns dict with keys: all, with_attachments, without_attachments, skipped_not_owned.""" + all_v2 = [] + with_attachments = [] + without_attachments = [] + skipped_not_owned = 0 + + for uid in params.record_cache: + if not ConvertHelper.is_v2_record(params, uid): + continue + if not ConvertHelper.is_record_owner(params, uid): + if not ignore_owner: + skipped_not_owned += 1 + continue + all_v2.append(uid) + if ConvertHelper.has_attachments(params, uid): + with_attachments.append(uid) + else: + without_attachments.append(uid) + + return { + 'all': all_v2, + 'with_attachments': with_attachments, + 'without_attachments': without_attachments, + 'skipped_not_owned': skipped_not_owned, } - file_info = extra.get('files') or [] - return v3_data, file_info + @staticmethod + def _print_dry_run(params, record_uids): + logging.info('The following %d record(s) would be converted:\n', len(record_uids)) + for uid in record_uids: + record = api.get_record(params, uid) + title = record.title if record else '' + logging.info(' %s %s', uid, title) + + @staticmethod + def _execute_conversion(params, record_uids, type_info): + start_time = time.time() + converted_count = 0 + failed_count = 0 + failed_records = [] + + records_to_send = [] + record_names = {} + + logging.info('Preparing %d record(s) for conversion...', len(record_uids)) + + for record_uid in record_uids: + record = api.get_record(params, record_uid) + title = record.title if record else record_uid + record_names[record_uid] = title + + rc = ConvertHelper.build_convert_request(record_uid, params, type_info) + if not rc: + failed_count += 1 + failed_records.append((record_uid, title, 'Field conversion failed')) + continue + records_to_send.append(rc) + + if not records_to_send: + logging.warning('No records were successfully prepared for conversion.') + if failed_records: + ConvertHelper.print_failures(failed_records) + return + + total_to_send = len(records_to_send) + sent_count = 0 + + logging.info('Converting %d record(s)...', total_to_send) + + while records_to_send: + batch = records_to_send[:CONVERT_BATCH_LIMIT] + records_to_send = records_to_send[CONVERT_BATCH_LIMIT:] + + try: + successes, failures, batch_failures = ConvertHelper.send_batch( + params, batch, record_names, quiet=True + ) + converted_count += successes + failed_count += failures + failed_records.extend(batch_failures) + sent_count += successes + failures + except Exception as e: + for rc_item in batch: + uid = loginv3.CommonHelperMethods.bytes_to_url_safe_str(rc_item.record_uid) + failed_count += 1 + failed_records.append((uid, record_names.get(uid, uid), str(e))) + sent_count += 1 + + ConvertHelper.render_progress(sent_count, total_to_send, start_time) + + sys.stderr.write('\n') + sys.stderr.flush() + + elapsed = time.time() - start_time + total_processed = converted_count + failed_count + + logging.info('') + logging.info('Conversion complete.') + logging.info(' Converted: %d', converted_count) + if failed_count > 0: + logging.warning(' Failed: %d', failed_count) + logging.info(' Total: %d', total_processed) + logging.info(' Time: %.1fs', elapsed) + if total_processed > 0: + logging.info(' Average: %.2fs per record', elapsed / total_processed) + + if failed_records: + ConvertHelper.print_failures(failed_records) diff --git a/keepercommander/commands/discoveryrotation.py b/keepercommander/commands/discoveryrotation.py index b4db65add..1864a74a3 100644 --- a/keepercommander/commands/discoveryrotation.py +++ b/keepercommander/commands/discoveryrotation.py @@ -10,6 +10,7 @@ # import argparse import fnmatch +import itertools import json import logging import os.path @@ -19,7 +20,6 @@ from typing import Dict, Optional, Any, Set, List from urllib.parse import urlparse, urlunparse - import requests from keeper_secrets_manager_core.utils import url_safe_str_to_bytes @@ -47,6 +47,7 @@ from .record_edit import RecordEditMixin from .helpers.timeout import parse_timeout from .email_commands import find_email_config_record, load_email_config_from_record, update_oauth_tokens_in_record +from .enterprise_common import EnterpriseCommand from ..email_service import EmailSender, build_onboarding_email from .tunnel.port_forward.TunnelGraph import TunnelDAG from .tunnel.port_forward.tunnel_helpers import get_config_uid, get_keeper_tokens @@ -75,6 +76,8 @@ from .pam_debug.rotation_setting import PAMDebugRotationSettingsCommand from .pam_debug.vertex import PAMDebugVertexCommand from .pam_import.commands import PAMProjectCommand +from keepercommander.commands.pam_cloud.pam_privileged_workflow import PAMPrivilegedWorkflowCommand +from keepercommander.commands.pam_cloud.pam_privileged_access import PAMPrivilegedAccessCommand from .pam_launch.launch import PAMLaunchCommand from .pam_service.list import PAMActionServiceListCommand from .pam_service.add import PAMActionServiceAddCommand @@ -86,7 +89,6 @@ from .pam_saas.update import PAMActionSaasUpdateCommand from .tunnel_and_connections import PAMTunnelCommand, PAMConnectionCommand, PAMRbiCommand, PAMSplitCommand - # These characters are based on the Vault PAM_DEFAULT_SPECIAL_CHAR = '''!@#$%^?();',.=+[]<>{}-_/\\*&:"`~|''' @@ -105,6 +107,7 @@ def is_valid_number(n): parts = re.split(r'[,\-/]', field) return all(part == '*' or part in ('L', 'LW') or is_valid_number(part) for part in parts if part != '*') + def validate_cron_expression(expr, for_rotation=False): parts = expr.strip().split() @@ -113,11 +116,13 @@ def validate_cron_expression(expr, for_rotation=False): if for_rotation is True: if len(parts) != 6: return False, f"CRON: Rotation schedules require all 6 parts incl. seconds - ex. Daily at 04:00:00 cron: 0 0 4 * * ? got {len(parts)} parts" - if not(parts[3] == '?' or parts[5] == "?"): - logging.warning("CRON: Rotation schedule CRON format - must use ? character in one of these fields: day-of-week, day-of-month") + if not (parts[3] == '?' or parts[5] == "?"): + logging.warning( + "CRON: Rotation schedule CRON format - must use ? character in one of these fields: day-of-week, day-of-month") parts[3] = '*' if parts[3] == '?' else parts[3] parts[5] = '*' if parts[5] == '?' else parts[5] - logging.debug("WARNING! Validating CRON expression for rotation - if you get 500 type errors make sure to validate your CRON using web vault UI") + logging.debug( + "WARNING! Validating CRON expression for rotation - if you get 500 type errors make sure to validate your CRON using web vault UI") if len(parts) not in [5, 6]: return False, f"CRON: Expected 5 or 6 fields, got {len(parts)}" @@ -143,11 +148,12 @@ def validate_cron_expression(expr, for_rotation=False): return True, "Valid cron expression" + def parse_schedule_data(kwargs): schedule_json_data = kwargs.get('schedule_json_data') schedule_cron_data = kwargs.get('schedule_cron_data') schedule_on_demand = kwargs.get('on_demand') is True - schedule_data = None # type: Optional[List] + schedule_data = None # type: Optional[List] if isinstance(schedule_json_data, list): schedule_data = [json.loads(x) for x in schedule_json_data] elif isinstance(schedule_cron_data, list): @@ -155,9 +161,9 @@ def parse_schedule_data(kwargs): if schedule_cron_data and isinstance(schedule_cron_data[0], str): valid, err = validate_cron_expression(schedule_cron_data[0], for_rotation=True) if valid: - schedule_data = [{"type": "CRON", "cron": schedule_cron_data[0], "tz": "Etc/UTC"}] + schedule_data = [{"type": "CRON", "cron": schedule_cron_data[0], "tz": "Etc/UTC"}] else: - logging.error('', f'Invalid CRON "{schedule_cron_data[0]}" Error: {err}') + logging.error('', f'Invalid CRON "{schedule_cron_data[0]}" Error: {err}') elif schedule_on_demand is True: schedule_data = [] return schedule_data @@ -186,6 +192,10 @@ def __init__(self): self.register_command('rbi', PAMRbiCommand(), 'Manage Remote Browser Isolation', 'b') self.register_command('project', PAMProjectCommand(), 'PAM Project Import/Export', 'p') self.register_command('launch', PAMLaunchCommand(), 'Launch a connection to a PAM resource', 'l') + self.register_command('workflow', PAMPrivilegedWorkflowCommand(), + 'Manage workflow access operations', 'wf') + self.register_command('access', PAMPrivilegedAccessCommand(), + 'Manage privileged cloud access operations', 'ac') class PAMGatewayCommand(GroupCommand): @@ -194,6 +204,7 @@ def __init__(self): super(PAMGatewayCommand, self).__init__() self.register_command('list', PAMGatewayListCommand(), 'List Gateways', 'l') self.register_command('new', PAMCreateGatewayCommand(), 'Create new Gateway', 'n') + self.register_command('edit', PAMEditGatewayCommand(), 'Edit Gateway', 'e') self.register_command('remove', PAMGatewayRemoveCommand(), 'Remove Gateway', 'rm') self.register_command('set-max-instances', PAMSetMaxInstancesCommand(), 'Set maximum gateway instances', 'smi') # self.register_command('connect', PAMConnect(), 'Connect') @@ -217,7 +228,7 @@ class PAMRotationCommand(GroupCommand): def __init__(self): super(PAMRotationCommand, self).__init__() - self.register_command('edit', PAMCreateRecordRotationCommand(), 'Edits Record Rotation configuration', 'new') + self.register_command('edit', PAMCreateRecordRotationCommand(), 'Edits Record Rotation configuration', 'new') self.register_command('list', PAMListRecordRotationCommand(), 'List Record Rotation configuration', 'l') self.register_command('info', PAMRouterGetRotationInfo(), 'Get Rotation Info', 'i') self.register_command('script', PAMRouterScriptCommand(), 'Add, delete, or edit script field') @@ -230,8 +241,10 @@ def __init__(self): super(PAMDiscoveryCommand, self).__init__() self.register_command('start', PAMGatewayActionDiscoverJobStartCommand(), 'Start a discovery process', 's') self.register_command('status', PAMGatewayActionDiscoverJobStatusCommand(), 'Status of discovery jobs', 'st') - self.register_command('remove', PAMGatewayActionDiscoverJobRemoveCommand(), 'Cancel or remove of discovery jobs', 'r') - self.register_command('process', PAMGatewayActionDiscoverResultProcessCommand(), 'Process discovered items', 'p') + self.register_command('remove', PAMGatewayActionDiscoverJobRemoveCommand(), + 'Cancel or remove of discovery jobs', 'r') + self.register_command('process', PAMGatewayActionDiscoverResultProcessCommand(), 'Process discovered items', + 'p') self.register_command('rule', PAMDiscoveryRuleCommand(), 'Manage discovery rules') self.default_verb = 'status' @@ -316,8 +329,10 @@ def __init__(self): class PAMLegacyCommand(Command): - parser = argparse.ArgumentParser(prog='pam legacy', description="Toggle PAM Legacy mode: ON/OFF - PAM Legacy commands are obsolete") - parser.add_argument('--status', '-s', required=False, dest='status', action='store_true', help='Show the current status - Legacy mode: ON/OFF') + parser = argparse.ArgumentParser(prog='pam legacy', + description="Toggle PAM Legacy mode: ON/OFF - PAM Legacy commands are obsolete") + parser.add_argument('--status', '-s', required=False, dest='status', action='store_true', + help='Show the current status - Legacy mode: ON/OFF') def get_parser(self): return PAMLegacyCommand.parser @@ -333,7 +348,7 @@ def execute(self, params, **kwargs): if legacy: print("PAM Legacy mode: ON") else: - print ("PAM Legacy mode: OFF") + print("PAM Legacy mode: OFF") return toggle_pam_legacy_commands(not legacy) @@ -382,7 +397,8 @@ class PAMCreateRecordRotationCommand(Command): choices=['general', 'iam_user', 'scripts_only'], help='Rotation profile type: general (resource-based), iam_user (IAM/Azure user), ' 'scripts_only (run PAM scripts only)') - parser.add_argument('--resource', '-rs', dest='resource', action='store', help='UID or path of the resource record.') + parser.add_argument('--resource', '-rs', dest='resource', action='store', + help='UID or path of the resource record.') schedule_group = parser.add_mutually_exclusive_group() schedule_group.add_argument('--schedulejson', '-sj', required=False, dest='schedule_json_data', action='append', help='JSON of the scheduler. Example: -sj \'{"type": "WEEKLY", ' @@ -396,7 +412,7 @@ class PAMCreateRecordRotationCommand(Command): action='store_true', help='Schedule from Configuration') parser.add_argument('--schedule-only', '-so', dest='schedule_only', action='store_true', help='Only update the rotation schedule without changing other settings') - parser.add_argument('--complexity', '-x', required=False, dest='pwd_complexity', action='store', + parser.add_argument('--complexity', '-x', required=False, dest='pwd_complexity', action='store', help='Password complexity: length, upper, lower, digits, symbols. Ex. 32,5,5,5,5[,SPECIAL CHARS]') parser.add_argument('--admin-user', '-a', required=False, dest='admin', action='store', help='UID or path for the PAMUser record to configure the admin credential on the PAM Resource as the Admin when rotating') @@ -432,7 +448,7 @@ def config_resource(_dag, target_record, target_config_uid, silent=None): if not _dag.resource_belongs_to_config(target_record.record_uid): # Change DAG to this new configuration. resource_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, - target_record.record_uid, transmission_key=transmission_key) + target_record.record_uid, transmission_key=transmission_key) _dag.link_resource_to_config(target_record.record_uid) admin = kwargs.get('admin') @@ -447,7 +463,7 @@ def config_resource(_dag, target_record, target_config_uid, silent=None): if _rotation_enabled is not None: _dag.set_resource_allowed(target_record.record_uid, rotation=_rotation_enabled, - allowed_settings_name="rotation") + allowed_settings_name="rotation") if resource_dag is not None and resource_dag.linking_dag.has_graph: # TODO: Make sure this doesn't remove everything from the new dag too @@ -462,7 +478,8 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): # Handle schedule-only operations first to avoid unnecessary resource validation if schedule_only: - if kwargs.get('folder_name') and (not current_record_rotation or current_record_rotation.get('disabled')): + if kwargs.get('folder_name') and ( + not current_record_rotation or current_record_rotation.get('disabled')): skipped_records.append([target_record.record_uid, target_record.title, 'Rotation not enabled', 'Skipped']) return @@ -480,7 +497,8 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): record_schedule_data = json.loads(cs) if cs else [] except: record_schedule_data = [] - pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation.get('pwd_complexity', '')) + pwd_complexity_rule_list_encrypted = utils.base64_url_decode( + current_record_rotation.get('pwd_complexity', '')) record_resource_uid = current_record_rotation.get('resource_uid') # IAM users have resource_uid == config_uid; should be empty to preserve rotation profile if record_resource_uid == record_config_uid: @@ -494,7 +512,8 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): complexity = '' if pwd_complexity_rule_list_encrypted: try: - decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, target_record.record_key) + decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, + target_record.record_key) c = json.loads(decrypted_complexity.decode()) complexity = f"{c.get('length', 0)},{c.get('caps', 0)},{c.get('lowercase', 0)},{c.get('digits', 0)},{c.get('special', 0)},{c.get('specialChars', PAM_DEFAULT_SPECIAL_CHAR)}" except Exception: @@ -559,17 +578,20 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): record_config_uid = current_record_rotation.get('configuration_uid') pc = vault.KeeperRecord.load(params, record_config_uid) if pc is None: - skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration was deleted', - 'Specify a configuration UID parameter [--config]']) + skipped_records.append( + [target_record.record_uid, target_record.title, 'PAM Configuration was deleted', + 'Specify a configuration UID parameter [--config]']) return if not isinstance(pc, vault.TypedRecord) or pc.version != 6: - skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration is invalid', - 'Specify a configuration UID parameter [--config]']) + skipped_records.append( + [target_record.record_uid, target_record.title, 'PAM Configuration is invalid', + 'Specify a configuration UID parameter [--config]']) return record_pam_config = pc else: - skipped_records.append([target_record.record_uid, target_record.title, 'No current PAM Configuration', - 'Specify a configuration UID parameter [--config]']) + skipped_records.append( + [target_record.record_uid, target_record.title, 'No current PAM Configuration', + 'Specify a configuration UID parameter [--config]']) return # 2. Schedule @@ -591,7 +613,8 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): # 3. Password complexity if pwd_complexity_rule_list is None: if current_record_rotation: - pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation['pwd_complexity']) + pwd_complexity_rule_list_encrypted = utils.base64_url_decode( + current_record_rotation['pwd_complexity']) else: pwd_complexity_rule_list_encrypted = b'' else: @@ -624,9 +647,11 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): disabled = False # 5. Enable rotation if kwargs.get('enable'): - _dag.set_resource_allowed(target_iam_aad_config_uid, rotation=True, is_config=bool(target_iam_aad_config_uid)) + _dag.set_resource_allowed(target_iam_aad_config_uid, rotation=True, + is_config=bool(target_iam_aad_config_uid)) elif kwargs.get('disable'): - _dag.set_resource_allowed(target_iam_aad_config_uid, rotation=False, is_config=bool(target_iam_aad_config_uid)) + _dag.set_resource_allowed(target_iam_aad_config_uid, rotation=False, + is_config=bool(target_iam_aad_config_uid)) disabled = True schedule = 'On-Demand' @@ -636,18 +661,20 @@ def config_iam_aad_user(_dag, target_record, target_iam_aad_config_uid): complexity = '' if pwd_complexity_rule_list_encrypted: try: - decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, target_record.record_key) + decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, + target_record.record_key) c = json.loads(decrypted_complexity.decode()) - complexity = f"{c.get('length', 0)},"\ - f"{c.get('caps', 0)},"\ - f"{c.get('lowercase', 0)},"\ - f"{c.get('digits', 0)},"\ - f"{c.get('special', 0)},"\ + complexity = f"{c.get('length', 0)}," \ + f"{c.get('caps', 0)}," \ + f"{c.get('lowercase', 0)}," \ + f"{c.get('digits', 0)}," \ + f"{c.get('special', 0)}," \ f"{c.get('specialChars', PAM_DEFAULT_SPECIAL_CHAR)}" except: pass valid_records.append( - [target_record.record_uid, target_record.title, not disabled, record_config_uid, record_resource_uid, schedule, + [target_record.record_uid, target_record.title, not disabled, record_config_uid, record_resource_uid, + schedule, complexity]) # 6. Construct Request object for IAM: empty resourceUid and noop=False @@ -669,7 +696,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None # Handle schedule-only operations first to avoid unnecessary resource validation if schedule_only: - if kwargs.get('folder_name') and (not current_record_rotation or current_record_rotation.get('disabled')): + if kwargs.get('folder_name') and ( + not current_record_rotation or current_record_rotation.get('disabled')): skipped_records.append([target_record.record_uid, target_record.title, 'Rotation not enabled', 'Skipped']) return @@ -687,7 +715,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None record_schedule_data = json.loads(cs) if cs else [] except: record_schedule_data = [] - pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation.get('pwd_complexity', '')) + pwd_complexity_rule_list_encrypted = utils.base64_url_decode( + current_record_rotation.get('pwd_complexity', '')) record_resource_uid = current_record_rotation.get('resource_uid') # IAM users have resource_uid == config_uid; should be empty to preserve rotation profile if record_resource_uid == record_config_uid: @@ -701,7 +730,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None complexity = '' if pwd_complexity_rule_list_encrypted: try: - decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, target_record.record_key) + decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, + target_record.record_key) c = json.loads(decrypted_complexity.decode()) complexity = f"{c.get('length', 0)},{c.get('caps', 0)},{c.get('lowercase', 0)},{c.get('digits', 0)},{c.get('special', 0)},{c.get('specialChars', PAM_DEFAULT_SPECIAL_CHAR)}" except Exception: @@ -804,24 +834,24 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None return else: raise CommandError('', f'{bcolors.FAIL}Record "{target_record.record_uid}" is ' - f'associated with multiple resources so you must supply ' - f'{bcolors.OKBLUE}"--resource/-rs RESOURCE".{bcolors.ENDC}') + f'associated with multiple resources so you must supply ' + f'{bcolors.OKBLUE}"--resource/-rs RESOURCE".{bcolors.ENDC}') elif len(resource_uids) == 0: raise CommandError('', - f'{bcolors.FAIL}Record "{target_record.record_uid}" is not associated with' - f' any resource. Please use {bcolors.OKBLUE}"pam rotation user ' - f'{target_record.record_uid} --resource RESOURCE" {bcolors.FAIL}to associate ' - f'it.{bcolors.ENDC}') + f'{bcolors.FAIL}Record "{target_record.record_uid}" is not associated with' + f' any resource. Please use {bcolors.OKBLUE}"pam rotation user ' + f'{target_record.record_uid} --resource RESOURCE" {bcolors.FAIL}to associate ' + f'it.{bcolors.ENDC}') target_resource_uid = resource_uids[0] if not _dag.resource_belongs_to_config(target_resource_uid): # some rotations (iam_user/noop) link straight to pamConfiguration if target_resource_uid != _dag.record.record_uid: raise CommandError('', - f'{bcolors.FAIL}Resource "{target_resource_uid}" is not associated with the ' - f'configuration of the user "{target_record.record_uid}". To associated the resources ' - f'to this config run {bcolors.OKBLUE}"pam rotation resource {target_resource_uid} ' - f'--config {_dag.record.record_uid}"{bcolors.ENDC}') + f'{bcolors.FAIL}Resource "{target_resource_uid}" is not associated with the ' + f'configuration of the user "{target_record.record_uid}". To associated the resources ' + f'to this config run {bcolors.OKBLUE}"pam rotation resource {target_resource_uid} ' + f'--config {_dag.record.record_uid}"{bcolors.ENDC}') if not _dag.user_belongs_to_resource(target_record.record_uid, target_resource_uid): old_resource_uid = _dag.get_resource_uid(target_record.record_uid) if old_resource_uid is not None and old_resource_uid != target_resource_uid: @@ -841,17 +871,20 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None record_config_uid = current_record_rotation.get('configuration_uid') pc = vault.KeeperRecord.load(params, record_config_uid) if pc is None: - skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration was deleted', - 'Specify a configuration UID parameter [--config]']) + skipped_records.append( + [target_record.record_uid, target_record.title, 'PAM Configuration was deleted', + 'Specify a configuration UID parameter [--config]']) return if not isinstance(pc, vault.TypedRecord) or pc.version != 6: - skipped_records.append([target_record.record_uid, target_record.title, 'PAM Configuration is invalid', - 'Specify a configuration UID parameter [--config]']) + skipped_records.append( + [target_record.record_uid, target_record.title, 'PAM Configuration is invalid', + 'Specify a configuration UID parameter [--config]']) return record_pam_config = pc else: - skipped_records.append([target_record.record_uid, target_record.title, 'No current PAM Configuration', - 'Specify a configuration UID parameter [--config]']) + skipped_records.append( + [target_record.record_uid, target_record.title, 'No current PAM Configuration', + 'Specify a configuration UID parameter [--config]']) return # 2. Schedule @@ -875,7 +908,8 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None # 3. Password complexity if pwd_complexity_rule_list is None: if current_record_rotation: - pwd_complexity_rule_list_encrypted = utils.base64_url_decode(current_record_rotation['pwd_complexity']) + pwd_complexity_rule_list_encrypted = utils.base64_url_decode( + current_record_rotation['pwd_complexity']) else: pwd_complexity_rule_list_encrypted = b'' else: @@ -929,18 +963,20 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None complexity = '' if pwd_complexity_rule_list_encrypted: try: - decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, target_record.record_key) + decrypted_complexity = crypto.decrypt_aes_v2(pwd_complexity_rule_list_encrypted, + target_record.record_key) c = json.loads(decrypted_complexity.decode()) - complexity = f"{c.get('length', 0)},"\ - f"{c.get('caps', 0)},"\ - f"{c.get('lowercase', 0)},"\ - f"{c.get('digits', 0)},"\ + complexity = f"{c.get('length', 0)}," \ + f"{c.get('caps', 0)}," \ + f"{c.get('lowercase', 0)}," \ + f"{c.get('digits', 0)}," \ f"{c.get('special', 0)}," \ f"{c.get('specialChars', PAM_DEFAULT_SPECIAL_CHAR)}" except: pass valid_records.append( - [target_record.record_uid, target_record.title, not disabled, record_config_uid, record_resource_uid, schedule, + [target_record.record_uid, target_record.title, not disabled, record_config_uid, record_resource_uid, + schedule, complexity]) # 6. Construct Request object @@ -959,7 +995,7 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None r_requests.append(rq) # Main execute() logic starts here - record_uids = set() # type: Set[str] + record_uids = set() # type: Set[str] folder_uids = set() record_pattern = '' @@ -994,7 +1030,7 @@ def config_user(_dag, target_record, target_resource_uid, target_config_uid=None folder, record_title = rs if not record_title: - def add_folders(sub_folder): # type: (BaseFolderNode) -> None + def add_folders(sub_folder): # type: (BaseFolderNode) -> None folder_uids.add(sub_folder.uid or '') if isinstance(folder, BaseFolderNode): @@ -1025,7 +1061,7 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None continue record_uids.add(record_uid) - pam_records = [] # type: List[vault.TypedRecord] + pam_records = [] # type: List[vault.TypedRecord] valid_record_types = ['pamDatabase', 'pamDirectory', 'pamMachine', 'pamUser', 'pamRemoteBrowser'] for record_uid in record_uids: record = vault.KeeperRecord.load(params, record_uid) @@ -1039,14 +1075,15 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None if not kwargs.get('silent'): logging.info('Selected %d PAM record(s) for rotation', len(pam_records)) - pam_configurations = {x.record_uid: x for x in vault_extensions.find_records(params, record_version=6) if isinstance(x, vault.TypedRecord)} + pam_configurations = {x.record_uid: x for x in vault_extensions.find_records(params, record_version=6) if + isinstance(x, vault.TypedRecord)} config_uid = kwargs.get('config') cfg_rec = RecordMixin.resolve_single_record(params, kwargs.get('config', None)) if cfg_rec and cfg_rec.version == 6 and cfg_rec.record_uid in pam_configurations: config_uid = cfg_rec.record_uid - pam_config = None # type: Optional[vault.TypedRecord] + pam_config = None # type: Optional[vault.TypedRecord] if config_uid: if config_uid in pam_configurations: pam_config = pam_configurations[config_uid] @@ -1057,7 +1094,7 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None schedule_data = parse_schedule_data(kwargs) pwd_complexity = kwargs.get("pwd_complexity") - pwd_complexity_rule_list = None # type: Optional[dict] + pwd_complexity_rule_list = None # type: Optional[dict] if pwd_complexity is not None: if pwd_complexity: pwd_complexity_list = [s.strip() for s in pwd_complexity.split(',', maxsplit=5)] @@ -1093,10 +1130,11 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None skipped_header = ['record_uid', 'record_title', 'problem', 'description'] skipped_records = [] - valid_header = ['record_uid', 'record_title', 'enabled', 'configuration_uid', 'resource_uid', 'schedule', 'complexity'] + valid_header = ['record_uid', 'record_title', 'enabled', 'configuration_uid', 'resource_uid', 'schedule', + 'complexity'] valid_records = [] - r_requests = [] # type: List[router_pb2.RouterRecordRotationRequest] + r_requests = [] # type: List[router_pb2.RouterRecordRotationRequest] # Note: --folder, -fd FOLDER_NAME sets up General rotation # use --schedule-only, -so to preserve individual setups (General, IAM, NOOP) @@ -1129,9 +1167,10 @@ def add_folders(sub_folder): # type: (BaseFolderNode) -> None effective_config_uid = current_rotation.get('configuration_uid') if not effective_config_uid: raise CommandError('', 'IAM user rotation requires a PAM Configuration. ' - 'Use --config or --iam-aad-config to specify one.') + 'Use --config or --iam-aad-config to specify one.') if effective_config_uid not in pam_configurations: - raise CommandError('', f'Record uid {effective_config_uid} is not a PAM Configuration record.') + raise CommandError('', + f'Record uid {effective_config_uid} is not a PAM Configuration record.') config_iam_aad_user(tmp_dag, _record, effective_config_uid) elif rotation_profile == 'scripts_only': # Set noop flag for scripts_only profile @@ -1376,10 +1415,11 @@ def execute(self, params, **kwargs): if format_type == 'json': return json.dumps({"gateways": [], "message": "This Enterprise does not have Gateways yet."}) else: - print(f"{bcolors.OKBLUE}\nThis Enterprise does not have Gateways yet. To create new Gateway, use command " - f"`{bcolors.ENDC}{bcolors.OKGREEN}pam gateway new{bcolors.ENDC}{bcolors.OKBLUE}`\n\n" - f"NOTE: If you have added new Gateway, you might still need to initialize it before it is " - f"listed.{bcolors.ENDC}") + print( + f"{bcolors.OKBLUE}\nThis Enterprise does not have Gateways yet. To create new Gateway, use command " + f"`{bcolors.ENDC}{bcolors.OKGREEN}pam gateway new{bcolors.ENDC}{bcolors.OKBLUE}`\n\n" + f"NOTE: If you have added new Gateway, you might still need to initialize it before it is " + f"listed.{bcolors.ENDC}") return table = [] @@ -1388,8 +1428,8 @@ def execute(self, params, **kwargs): if format_type == 'json': headers = ['ksm_app_name_uid', 'gateway_name', 'gateway_uid', 'status', 'gateway_version'] if is_verbose: - headers.extend(['device_name', 'device_token', 'created_on', 'last_modified', 'node_id', - 'os', 'os_release', 'machine_type', 'os_version']) + headers.extend(['device_name', 'device_token', 'created_on', 'last_modified', 'node_id', + 'os', 'os_release', 'machine_type', 'os_version']) else: headers = [] headers.append('KSM Application Name (UID)') @@ -1502,7 +1542,8 @@ def execute(self, params, **kwargs): row_color = bcolors.OKGREEN row = [] - row.append(f'{row_color if ksm_app_accessible else bcolors.WHITE}{ksm_app_info_plain}{bcolors.ENDC}') + row.append( + f'{row_color if ksm_app_accessible else bcolors.WHITE}{ksm_app_info_plain}{bcolors.ENDC}') row.append(f'{row_color}{c.controllerName}{bcolors.ENDC}') row.append(f'{row_color}{gateway_uid_str}{bcolors.ENDC}') row.append(f'{row_color}{status}{bcolors.ENDC}') @@ -1570,7 +1611,8 @@ def execute(self, params, **kwargs): "device_name": c.deviceName, "device_token": c.deviceToken, "created_on": datetime.fromtimestamp(c.created / 1000).strftime('%Y-%m-%d %H:%M:%S'), - "last_modified": datetime.fromtimestamp(c.lastModified / 1000).strftime('%Y-%m-%d %H:%M:%S'), + "last_modified": datetime.fromtimestamp(c.lastModified / 1000).strftime( + '%Y-%m-%d %H:%M:%S'), "node_id": c.nodeId }) @@ -1581,7 +1623,8 @@ def execute(self, params, **kwargs): # Parent gateway row row = [] - row.append(f'{row_color if ksm_app_accessible else bcolors.WHITE}{ksm_app_info_plain}{bcolors.ENDC}') + row.append( + f'{row_color if ksm_app_accessible else bcolors.WHITE}{ksm_app_info_plain}{bcolors.ENDC}') row.append(f'{row_color}{c.controllerName}{bcolors.ENDC}') row.append(f'{row_color}{gateway_uid_str}{bcolors.ENDC}') row.append(f'{row_color}{overall_status}{bcolors.ENDC}') @@ -1609,7 +1652,8 @@ def execute(self, params, **kwargs): version = version_parts[0] if version_parts else instance.version ip_address = instance.ipAddress if hasattr(instance, 'ipAddress') else "" - connected_on = datetime.fromtimestamp(instance.connectedOn / 1000).strftime('%Y-%m-%d %H:%M:%S') if hasattr(instance, 'connectedOn') else "" + connected_on = datetime.fromtimestamp(instance.connectedOn / 1000).strftime( + '%Y-%m-%d %H:%M:%S') if hasattr(instance, 'connectedOn') else "" instance_row = [] instance_row.append('') # Empty KSM app column @@ -1626,7 +1670,8 @@ def execute(self, params, **kwargs): instance_row.append('') instance_row.append('') - instance_row.append(f'{row_color}{datetime.fromtimestamp(instance.connectedOn / 1000) if hasattr(instance, "connectedOn") else ""}{bcolors.ENDC}') + instance_row.append( + f'{row_color}{datetime.fromtimestamp(instance.connectedOn / 1000) if hasattr(instance, "connectedOn") else ""}{bcolors.ENDC}') instance_row.append('') instance_row.append('') instance_row.append(f'{row_color}{os_name}{bcolors.ENDC}') @@ -1639,7 +1684,7 @@ def execute(self, params, **kwargs): if format_type == 'json': # Sort JSON data by status and app name gateways_data.sort(key=lambda x: (x['status'], (x['ksm_app_name'] or '').lower())) - + if is_verbose: krouter_host = get_router_url(params) result = { @@ -1648,7 +1693,7 @@ def execute(self, params, **kwargs): } else: result = {"gateways": gateways_data} - + return json.dumps(result, indent=2) else: # Separate rows into groups: each parent with its instances @@ -1703,7 +1748,8 @@ def execute(self, params, **kwargs): if format_type == 'json' and result: return result else: # Print element configs (config that is not a root) - result = PAMConfigurationListCommand.print_pam_configuration_details(params, pam_configuration_uid, is_verbose, format_type) + result = PAMConfigurationListCommand.print_pam_configuration_details(params, pam_configuration_uid, + is_verbose, format_type) if format_type == 'json' and result: return result @@ -1734,12 +1780,12 @@ def print_pam_configuration_details(params, config_uid, is_verbose=False, format facade = PamConfigurationRecordFacade() facade.record = configuration - + folder_uid = facade.folder_uid sf = None if folder_uid in params.shared_folder_cache: sf = api.get_shared_folder(params, folder_uid) - + if format_type == 'json': config_data = { "uid": configuration.record_uid, @@ -1753,8 +1799,8 @@ def print_pam_configuration_details(params, config_uid, is_verbose=False, format "resource_record_uids": facade.resource_ref, "fields": {} } - - for field in configuration.fields: + + for field in itertools.chain(configuration.fields, configuration.custom): if field.type in ('pamResources', 'fileRef'): continue values = list(field.get_external_value()) @@ -1763,9 +1809,9 @@ def print_pam_configuration_details(params, config_uid, is_verbose=False, format field_name = field.get_field_name() if field.type == 'schedule': field_name = 'Default Schedule' - + config_data["fields"][field_name] = values - + return json.dumps(config_data, indent=2) else: table = [] @@ -1777,7 +1823,7 @@ def print_pam_configuration_details(params, config_uid, is_verbose=False, format table.append(['Gateway UID', facade.controller_uid]) table.append(['Resource Record UIDs', facade.resource_ref]) - for field in configuration.fields: + for field in itertools.chain(configuration.fields, configuration.custom): if field.type in ('pamResources', 'fileRef'): continue values = list(field.get_external_value()) @@ -1794,10 +1840,10 @@ def print_pam_configuration_details(params, config_uid, is_verbose=False, format def print_root_rotation_setting(params, is_verbose=False, format_type='table'): configurations = list(vault_extensions.find_records(params, record_version=6)) facade = PamConfigurationRecordFacade() - + configs_data = [] table = [] - + if format_type == 'json': headers = ['uid', 'config_name', 'config_type', 'shared_folder', 'gateway_uid', 'resource_record_uids'] if is_verbose: @@ -1808,12 +1854,13 @@ def print_root_rotation_setting(params, is_verbose=False, format_type='table'): headers.append('Fields') for c in configurations: # type: vault.TypedRecord - if c.record_type in ('pamAwsConfiguration', 'pamAzureConfiguration', 'pamGcpConfiguration', 'pamDomainConfiguration', 'pamNetworkConfiguration', 'pamOciConfiguration'): + if c.record_type in ('pamAwsConfiguration', 'pamAzureConfiguration', 'pamGcpConfiguration', + 'pamDomainConfiguration', 'pamNetworkConfiguration', 'pamOciConfiguration'): facade.record = c shared_folder_parents = find_parent_top_folder(params, c.record_uid) if shared_folder_parents: sf = shared_folder_parents[0] - + if format_type == 'json': config_data = { "uid": c.record_uid, @@ -1829,7 +1876,7 @@ def print_root_rotation_setting(params, is_verbose=False, format_type='table'): if is_verbose: fields = {} - for field in c.fields: + for field in itertools.chain(c.fields, c.custom): if field.type in ('pamResources', 'fileRef'): continue value = ', '.join(field.get_external_value()) @@ -1844,7 +1891,7 @@ def print_root_rotation_setting(params, is_verbose=False, format_type='table'): if is_verbose: fields = [] - for field in c.fields: + for field in itertools.chain(c.fields, c.custom): if field.type in ('pamResources', 'fileRef'): continue value = ', '.join(field.get_external_value()) @@ -1876,8 +1923,11 @@ def print_root_rotation_setting(params, is_verbose=False, format_type='table'): common_parser.add_argument('--shared-folder', '-sf', dest='shared_folder_uid', action='store', help='Share Folder where this PAM Configuration is stored. Should be one of the folders to ' 'which the gateway has access to.') -common_parser.add_argument('--schedule', '-sc', dest='default_schedule', action='store', help='Default Schedule: Use CRON syntax') +common_parser.add_argument('--schedule', '-sc', dest='default_schedule', action='store', + help='Default Schedule: Use CRON syntax') common_parser.add_argument('--port-mapping', '-pm', dest='port_mapping', action='append', help='Port Mapping') +common_parser.add_argument('--identity-provider', '-idp', dest='identity_provider_uid', + action='store', help='Identity Provider UID') network_group = common_parser.add_argument_group('network', 'Local network configuration') network_group.add_argument('--network-id', dest='network_id', action='store', help='Network ID') network_group.add_argument('--network-cidr', dest='network_cidr', action='store', help='Network CIDR') @@ -1898,26 +1948,33 @@ def print_root_rotation_setting(params, is_verbose=False, format_type='table'): domain_group.add_argument('--domain-id', dest='domain_id', action='store', help='Domain ID') domain_group.add_argument('--domain-hostname', dest='domain_hostname', action='store', help='Domain hostname') domain_group.add_argument('--domain-port', dest='domain_port', action='store', help='Domain port') -domain_group.add_argument('--domain-use-ssl', dest='domain_use_ssl', choices=['true', 'false'], help='Domain use SSL flag') -domain_group.add_argument('--domain-scan-dc-cidr', dest='domain_scan_dc_cidr', choices=['true', 'false'], help='Domain scan DC CIDR flag') -domain_group.add_argument('--domain-network-cidr', dest='domain_network_cidr', action='store', help='Domain Network CIDR') -domain_group.add_argument('--domain-admin', dest='domain_administrative_credential', action='store', help='Domain administrative credential') +domain_group.add_argument('--domain-use-ssl', dest='domain_use_ssl', choices=['true', 'false'], + help='Domain use SSL flag') +domain_group.add_argument('--domain-scan-dc-cidr', dest='domain_scan_dc_cidr', choices=['true', 'false'], + help='Domain scan DC CIDR flag') +domain_group.add_argument('--domain-network-cidr', dest='domain_network_cidr', action='store', + help='Domain Network CIDR') +domain_group.add_argument('--domain-admin', dest='domain_administrative_credential', action='store', + help='Domain administrative credential') oci_group = common_parser.add_argument_group('oci', 'OCI configuration') oci_group.add_argument('--oci-id', dest='oci_id', action='store', help='OCI ID') oci_group.add_argument('--oci-admin-id', dest='oci_admin_id', action='store', help='OCI Admin ID') -oci_group.add_argument('--oci-admin-public-key', dest='oci_admin_public_key', action='store', help='OCI admin public key') -oci_group.add_argument('--oci-admin-private-key', dest='oci_admin_private_key', action='store', help='OCI admin private key') +oci_group.add_argument('--oci-admin-public-key', dest='oci_admin_public_key', action='store', + help='OCI admin public key') +oci_group.add_argument('--oci-admin-private-key', dest='oci_admin_private_key', action='store', + help='OCI admin private key') oci_group.add_argument('--oci-tenancy', dest='oci_tenancy', action='store', help='OCI tenancy') oci_group.add_argument('--oci-region', dest='oci_region', action='store', help='OCI region') gcp_group = common_parser.add_argument_group('gcp', 'GCP configuration') gcp_group.add_argument('--gcp-id', dest='gcp_id', action='store', help='GCP Id') gcp_group.add_argument('--service-account-key', dest='service_account_key', action='store', - help='Service Account Key (JSON format)') + help='Service Account Key (JSON format)') gcp_group.add_argument('--google-admin-email', dest='google_admin_email', action='store', - help='Google Workspace Administrator Email Address') + help='Google Workspace Administrator Email Address') gcp_group.add_argument('--gcp-region', dest='region_names', action='append', help='GCP Region Names') + class PamConfigurationEditMixin(RecordEditMixin): pam_record_types = None @@ -2014,7 +2071,8 @@ def parse_pam_configuration(self, params, record, **kwargs): value['resourceRef'] = list(record_uids) @staticmethod - def resolve_single_record(params, record_name, rec_type=''): # type: (KeeperParams, str, str) -> Optional[vault.KeeperRecord] + def resolve_single_record(params, record_name, + rec_type=''): # type: (KeeperParams, str, str) -> Optional[vault.KeeperRecord] rec = RecordMixin.resolve_single_record(params, record_name) if not rec: recs = [] @@ -2040,10 +2098,15 @@ def parse_properties(self, params, record, **kwargs): # type: (KeeperParams, va if not valid: raise CommandError('', f'Invalid CRON "{schedule}" Error: {err}') if schedule: - extra_properties.append(f'schedule.defaultRotationSchedule=$JSON:{{"type": "CRON", "cron": "{schedule}", "tz": "Etc/UTC"}}') + extra_properties.append( + f'schedule.defaultRotationSchedule=$JSON:{{"type": "CRON", "cron": "{schedule}", "tz": "Etc/UTC"}}') else: extra_properties.append('schedule.defaultRotationSchedule=On-Demand') + identity_provider_uid = kwargs.get('identity_provider_uid') + if identity_provider_uid: + extra_properties.append(f'text.identityProviderUid={identity_provider_uid}') + if record.record_type == 'pamNetworkConfiguration': network_id = kwargs.get('network_id') if network_id: @@ -2192,7 +2255,8 @@ class PAMConfigurationNewCommand(Command, PamConfigurationEditMixin): help='Set TypeScript recording permissions for the resource') parser.add_argument('--ai-threat-detection', dest='ai_threat_detection', choices=choices, help='Set AI threat detection permissions') - parser.add_argument('--ai-terminate-session-on-detection', dest='ai_terminate_session_on_detection', choices=choices, + parser.add_argument('--ai-terminate-session-on-detection', dest='ai_terminate_session_on_detection', + choices=choices, help='Set AI session termination on threat detection permissions') def __init__(self): @@ -2221,7 +2285,7 @@ def execute(self, params, **kwargs): record_type = 'pamOciConfiguration' else: raise CommandError('pam-config-new', f'--environment {config_type} is not supported' - ' - supported options: local, aws, azure, gcp, domain, oci') + ' - supported options: local, aws, azure, gcp, domain, oci') title = kwargs.get('title') if not title: @@ -2378,6 +2442,10 @@ def execute(self, params, **kwargs): if rt_fields: RecordEditMixin.adjust_typed_record_fields(configuration, rt_fields) + rt_fields = RecordEditMixin.get_record_type_fields(params, configuration.record_type) + if rt_fields: + RecordEditMixin.adjust_typed_record_fields(configuration, rt_fields) + title = kwargs.get('title') if title: configuration.title = title @@ -2412,7 +2480,8 @@ def execute(self, params, **kwargs): shared_folder_uid = value.get('folderUid') or '' if shared_folder_uid != orig_shared_folder_uid: FolderMoveCommand().execute(params, src=configuration.record_uid, dst=shared_folder_uid) - if configuration.type_name == 'pamDomainConfiguration' and not kwargs.get('force_domain_admin', False) is True: + if configuration.type_name == 'pamDomainConfiguration' and not kwargs.get('force_domain_admin', + False) is True: # pamUser must exist or "403 Insufficient PAM access to perform this operation" admin_cred_ref = value.get('adminCredentialRef') or '' @@ -2425,7 +2494,7 @@ def execute(self, params, **kwargs): _typescript_recording = kwargs.get('typescriptrecording', None) if (_connections is not None or _tunneling is not None or _rotation is not None or _rbi is not None or - _recording is not None or _typescript_recording is not None or orig_admin_cred_ref != admin_cred_ref): + _recording is not None or _typescript_recording is not None or orig_admin_cred_ref != admin_cred_ref): encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(params) tmp_dag = TunnelDAG(params, encrypted_session_token, encrypted_transmission_key, configuration.record_uid, is_config=True, transmission_key=transmission_key) @@ -2498,8 +2567,10 @@ def execute(self, params, **kwargs): print(f'PAM Config UID: {bcolors.OKBLUE}{configuration_uid}{bcolors.ENDC}') print(f'Node ID: {bcolors.OKBLUE}{rri.nodeId}{bcolors.ENDC}') - print(f"Gateway Name where the rotation will be performed: {bcolors.OKBLUE}{(rri.controllerName if rri.controllerName else '-')}{bcolors.ENDC}") - print(f"Gateway Uid: {bcolors.OKBLUE}{(utils.base64_url_encode(rri.controllerUid) if rri.controllerUid else '-') } {bcolors.ENDC}") + print( + f"Gateway Name where the rotation will be performed: {bcolors.OKBLUE}{(rri.controllerName if rri.controllerName else '-')}{bcolors.ENDC}") + print( + f"Gateway Uid: {bcolors.OKBLUE}{(utils.base64_url_encode(rri.controllerUid) if rri.controllerUid else '-')} {bcolors.ENDC}") def is_resource_ok(resource_id, params, configuration_uid): if resource_id not in params.record_cache: @@ -2533,7 +2604,8 @@ def is_resource_ok(resource_id, params, configuration_uid): try: record = params.record_cache.get(record_uid) if record: - complexity = crypto.decrypt_aes_v2(utils.base64_url_decode(rri.pwdComplexity), record['record_key_unencrypted']) + complexity = crypto.decrypt_aes_v2(utils.base64_url_decode(rri.pwdComplexity), + record['record_key_unencrypted']) c = json.loads(complexity.decode()) print(f"Password Complexity Data: {bcolors.OKBLUE}" f"Length: {c.get('length')}; Lowercase: {c.get('lowercase')}; " @@ -2547,7 +2619,7 @@ def is_resource_ok(resource_id, params, configuration_uid): print(f"Password Complexity: {bcolors.OKGREEN}[not set]{bcolors.ENDC}") print(f"Is Rotation Disabled: {bcolors.OKGREEN}{rri.disabled}{bcolors.ENDC}") - + # Get schedule information rq = pam_pb2.PAMGenericUidsRequest() schedules_proto = router_get_rotation_schedules(params, rq) @@ -2568,7 +2640,7 @@ def is_resource_ok(resource_id, params, configuration_uid): schedule_str = s.scheduleData print(f"Schedule: {bcolors.OKBLUE}{schedule_str}{bcolors.ENDC}") break - + print(f"\nCommand to manually rotate: {bcolors.OKGREEN}pam action rotate -r {record_uid}{bcolors.ENDC}") else: print(f'{bcolors.WARNING}Rotation Status: Not ready to rotate ({rri_status_name}){bcolors.ENDC}') @@ -2873,28 +2945,31 @@ def execute(self, params, **kwargs): destination_gateway_uid_str=gateway_uid ) - print_router_response(router_response, 'job_info', original_conversation_id=conversation_id, gateway_uid=gateway_uid) + print_router_response(router_response, 'job_info', original_conversation_id=conversation_id, + gateway_uid=gateway_uid) class PAMGatewayActionRotateCommand(Command): parser = argparse.ArgumentParser(prog='pam action rotate') parser.add_argument('--record-uid', '-r', dest='record_uid', action='store', help='Record UID to rotate') - parser.add_argument('--folder', '-f', dest='folder', action='store', help='Shared folder UID or title pattern to rotate') + parser.add_argument('--folder', '-f', dest='folder', action='store', + help='Shared folder UID or title pattern to rotate') # parser.add_argument('--recursive', '-a', dest='recursive', default=False, action='store', help='Enable recursion to rotate sub-folders too') # parser.add_argument('--record-pattern', '-p', dest='pattern', action='store', help='Record title match pattern') - parser.add_argument('--dry-run', '-n', dest='dry_run', default=False, action='store_true', help='Enable dry-run mode') + parser.add_argument('--dry-run', '-n', dest='dry_run', default=False, action='store_true', + help='Enable dry-run mode') # parser.add_argument('--config', '-c', dest='configuration_uid', action='store', help='Rotation configuration UID') # Email and share link arguments parser.add_argument('--self-destruct', dest='self_destruct', action='store', - metavar='[(m)inutes|(h)ours|(d)ays]', - help='Create one-time share link that expires after duration') + metavar='[(m)inutes|(h)ours|(d)ays]', + help='Create one-time share link that expires after duration') parser.add_argument('--email-config', dest='email_config', action='store', - help='Email configuration name to use for sending (required with --send-email)') + help='Email configuration name to use for sending (required with --send-email)') parser.add_argument('--send-email', dest='send_email', action='store', - help='Email address to send credentials after rotation') + help='Email address to send credentials after rotation') parser.add_argument('--email-message', dest='email_message', action='store', - help='Custom message to include in email') + help='Custom message to include in email') def get_parser(self): return PAMGatewayActionRotateCommand.parser @@ -2915,7 +2990,8 @@ def execute(self, params, **kwargs): # Validate email setup early (before rotation) to avoid rotating password without being able to send email if self.send_email: if not self.email_config: - raise CommandError('pam action rotate', '--send-email requires --email-config to specify email configuration') + raise CommandError('pam action rotate', + '--send-email requires --email-config to specify email configuration') # Find and load email config to validate provider and dependencies try: @@ -2937,7 +3013,8 @@ def execute(self, params, **kwargs): # record, folder or pattern - at least one required if not record_uid and not folder: - print(f'the following arguments are required: {bcolors.OKBLUE}--record-uid/-r{bcolors.ENDC} or {bcolors.OKBLUE}--folder/-f{bcolors.ENDC}') + print( + f'the following arguments are required: {bcolors.OKBLUE}--record-uid/-r{bcolors.ENDC} or {bcolors.OKBLUE}--folder/-f{bcolors.ENDC}') return # single record UID - ignore all folder options @@ -2989,7 +3066,7 @@ def execute(self, params, **kwargs): path.append(child) child = params.folder_cache[child].parent_uid path.append(child) # add root shf - path = path[1:] if path else [] # skip child uid + path = path[1:] if path else [] # skip child uid if not set(path) & fldrset: # no intersect uniq.append(fldr) folders = list(set(uniq)) @@ -3041,14 +3118,14 @@ def execute(self, params, **kwargs): except Exception as e: msg = str(e) # what is considered a throttling error... if re.search(r"throttle", msg, re.IGNORECASE): - delay = (delay+10) % 100 # reset every 1.5 minutes + delay = (delay + 10) % 100 # reset every 1.5 minutes logging.debug(f'Record UID: {record_uid} was throttled (retry in {delay} sec)') - time.sleep(1+delay) + time.sleep(1 + delay) else: logging.error(f'Record UID: {record_uid} skipped: non-throttling, non-recoverable error: {msg}') break - def record_rotate(self, params, record_uid, slient:bool = False): + def record_rotate(self, params, record_uid, slient: bool = False): record = vault.KeeperRecord.load(params, record_uid) if not isinstance(record, vault.TypedRecord): print(f'{bcolors.FAIL}Record [{record_uid}] is not available.{bcolors.ENDC}') @@ -3068,7 +3145,8 @@ def record_rotate(self, params, record_uid, slient:bool = False): 'digits': 1, 'special': 1, } - ri_pwd_complexity_encrypted = utils.base64_url_encode(router_helper.encrypt_pwd_complexity(rule_list_dict, record.record_key)) + ri_pwd_complexity_encrypted = utils.base64_url_encode( + router_helper.encrypt_pwd_complexity(rule_list_dict, record.record_key)) # else: # rule_list_json = crypto.decrypt_aes_v2(utils.base64_url_decode(ri_pwd_complexity_encrypted), record.record_key) # complexity = json.loads(rule_list_json.decode()) @@ -3218,7 +3296,8 @@ def _handle_post_rotation_email(self, params, record): expire_seconds = int(expiration_period.total_seconds()) if expire_seconds <= 0: - logging.warning(f'{bcolors.WARNING}Invalid --self-destruct value. Skipping share link.{bcolors.ENDC}') + logging.warning( + f'{bcolors.WARNING}Invalid --self-destruct value. Skipping share link.{bcolors.ENDC}') return # Calculate human-readable expiration text @@ -3248,7 +3327,8 @@ def _handle_post_rotation_email(self, params, record): # Extract hostname from params.server parsed = urlparse(params.server) server_netloc = parsed.netloc if parsed.netloc else parsed.path - share_url = urlunparse(('https', server_netloc, '/vault/share', None, None, utils.base64_url_encode(client_key))) + share_url = urlunparse( + ('https', server_netloc, '/vault/share', None, None, utils.base64_url_encode(client_key))) logging.info(f'{bcolors.OKGREEN}Share link created successfully{bcolors.ENDC}') except Exception as e: logging.warning(f'{bcolors.WARNING}Failed to create share link: {e}{bcolors.ENDC}') @@ -3261,7 +3341,8 @@ def _handle_post_rotation_email(self, params, record): logging.info(f'Loading email configuration: {self.email_config}') config_uid = find_email_config_record(params, self.email_config) if not config_uid: - logging.warning(f'{bcolors.WARNING}Email configuration "{self.email_config}" not found. Skipping email.{bcolors.ENDC}') + logging.warning( + f'{bcolors.WARNING}Email configuration "{self.email_config}" not found. Skipping email.{bcolors.ENDC}') return # Load the email configuration @@ -3313,11 +3394,12 @@ def str_to_regex(self, text): text = str(text) try: pattern = re.compile(text, re.IGNORECASE) - except: # re.error: yet maybe TypeError, MemoryError, RecursionError etc. + except: # re.error: yet maybe TypeError, MemoryError, RecursionError etc. pattern = re.compile(re.escape(text), re.IGNORECASE) logging.debug(f"regex pattern {text} failed to compile (using it as plaintext pattern)") return pattern + class PAMGatewayActionServerInfoCommand(Command): parser = argparse.ArgumentParser(prog='dr-info-command') parser.add_argument('--gateway', '-g', required=False, dest='gateway_uid', action='store', help='Gateway UID') @@ -3337,11 +3419,11 @@ def execute(self, params, **kwargs): destination_gateway_uid_str=destination_gateway_uid_str ) - print_router_response(router_response, 'gateway_info', is_verbose=is_verbose, gateway_uid=destination_gateway_uid_str) + print_router_response(router_response, 'gateway_info', is_verbose=is_verbose, + gateway_uid=destination_gateway_uid_str) class PAMGatewayActionDiscoverCommandBase(Command): - """ The discover command base. @@ -3524,3 +3606,55 @@ def execute(self, params, **kwargs): print('-----------------------------------------------') print(bcolors.OKGREEN + one_time_token + bcolors.ENDC) print('-----------------------------------------------') + + +class PAMEditGatewayCommand(Command): + parser = argparse.ArgumentParser(prog='pam gateway edit') + parser.add_argument('--gateway', '-g', required=False, dest='gateway', + help='Gateway UID or Name', action='store') + parser.add_argument('--name', '-n', required=False, dest='gateway_name', + help='Name of the Gateway', action='store') + parser.add_argument('--node-id', '-i', required=False, dest='node_id', + help='Node ID', action='store') + + def get_parser(self): + return PAMEditGatewayCommand.parser + + def execute(self, params: KeeperParams, **kwargs): + gateway_uid = kwargs.get('gateway') + new_name = kwargs.get('gateway_name') + node_id_arg = kwargs.get('node_id') + + if not gateway_uid: + raise CommandError('pam gateway edit', 'Argument --gateway is required') + + gateways = gateway_helper.get_all_gateways(params) + + gateway = next((x for x in gateways + if utils.base64_url_encode(x.controllerUid) == gateway_uid + or x.controllerName.lower() == gateway_uid.lower()), None) + + if not gateway: + raise CommandError('', f'{bcolors.FAIL}Gateway "{gateway_uid}" not found{bcolors.ENDC}') + + if not node_id_arg and not new_name: + raise CommandError('pam gateway edit', 'Nothing to do. At least one of --name or --node-id is required to edit the gateway') + + if node_id_arg is not None and str(node_id_arg).strip(): + if not params.enterprise or 'nodes' not in params.enterprise: + raise CommandError('', f'{bcolors.FAIL}Enterprise node data is not loaded{bcolors.ENDC}') + nodes = list(EnterpriseCommand.resolve_nodes(params, str(node_id_arg))) + if len(nodes) == 0: + raise CommandError('', f'{bcolors.FAIL}Node "{node_id_arg}" not found{bcolors.ENDC}') + if len(nodes) > 1: + raise CommandError( + '', + f'{bcolors.FAIL}More than one node "{node_id_arg}" found. Use Node ID.{bcolors.ENDC}') + resolved_node_id = nodes[0]['node_id'] + else: + resolved_node_id = gateway.nodeId + + gateway_name = new_name if new_name else gateway.controllerName + + gateway_helper.edit_gateway(params, gateway.controllerUid, gateway_name, resolved_node_id) + logging.info('Gateway %s has been edited.', gateway_uid) \ No newline at end of file diff --git a/keepercommander/commands/ksm.py b/keepercommander/commands/ksm.py index 59a13e442..bd4bbfcaf 100644 --- a/keepercommander/commands/ksm.py +++ b/keepercommander/commands/ksm.py @@ -49,6 +49,9 @@ {bcolors.BOLD}Create Application:{bcolors.ENDC} {bcolors.OKGREEN}secrets-manager app create {bcolors.OKBLUE}[NAME]{bcolors.ENDC} + {bcolors.BOLD}Update Application:{bcolors.ENDC} + {bcolors.OKGREEN}secrets-manager app update {bcolors.OKBLUE}[APP NAME OR UID]{bcolors.OKGREEN} --name {bcolors.OKBLUE}[NEW NAME]{bcolors.ENDC} + {bcolors.BOLD}Remove Application:{bcolors.ENDC} {bcolors.OKGREEN}secrets-manager app remove {bcolors.OKBLUE}[APP NAME OR UID]{bcolors.ENDC} Options: @@ -73,15 +76,27 @@ {bcolors.BOLD}Remove Client Device:{bcolors.ENDC} {bcolors.OKGREEN}secrets-manager client remove --app {bcolors.OKBLUE}[APP NAME OR UID] {bcolors.OKGREEN}--client {bcolors.OKBLUE}[NAME OR ID]{bcolors.ENDC} - Options: + Options: --force : Do not prompt for confirmation --client : Client name or ID. Provide `*` or `all` to delete all clients at once - + + {bcolors.BOLD}Revoke Client Device (search all applications):{bcolors.ENDC} + {bcolors.OKGREEN}secrets-manager client revoke --client {bcolors.OKBLUE}[CLIENT ID]{bcolors.ENDC} + Searches all applications for the given client ID and revokes it. + Useful for quickly revoking a leaked device without knowing the application. + The client ID can be found in the device's configuration file as "clientId". + Options: + --force : Do not prompt for confirmation + {bcolors.BOLD}Add Secret to Application:{bcolors.ENDC} {bcolors.OKGREEN}secrets-manager share add --app {bcolors.OKBLUE}[APP NAME OR UID] {bcolors.OKGREEN}--secret {bcolors.OKBLUE}[RECORD OR SHARED FOLDER UID]{bcolors.ENDC} Options: --editable : Allow secrets to be editable by the client + {bcolors.BOLD}Update Secret Permissions:{bcolors.ENDC} + {bcolors.OKGREEN}secrets-manager share update --app {bcolors.OKBLUE}[APP NAME OR UID] {bcolors.OKGREEN}--secret {bcolors.OKBLUE}[RECORD OR SHARED FOLDER UID] {bcolors.OKGREEN}--editable{bcolors.ENDC} + {bcolors.OKGREEN}secrets-manager share update --app {bcolors.OKBLUE}[APP NAME OR UID] {bcolors.OKGREEN}--secret {bcolors.OKBLUE}[RECORD OR SHARED FOLDER UID] {bcolors.OKGREEN}--readonly{bcolors.ENDC} + {bcolors.BOLD}Remove Secret from Application:{bcolors.ENDC} {bcolors.OKGREEN}secrets-manager share remove --app {bcolors.OKBLUE}[APP NAME OR UID] {bcolors.OKGREEN}--secret {bcolors.OKBLUE}[RECORD OR SHARED FOLDER UID]{bcolors.ENDC} @@ -98,8 +113,8 @@ ksm_parser = argparse.ArgumentParser(prog='secrets-manager', description='Keeper Secrets Management (KSM) Commands', add_help=False) ksm_parser.add_argument('command', type=str, action='store', nargs="*", - help='One of: "app list", "app get", "app create", "app remove", "app share", "app unshare", ' + - '"client add", "client remove", "share add" or "share remove"') + help='One of: "app list", "app get", "app create", "app update", "app remove", "app share", ' + + '"app unshare", "client add", "client remove", "share add", "share update" or "share remove"') ksm_parser.add_argument('--secret', '-s', type=str, action='append', required=False, help='Record UID') ksm_parser.add_argument('--app', '-a', type=str, action='store', required=False, @@ -120,6 +135,8 @@ ksm_parser.add_argument('--help', '-h', dest='helpflag', action="store_true", help='Display help') ksm_parser.add_argument('--editable', '-e', action='store_true', required=False, help='Is this share going to be editable or not.') +ksm_parser.add_argument('--readonly', '-r', action='store_true', required=False, + help='Set this share to read-only (used with share update).') ksm_parser.add_argument('--unlock-ip', '-l', dest='unlockIp', action='store_true', help='Unlock IP Address.') ksm_parser.add_argument('--return-tokens', dest='returnTokens', action='store_true', help='Return Tokens') ksm_parser.add_argument('--name', '-n', type=str, dest='name', action='store', help='client name') @@ -223,6 +240,32 @@ def execute(self, params, **kwargs): return result return + if ksm_obj in ['app', 'apps'] and ksm_action in ['update', 'rename']: + if len(ksm_command) < 3: + print( + f'''{bcolors.WARNING}Application name or UID is missing.{bcolors.ENDC}\n''' + f'''\tEx: {bcolors.OKGREEN}secrets-manager app update {bcolors.OKBLUE}MyApp''' + f'''{bcolors.OKGREEN} --name {bcolors.OKBLUE}NewAppName{bcolors.ENDC}''' + ) + return + + app_name_or_uid = ksm_command[2] + new_name = kwargs.get('name') + + if not new_name: + print( + f'''{bcolors.WARNING}New application name is required.{bcolors.ENDC}\n''' + f'''\tEx: {bcolors.OKGREEN}secrets-manager app update {bcolors.OKBLUE}{app_name_or_uid}''' + f'''{bcolors.OKGREEN} --name {bcolors.OKBLUE}NewAppName{bcolors.ENDC}''' + ) + return + + format_type = kwargs.get('format', 'table') + result = KSMCommand.update_app(params, app_name_or_uid, new_name, format_type) + if format_type == 'json' and result: + return result + return + if ksm_obj in ['app', 'apps'] and ksm_action in ['remove', 'rem', 'rm']: app_name_or_uid = ksm_command[2] purge = kwargs.get('purge') @@ -277,6 +320,44 @@ def execute(self, params, **kwargs): KSMCommand.add_app_share(params, secret_uid, app_name_or_uid, is_editable) return + if ksm_obj in ['share', 'secret'] and ksm_action in ['update', 'edit']: + + app_name_or_uid = kwargs.get('app') + secret_uids = kwargs.get('secret') + is_editable = kwargs.get('editable') + is_readonly = kwargs.get('readonly') + + if not app_name_or_uid: + print(bcolors.WARNING + "\nApplication name or UID is required." + bcolors.ENDC) + print(f"Example:" + + bcolors.OKGREEN + " secrets-manager share update --app " + bcolors.OKBLUE + "[APP NAME or APP UID]" + + bcolors.OKGREEN + " --secret " + bcolors.OKBLUE + "[SECRET UID]" + + bcolors.OKGREEN + " --editable" + bcolors.ENDC) + return + + if not secret_uids: + print(bcolors.WARNING + "\nRecord or Shared Folder UID is required." + bcolors.ENDC) + print(f"Example:" + + bcolors.OKGREEN + " secrets-manager share update --app " + bcolors.OKBLUE + "[APP NAME or APP UID]" + + bcolors.OKGREEN + " --secret " + bcolors.OKBLUE + "[SECRET UID]" + + bcolors.OKGREEN + " --editable" + bcolors.ENDC) + return + + if not is_editable and not is_readonly: + print(bcolors.WARNING + "\nPlease specify either --editable or --readonly." + bcolors.ENDC) + print(f"Example:" + + bcolors.OKGREEN + " secrets-manager share update --app " + bcolors.OKBLUE + "[APP NAME or APP UID]" + + bcolors.OKGREEN + " --secret " + bcolors.OKBLUE + "[SECRET UID]" + + bcolors.OKGREEN + " --editable" + bcolors.ENDC) + return + + if is_editable and is_readonly: + print(bcolors.WARNING + "\nCannot specify both --editable and --readonly." + bcolors.ENDC) + return + + KSMCommand.update_app_share(params, secret_uids, app_name_or_uid, is_editable) + return + if ksm_obj in ['share', 'secret'] and ksm_action in ['remove', 'rem', 'rm']: app_name_or_uid = kwargs['app'] if 'app' in kwargs else None secret_uids = kwargs.get('secret') @@ -284,6 +365,17 @@ def execute(self, params, **kwargs): KSMCommand.remove_share(params, app_name_or_uid, secret_uids) return + if ksm_obj in ['client', 'c'] and ksm_action == 'revoke': + client_names_or_ids = kwargs.get('client_names_or_ids') + if not client_names_or_ids: + print(f"{bcolors.WARNING}Client ID is required.{bcolors.ENDC}\n" + f" Usage: {bcolors.OKGREEN}secrets-manager client revoke --client {bcolors.OKBLUE}[CLIENT ID]{bcolors.ENDC}\n" + f" The client ID can be found in the device configuration file as \"clientId\".") + return + force = kwargs.get('force') + KSMCommand.revoke_client(params, client_names_or_ids, force) + return + if ksm_obj in ['client', 'c']: app_name_or_uid = kwargs['app'] if 'app' in kwargs else None @@ -588,9 +680,33 @@ def shorten_client_id(all_clients, original_id, number_of_characters): app_data = { "app_name": app.get("title"), "app_uid": app_uid_str, + "users": [], "client_devices": [], "shares": [] } + + # Fetch user permissions for this application record + app_rec = params.record_cache.get(app_uid_str) + if app_rec: + # Clear cached shares to force a fresh fetch + app_rec.pop('shares', None) + api.get_record_shares(params, [app_uid_str]) + app_rec = params.record_cache.get(app_uid_str) + user_perms = (app_rec or {}).get('shares', {}).get('user_permissions', []) + for up in user_perms: + role = 'owner' if up.get('owner') else 'member' + user_data = { + "username": up.get('username'), + "role": role, + "share_admin": up.get('share_admin', False), + "shareable": up.get('shareable', False), + "editable": up.get('editable', False), + } + if up.get('awaiting_approval'): + user_data["awaiting_approval"] = True + if up.get('expiration') and up['expiration'] > 0: + user_data["expiration"] = up['expiration'] + app_data["users"].append(user_data) if format_type == 'table': print(f'\nSecrets Manager Application\n' @@ -666,6 +782,19 @@ def shorten_client_id(all_clients, original_id, number_of_characters): print(f'\n\t{bcolors.WARNING}No client devices registered for this Application{bcolors.ENDC}') if format_type == 'table': + if app_data["users"]: + print(bcolors.BOLD + "\nApplication Users\n" + bcolors.ENDC) + users_table_fields = ['Username', 'Role', 'Editable', 'Shareable'] + users_table = [] + for u in app_data["users"]: + role_color = bcolors.OKGREEN if u["role"] == "owner" else bcolors.OKBLUE + role_str = role_color + u["role"].capitalize() + bcolors.ENDC + editable_str = (bcolors.OKGREEN + "Yes" + bcolors.ENDC) if u["editable"] else (bcolors.WARNING + "No" + bcolors.ENDC) + shareable_str = (bcolors.OKGREEN + "Yes" + bcolors.ENDC) if u["shareable"] else (bcolors.WARNING + "No" + bcolors.ENDC) + users_table.append([u["username"], role_str, editable_str, shareable_str]) + users_table.sort(key=lambda x: (0 if 'Owner' in x[1] else 1, x[0].lower())) + dump_report_data(users_table, users_table_fields, fmt='table') + print(bcolors.BOLD + "\nApplication Access\n" + bcolors.ENDC) if ai.shares: @@ -913,6 +1042,164 @@ def add_new_v5_app(params, app_name, force_to_add=False, format_type='table'): params.sync_data = True + @staticmethod + def update_app(params, app_name_or_uid, new_name, format_type='table'): + """Rename an existing KSM application.""" + app = KSMCommand.get_app_record(params, app_name_or_uid) + if not app: + if format_type == 'json': + return json.dumps({"error": f"Application '{app_name_or_uid}' not found."}) + else: + logging.warning('Application "%s" not found.' % app_name_or_uid) + return + + existing_app = KSMCommand.get_app_record(params, new_name) + if existing_app and existing_app.get('record_uid') != app.get('record_uid'): + if format_type == 'json': + return json.dumps({"error": f'Application with the name "{new_name}" already exists.'}) + else: + logging.warning('Application with the name "%s" already exists.' % new_name) + return + + app_uid = app.get('record_uid') + record_key = app.get('record_key_unencrypted') + revision = app.get('revision') + + data_dict = KSMCommand.record_data_as_dict(app) + old_name = data_dict.get('title') + data_dict['title'] = new_name + + data_json = json.dumps(data_dict) + data_padded = api.pad_aes_gcm(data_json) + rdata = bytes(data_padded, 'utf-8') if isinstance(data_padded, str) else data_padded + rdata = crypto.encrypt_aes_v2(rdata, record_key) + + ru = record_pb2.RecordUpdate() + ru.record_uid = utils.base64_url_decode(app_uid) + ru.client_modified_time = api.current_milli_time() + ru.revision = revision + ru.data = rdata + + rq = api.get_records_update_request(params) + rq.records.append(ru) + + try: + rs = api.communicate_rest(params, rq, 'vault/records_update', + rs_type=record_pb2.RecordsModifyResponse) + record_uid_bytes = utils.base64_url_decode(app_uid) + rs_status = next((x for x in rs.records if record_uid_bytes == x.record_uid), None) + if rs_status and rs_status.status != record_pb2.RS_SUCCESS: + raise KeeperApiError(record_pb2.RecordModifyResult.keys()[rs_status.status], rs_status.message) + + params.sync_data = True + if format_type == 'json': + return json.dumps({ + "app_uid": app_uid, + "old_name": old_name, + "new_name": new_name, + "message": "Application was successfully renamed" + }, indent=2) + else: + print(bcolors.OKGREEN + + f'Application "{old_name}" was successfully renamed to "{new_name}" (UID: {app_uid})' + + bcolors.ENDC) + except KeeperApiError as kae: + logging.error('Failed to update application: %s' % kae.message) + except Exception as e: + logging.error('Failed to update application: %s' % str(e)) + + @staticmethod + def update_app_share(params, secret_uids, app_name_or_uid, is_editable): + """Update the editable permission on secrets already shared with an application. + + Performs a remove + re-add (matching the web vault behaviour) so that + the encrypted secret key is re-supplied with the new editable flag. + """ + rec_cache_val = KSMCommand.get_app_record(params, app_name_or_uid) + if rec_cache_val is None: + logging.warning('Application "%s" not found.' % app_name_or_uid) + return + + app_record_uid = rec_cache_val.get('record_uid') + master_key = rec_cache_val.get('record_key_unencrypted') + + app_info = KSMCommand.get_app_info(params, app_record_uid) + existing_shares = { + utils.base64_url_encode(s.secretUid): s + for ai in app_info for s in (ai.shares or []) + } + + uids_to_update = [] + for uid in secret_uids: + if uid not in existing_shares: + logging.warning('Secret "%s" is not currently shared with this application. ' + 'Use "share add" to add it first.' % uid) + continue + current_share = existing_shares[uid] + if current_share.editable == is_editable: + perm = "editable" if is_editable else "read-only" + logging.info('Secret "%s" is already %s. No change needed.' % (uid, perm)) + continue + uids_to_update.append(uid) + + if not uids_to_update: + print(bcolors.WARNING + "No share permissions to update." + bcolors.ENDC) + return + + rq_remove = APIRequest_pb2.RemoveAppSharesRequest() + rq_remove.appRecordUid = utils.base64_url_decode(app_record_uid) + rq_remove.shares.extend(utils.base64_url_decode(uid) for uid in uids_to_update) + + try: + api.communicate_rest(params, rq_remove, 'vault/app_share_remove') + except KeeperApiError as kae: + logging.error('Failed to remove shares for update: %s' % kae.message) + return + + app_shares = [] + for uid in uids_to_update: + is_record = uid in params.record_cache + is_shared_folder = api.is_shared_folder(params, uid) + + if is_record: + share_key = params.record_cache[uid]['record_key_unencrypted'] + share_type = 'SHARE_TYPE_RECORD' + elif is_shared_folder: + share_key = params.shared_folder_cache[uid].get('shared_folder_key_unencrypted') + share_type = 'SHARE_TYPE_FOLDER' + else: + logging.warning('UID "%s" not found in local cache. Run sync-down and try again.' % uid) + continue + + encrypted_secret_key = crypto.encrypt_aes_v2(share_key, master_key) + + app_share = APIRequest_pb2.AppShareAdd() + app_share.secretUid = utils.base64_url_decode(uid) + app_share.shareType = APIRequest_pb2.ApplicationShareType.Value(share_type) + app_share.encryptedSecretKey = encrypted_secret_key + app_share.editable = is_editable + + app_shares.append(app_share) + + if not app_shares: + return + + rq_add = APIRequest_pb2.AddAppSharesRequest() + rq_add.appRecordUid = utils.base64_url_decode(app_record_uid) + rq_add.shares.extend(app_shares) + + try: + api.communicate_rest(params, rq_add, 'vault/app_share_add') + perm = "editable" if is_editable else "read-only" + print(bcolors.OKGREEN + + f'\nSuccessfully updated share permissions to {perm} for app uid={app_record_uid}:' + + bcolors.ENDC) + for uid in uids_to_update: + print(f'\t{uid}') + print() + except KeeperApiError as kae: + logging.error('Failed to re-add shares with updated permissions: %s' % kae.message) + @staticmethod def remove_share(params, app_name_or_uid, secret_uids): app = KSMCommand.get_app_record(params, app_name_or_uid) @@ -1010,6 +1297,87 @@ def convert_ids_and_hashes_to_hashes(cnahs, app_uid): api.communicate_rest(params, rq, 'vault/app_client_remove') print(bcolors.OKGREEN + "\nClient removal was successful\n" + bcolors.ENDC) + @staticmethod + def revoke_client(params, client_ids, force=False): + """Search all SM applications for matching client IDs and revoke them. + + Accepts clientId values from device config files (standard or URL-safe base64). + """ + # Normalize input client IDs to URL-safe base64 for comparison + normalized_inputs = [] + for cid in client_ids: + # Config files may use standard base64 (+, /, =) or URL-safe base64 (-, _) + normalized = cid.replace('+', '-').replace('/', '_').rstrip('=') + normalized_inputs.append(normalized) + + # Collect all SM app UIDs from the vault cache + app_uids = [] + app_titles = {} + for rec_cache_val in params.record_cache.values(): + if rec_cache_val.get('version') == 5: + r_uid = rec_cache_val.get('record_uid') + try: + r_data = json.loads(rec_cache_val.get('data_unencrypted').decode('utf-8')) + app_titles[r_uid] = r_data.get('title', r_uid) + except Exception: + app_titles[r_uid] = r_uid + app_uids.append(r_uid) + + if not app_uids: + print(bcolors.WARNING + "No Secrets Manager applications found in the vault." + bcolors.ENDC) + return + + # Fetch app info for all apps in a single API call + rq = APIRequest_pb2.GetAppInfoRequest() + for app_uid in app_uids: + rq.appRecordUid.append(utils.base64_url_decode(app_uid)) + rs = api.communicate_rest(params, rq, 'vault/get_app_info', rs_type=APIRequest_pb2.GetAppInfoResponse) + + # Search for matching clients across all apps + matches = [] # list of (app_uid, app_title, client_name, client_id_b64, client_id_bytes) + for ai in rs.appInfo: + app_uid_str = utils.base64_url_encode(ai.appRecordUid) + app_title = app_titles.get(app_uid_str, app_uid_str) + for c in ai.clients: + client_id_b64 = utils.base64_url_encode(c.clientId) + for norm_input in normalized_inputs: + if client_id_b64 == norm_input or \ + (len(norm_input) >= KSMCommand.CLIENT_SHORT_ID_LENGTH and client_id_b64.startswith(norm_input)): + matches.append((app_uid_str, app_title, c.id, client_id_b64, c.clientId)) + break + + if not matches: + print(bcolors.WARNING + "No matching client devices found across any application." + bcolors.ENDC) + return + + # Display matches and confirm + print(f"\n{bcolors.BOLD}Found {len(matches)} matching client device(s):{bcolors.ENDC}\n") + for app_uid_str, app_title, device_name, client_id_b64, _ in matches: + print(f" Application: {bcolors.OKGREEN}{app_title}{bcolors.ENDC} ({app_uid_str})") + print(f" Device Name: {device_name}") + print(f" Client ID: {client_id_b64[:20]}...") + print() + + if not force: + uc = user_choice(f'\tAre you sure you want to revoke {len(matches)} client device(s)?', 'yn', default='n') + if uc.lower() != 'y': + return + + # Group matches by app and remove + sorted_matches = sorted(matches, key=lambda m: m[0]) + for app_uid_str, group in groupby(sorted_matches, key=lambda m: m[0]): + group_list = list(group) + app_title = group_list[0][1] + client_hashes = [m[4] for m in group_list] + + rm_rq = APIRequest_pb2.RemoveAppClientsRequest() + rm_rq.appRecordUid = utils.base64_url_decode(app_uid_str) + rm_rq.clients.extend(client_hashes) + api.communicate_rest(params, rm_rq, 'vault/app_client_remove') + print(bcolors.OKGREEN + f"Revoked {len(client_hashes)} client(s) from application \"{app_title}\"" + bcolors.ENDC) + + print(bcolors.OKGREEN + "\nClient revocation complete.\n" + bcolors.ENDC) + @staticmethod def add_client(params, app_name_or_uid, count, unlock_ip, first_access_expire_on, access_expire_in_min, client_name=None, config_init=None, silent=False, client_type=enterprise_pb2.GENERAL): diff --git a/keepercommander/commands/pam/gateway_helper.py b/keepercommander/commands/pam/gateway_helper.py index 7d057cec4..ec87111e9 100644 --- a/keepercommander/commands/pam/gateway_helper.py +++ b/keepercommander/commands/pam/gateway_helper.py @@ -1,3 +1,5 @@ +import threading +import time from typing import Sequence, Optional, List from keeper_secrets_manager_core.utils import url_safe_str_to_bytes @@ -9,6 +11,12 @@ from ...proto import pam_pb2, enterprise_pb2 +_gateway_cache_lock = threading.Lock() +_gateway_cache_result = None # type: Optional[Sequence[pam_pb2.PAMController]] +_gateway_cache_time = 0.0 +_GATEWAY_CACHE_TTL = 60 # seconds + + def find_one_gateway_by_uid_or_name(params, gateway_name_or_uid): all_gateways = get_all_gateways(params) gateway_uid_bytes = url_safe_str_to_bytes(gateway_name_or_uid) @@ -26,8 +34,26 @@ def find_one_gateway_by_uid_or_name(params, gateway_name_or_uid): def get_all_gateways(params): # type: (KeeperParams) -> Sequence[pam_pb2.PAMController] - rs = api.communicate_rest(params, None, 'pam/get_controllers', rs_type=pam_pb2.PAMControllersResponse) - return rs.controllers + global _gateway_cache_result, _gateway_cache_time + now = time.time() + if _gateway_cache_result is not None and (now - _gateway_cache_time) < _GATEWAY_CACHE_TTL: + return _gateway_cache_result + with _gateway_cache_lock: + # Re-check after acquiring lock (another thread may have refreshed) + now = time.time() + if _gateway_cache_result is not None and (now - _gateway_cache_time) < _GATEWAY_CACHE_TTL: + return _gateway_cache_result + rs = api.communicate_rest(params, None, 'pam/get_controllers', rs_type=pam_pb2.PAMControllersResponse) + _gateway_cache_result = rs.controllers + _gateway_cache_time = time.time() + return _gateway_cache_result + + +def invalidate_gateway_cache(): + global _gateway_cache_result, _gateway_cache_time + with _gateway_cache_lock: + _gateway_cache_result = None + _gateway_cache_time = 0.0 def find_connected_gateways(all_controllers, identifier): # type: (List[bytes], str) -> Optional[bytes] @@ -74,6 +100,7 @@ def remove_gateway(params, gateway_uid): # type: (KeeperParams, bytes) -> None rq = pam_pb2.PAMGenericUidRequest() rq.uid = gateway_uid rs = api.communicate_rest(params, rq, 'pam/remove_controller', rs_type=pam_pb2.PAMRemoveControllerResponse) + invalidate_gateway_cache() controller = next((x for x in rs.controllers if x.controllerUid == gateway_uid), None) if controller: raise Exception(controller.message) @@ -84,3 +111,12 @@ def set_gateway_max_instances(params, gateway_uid, max_instance_count): # type rq.controllerUid = gateway_uid rq.maxInstanceCount = max_instance_count api.communicate_rest(params, rq, 'pam/set_controller_max_instance_count') + + +def edit_gateway(params, gateway_uid, gateway_name, node_id): + rq = pam_pb2.PAMController() + rq.controllerUid = gateway_uid + rq.controllerName = gateway_name + rq.nodeId = node_id + api.communicate_rest(params, rq, 'pam/modify_controller') + \ No newline at end of file diff --git a/keepercommander/commands/pam/pam_dto.py b/keepercommander/commands/pam/pam_dto.py index b23eb4c46..e5d353671 100644 --- a/keepercommander/commands/pam/pam_dto.py +++ b/keepercommander/commands/pam/pam_dto.py @@ -230,6 +230,79 @@ def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) +# IDENTITY PROVIDER ACTIONS +# These use existing RM action strings with optional idpConfigUid for IdP credential scoping. +# See: KPC Track B (resolve_config_uid on DRAction base class in dr-controller) + + +class GatewayActionIdpInputs: + + def __init__(self, configuration_uid, idp_config_uid=None, **kwargs): + self.configurationUid = configuration_uid + if idp_config_uid and idp_config_uid != configuration_uid: + self.idpConfigUid = idp_config_uid + for key, value in kwargs.items(): + if value is not None: + setattr(self, key, value) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionIdpCreateUser(GatewayAction): + + def __init__(self, inputs: GatewayActionIdpInputs, conversation_id=None): + super().__init__('rm-create-user', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionIdpDeleteUser(GatewayAction): + + def __init__(self, inputs: GatewayActionIdpInputs, conversation_id=None): + super().__init__('rm-delete-user', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionIdpAddUserToGroup(GatewayAction): + + def __init__(self, inputs: GatewayActionIdpInputs, conversation_id=None): + super().__init__('rm-add-user-to-group', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionIdpRemoveUserFromGroup(GatewayAction): + + def __init__(self, inputs: GatewayActionIdpInputs, conversation_id=None): + super().__init__('rm-remove-user-from-group', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionIdpGroupList(GatewayAction): + + def __init__(self, inputs: GatewayActionIdpInputs, conversation_id=None): + super().__init__('rm-group-list', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + +class GatewayActionIdpValidateDomain(GatewayAction): + + def __init__(self, inputs: GatewayActionIdpInputs, conversation_id=None): + super().__init__('rm-validate-domain', inputs=inputs, conversation_id=conversation_id, is_scheduled=False) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) + + # REMOTE MANAGEMENT ACTIONS (KC-1035) class GatewayActionRmCreateUserInputs: diff --git a/keepercommander/commands/pam_cloud/__init__.py b/keepercommander/commands/pam_cloud/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/keepercommander/commands/pam_cloud/pam_privileged_access.py b/keepercommander/commands/pam_cloud/pam_privileged_access.py new file mode 100644 index 000000000..21940efa0 --- /dev/null +++ b/keepercommander/commands/pam_cloud/pam_privileged_access.py @@ -0,0 +1,573 @@ +import argparse +import base64 +import json +import logging + +from keepercommander.commands.base import Command, GroupCommand +from keepercommander.commands.pam.pam_dto import ( + GatewayAction, + GatewayActionIdpInputs, + GatewayActionIdpCreateUser, + GatewayActionIdpDeleteUser, + GatewayActionIdpAddUserToGroup, + GatewayActionIdpRemoveUserFromGroup, + GatewayActionIdpGroupList, +) +from keepercommander.commands.pam.router_helper import router_send_action_to_gateway +from keepercommander.error import CommandError +from keepercommander import api, crypto, record_management, vault +from keepercommander.proto import pam_pb2 +from keepercommander.subfolder import find_parent_top_folder + + +logger = logging.getLogger(__name__) + + +VALID_CONFIG_TYPES = { + 'pamAzureConfiguration', + 'pamOktaConfiguration', + 'pamDomainConfiguration', + 'pamAwsConfiguration', + 'pamGcpConfiguration', +} + + +def resolve_pam_idp_config(params, config_uid): + """Resolve the Identity Provider config UID from a PAM configuration. + + Reads the 'identityProviderUid' custom text field on the PAM config record. + If set, returns the referenced config UID (the IdP). + If empty, returns config_uid itself (self-managing). + """ + record = vault.KeeperRecord.load(params, config_uid) + if not record: + raise CommandError('pam-privileged-access', f'PAM configuration "{config_uid}" not found.') + if not isinstance(record, vault.TypedRecord): + raise CommandError('pam-privileged-access', 'Only typed PAM configuration records are supported.') + + # Check custom field for identityProviderUid + for field in record.custom: + if field.type == 'text' and field.label == 'identityProviderUid': + values = list(field.get_external_value()) + if values and values[0]: + idp_uid = values[0] + idp_record = vault.KeeperRecord.load(params, idp_uid) + if not idp_record: + raise CommandError('pam-privileged-access', + f'PAM Identity Provider config "{idp_uid}" not found.') + if isinstance(idp_record, vault.TypedRecord): + if idp_record.record_type not in VALID_CONFIG_TYPES: + raise CommandError('pam-privileged-access', + f'Referenced config type "{idp_record.record_type}" ' + f'does not support identity provider operations.') + return idp_uid + + # Self-managing — verify config type supports IdP + if record.record_type in VALID_CONFIG_TYPES: + return config_uid + + raise CommandError('pam-privileged-access', + f'No Identity Provider available for config type "{record.record_type}". ' + f'Link one with: pam config edit {config_uid} --identity-provider ') + + +def _get_record_key(params, config_uid): + """Get the record key for a PAM config record.""" + record = vault.KeeperRecord.load(params, config_uid) + if not record or not record.record_key: + raise CommandError('pam-privileged-access', 'Record key unavailable for config record.') + return record.record_key + + +def _encrypt_field(value, record_key): + """Encrypt a string value with the record key, return base64.""" + encrypted = crypto.encrypt_aes_v2(value.encode('utf-8'), record_key) + return base64.b64encode(encrypted).decode('utf-8') + + +def _decrypt_gateway_data(params, config_uid, encrypted_data): + """Decrypt record-key-encrypted data from gateway response.""" + record_key = _get_record_key(params, config_uid) + enc_bytes = base64.b64decode(encrypted_data) + decrypted = crypto.decrypt_aes_v2(enc_bytes, record_key) + return json.loads(decrypted.decode('utf-8')) + + +def _friendly_error(error_msg): + """Convert raw gateway/Azure error messages into user-friendly text.""" + msg_lower = error_msg.lower() + if 'request_resourcenotfound' in msg_lower: + if 'group' in msg_lower: + return 'User is not a member of this group.' + return 'The specified resource was not found.' + if 'request_badrequest' in msg_lower: + # Try to extract the Azure message + try: + parsed = json.loads(error_msg.split(':', 1)[1].strip()) if ':' in error_msg else {} + inner_msg = parsed.get('error', {}).get('message', '') + if inner_msg: + return inner_msg + except (json.JSONDecodeError, IndexError): + pass + if 'already exist' in msg_lower or 'one or more added object references already exist' in msg_lower: + return 'User is already a member of this group.' + if 'does not exist' in msg_lower and 'user' in msg_lower: + return 'User not found in the Identity Provider.' + return error_msg + + +def _dispatch_idp_action(params, gateway_action, gateway_uid=None): + """Dispatch a GatewayAction to the gateway and return the response.""" + conversation_id = GatewayAction.generate_conversation_id() + gateway_action.conversationId = conversation_id + + router_response = router_send_action_to_gateway( + params=params, + gateway_action=gateway_action, + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_uid, + ) + + if not router_response: + raise CommandError('pam-privileged-access', 'No response received from gateway.') + + response = router_response.get('response', {}) + payload_str = response.get('payload') + if not payload_str: + raise CommandError('pam-privileged-access', 'Empty response payload from gateway.') + + payload = json.loads(payload_str) + + if not (payload.get('is_ok') or payload.get('isOk')): + error_msg = payload.get('error', payload.get('message', 'Unknown gateway error')) + raise CommandError('pam-privileged-access', f'Gateway error: {error_msg}') + + data = payload.get('data', {}) + if isinstance(data, dict) and not data.get('success', True): + error_msg = data.get('error', 'Unknown error') + raise CommandError('pam-privileged-access', _friendly_error(error_msg)) + + return payload + + +# --- Command Groups --- + + +class PAMPrivilegedAccessCommand(GroupCommand): + def __init__(self): + super().__init__() + self.register_command('user', PAMAccessUserCommand(), 'Manage privileged IdP users') + self.register_command('group', PAMAccessGroupCommand(), 'Manage privileged IdP groups') + + +class PAMAccessUserCommand(GroupCommand): + def __init__(self): + super().__init__() + self.register_command('provision', PAMAccessUserProvisionCommand(), + 'Provision a privileged user in the Identity Provider') + self.register_command('deprovision', PAMAccessUserDeprovisionCommand(), + 'Deprovision a privileged user from the Identity Provider') + self.register_command('list', PAMAccessUserListCommand(), + 'List users in the Identity Provider') + + +class PAMAccessGroupCommand(GroupCommand): + def __init__(self): + super().__init__() + self.register_command('add-user', PAMAccessGroupAddUserCommand(), + 'Add a user to a privileged group in the Identity Provider') + self.register_command('remove-user', PAMAccessGroupRemoveUserCommand(), + 'Remove a user from privileged group in the Identity Provider') + self.register_command('list', PAMAccessGroupListCommand(), + 'List groups in the Identity Provider') + + +# --- User Commands --- + + +class PAMAccessUserProvisionCommand(Command): + parser = argparse.ArgumentParser(prog='pam access user provision', + description='Provision a privileged user in the Identity Provider') + parser.add_argument('--config', '-c', required=True, dest='config_uid', + help='PAM configuration UID') + parser.add_argument('--username', '-u', required=True, dest='username', + help='Username to create (e.g. testuser or testuser@domain.com)') + parser.add_argument('--domain', '-d', dest='domain', + help='Domain for the user (e.g. domain.com, if not included in --username)') + parser.add_argument('--display-name', '-n', dest='display_name', + help='Display name (defaults to --username)') + parser.add_argument('--password', '-p', dest='password', + help='Initial password (auto-generated if omitted)') + parser.add_argument('--save-record', '-s', dest='save_record', action='store_true', + help='Save provisioned credentials as a pamUser record') + parser.add_argument('--folder', '-f', dest='folder_uid', + help='Folder UID to save the record in (used with --save-record)') + parser.add_argument('--gateway', '-g', dest='gateway', + help='Gateway UID or name') + + def get_parser(self): + return PAMAccessUserProvisionCommand.parser + + def execute(self, params, **kwargs): + config_uid = kwargs['config_uid'] + username = kwargs['username'] + domain = kwargs.get('domain') + + if '@' in username: + if domain: + logging.warning('Username already contains @domain, ignoring --domain flag.') + elif domain: + username = f'{username}@{domain}' + else: + raise CommandError('pam-privileged-access', + 'Username must include domain (e.g. user@domain.com), ' + 'or use --domain to specify one.') + idp_config_uid = resolve_pam_idp_config(params, config_uid) + record_key = _get_record_key(params, config_uid) + + meta = {} + display_name = kwargs.get('display_name') + if display_name: + meta['display_name'] = display_name + encrypted_meta = _encrypt_field(json.dumps(meta), record_key) if meta else None + + inputs = GatewayActionIdpInputs( + configuration_uid=config_uid, + idp_config_uid=idp_config_uid, + user=_encrypt_field(username, record_key), + password=kwargs.get('password'), + ) + if encrypted_meta: + inputs.meta = encrypted_meta + action = GatewayActionIdpCreateUser(inputs=inputs) + + payload = _dispatch_idp_action(params, action, kwargs.get('gateway')) + + response_data = payload.get('data', {}) + if isinstance(response_data, str): + try: + response_data = json.loads(response_data) + except (json.JSONDecodeError, TypeError): + response_data = {} + + if isinstance(response_data, dict) and not response_data.get('success', True): + error = response_data.get('error', 'Unknown error') + raise CommandError('pam-privileged-access', f'Gateway reported failure: {error}') + + # Decrypt the response data if encrypted + data = {} + encrypted_content = response_data.get('data') if isinstance(response_data, dict) else None + if encrypted_content: + try: + data = _decrypt_gateway_data(params, config_uid, encrypted_content) + except Exception: + data = {} + + if isinstance(data, dict): + # Handle different response formats (Azure returns 'name' as string, GCP returns dict) + raw_name = data.get('name', username) + if isinstance(raw_name, dict): + user_name = data.get('primaryEmail', username) + else: + user_name = raw_name + user_password = data.get('password', '') + user_id = data.get('id', '') + else: + user_name = username + user_password = '' + user_id = '' + + logging.info(f'User provisioned successfully.') + print(f' Username: {user_name}') + if user_id: + print(f' User ID: {user_id}') + print(f' Password: {"**********" if user_password else "(none)"}') + + if kwargs.get('save_record'): + display_name = kwargs.get('display_name') or username + record = vault.TypedRecord() + record.type_name = 'pamUser' + record.title = display_name + record.fields.append(vault.TypedField.new_field('login', user_name)) + record.fields.append(vault.TypedField.new_field('password', user_password)) + if user_id: + idp_record = vault.KeeperRecord.load(params, idp_config_uid) + idp_type = idp_record.record_type if isinstance(idp_record, vault.TypedRecord) else '' + idp_label_map = { + 'pamAzureConfiguration': 'Azure User ID', + 'pamGcpConfiguration': 'GCP User ID', + 'pamOktaConfiguration': 'Okta User ID', + 'pamAwsConfiguration': 'AWS User ID', + 'pamDomainConfiguration': 'Domain User ID', + } + user_id_label = idp_label_map.get(idp_type, 'IdP User ID') + record.custom.append(vault.TypedField.new_field('text', user_id, user_id_label)) + + folder_uid = kwargs.get('folder_uid') + if not folder_uid: + shared_folders = find_parent_top_folder(params, config_uid) + if shared_folders: + sf = shared_folders[0] + folder_uid = sf.parent_uid if sf.parent_uid else sf.uid + record_management.add_record_to_folder(params, record, folder_uid) + params.sync_data = True + + print(f' Record UID: {record.record_uid}') + logging.info(f'Credentials saved as pamUser record.') + + +class PAMAccessUserDeprovisionCommand(Command): + parser = argparse.ArgumentParser(prog='pam access user deprovision', + description='Deprovision a privileged user from the Identity Provider') + parser.add_argument('--config', '-c', required=True, dest='config_uid', + help='PAM configuration UID') + parser.add_argument('--username', '-u', required=True, dest='username', + help='Username or user principal name') + parser.add_argument('--delete-record', '-d', dest='delete_record', nargs='?', const='auto', + metavar='RECORD_UID', + help='Delete the associated pamUser record. Optionally pass a record UID, ' + 'or omit to auto-find by Azure user ID.') + parser.add_argument('--force', dest='force', action='store_true', + help='Skip confirmation prompt') + parser.add_argument('--gateway', '-g', dest='gateway', + help='Gateway UID or name') + + def get_parser(self): + return PAMAccessUserDeprovisionCommand.parser + + def execute(self, params, **kwargs): + config_uid = kwargs['config_uid'] + username = kwargs['username'] + idp_config_uid = resolve_pam_idp_config(params, config_uid) + record_key = _get_record_key(params, config_uid) + + if not kwargs.get('force'): + try: + answer = input(f'Are you sure you want to deprovision user "{username}"? (y/N): ') + if answer.lower() not in ('y', 'yes'): + print('Cancelled.') + return + except EOFError: + print('Cancelled.') + return + + inputs = GatewayActionIdpInputs( + configuration_uid=config_uid, + idp_config_uid=idp_config_uid, + user=_encrypt_field(username, record_key), + ) + action = GatewayActionIdpDeleteUser(inputs=inputs) + + _dispatch_idp_action(params, action, kwargs.get('gateway')) + + logging.info(f'User "{username}" deprovisioned successfully.') + + delete_record = kwargs.get('delete_record') + if delete_record: + if delete_record == 'auto': + record_uid = _find_pam_user_record_by_user_id(params, username) + if record_uid: + api.delete_record(params, record_uid) + logging.info(f'Deleted pamUser record {record_uid}.') + else: + logging.warning(f'No pamUser record with matching IdP User ID found for "{username}".') + + else: + record = vault.KeeperRecord.load(params, delete_record) + if record: + api.delete_record(params, delete_record) + logging.info(f'Deleted record {delete_record}.') + else: + logging.warning(f'Record "{delete_record}" not found.') + + +def _find_pam_user_record_by_user_id(params, username): + """Find a pamUser record with an IdP User ID custom field matching the given username.""" + username_lower = username.lower() + for record_uid in params.record_cache: + record = vault.KeeperRecord.load(params, record_uid) + if not isinstance(record, vault.TypedRecord): + continue + if record.type_name != 'pamUser': + continue + # Check login matches (exact or prefix match for username without domain) + login_match = False + for field in record.fields: + if field.type == 'login': + values = list(field.get_external_value()) + if values and values[0]: + login_lower = values[0].lower() + if login_lower == username_lower or login_lower.split('@')[0] == username_lower: + login_match = True + break + if not login_match: + continue + # Prefer records that have an IdP User ID custom field + idp_user_id_labels = {'Azure User ID', 'GCP User ID', 'Okta User ID', 'AWS User ID', + 'Domain User ID', 'IdP User ID'} + for field in record.custom: + if field.label in idp_user_id_labels: + values = list(field.get_external_value()) + if values and values[0]: + return record_uid + return None + + +class PAMAccessUserListCommand(Command): + parser = argparse.ArgumentParser(prog='pam access user list', + description='List users in the Identity Provider') + parser.add_argument('--config', '-c', required=True, dest='config_uid', + help='PAM configuration UID') + parser.add_argument('--gateway', '-g', dest='gateway', + help='Gateway UID or name') + + def get_parser(self): + return PAMAccessUserListCommand.parser + + def execute(self, params, **kwargs): + raise CommandError('pam-privileged-access', + 'User listing is not yet implemented. ' + 'Use "pam idp group list" to list groups, or check the IdP portal directly.') + + +# --- Group Commands --- + + +class PAMAccessGroupListCommand(Command): + parser = argparse.ArgumentParser(prog='pam access group list', + description='List groups in the Identity Provider') + parser.add_argument('--config', '-c', required=True, dest='config_uid', + help='PAM configuration UID') + parser.add_argument('--format', '-f', dest='output_format', choices=['table', 'json'], + default='table', help='Output format (default: table)') + parser.add_argument('--gateway', '-g', dest='gateway', + help='Gateway UID or name') + + def get_parser(self): + return PAMAccessGroupListCommand.parser + + def execute(self, params, **kwargs): + config_uid = kwargs['config_uid'] + idp_config_uid = resolve_pam_idp_config(params, config_uid) + + inputs = GatewayActionIdpInputs( + configuration_uid=config_uid, + idp_config_uid=idp_config_uid, + includeUsers=True, + ) + action = GatewayActionIdpGroupList(inputs=inputs) + + payload = _dispatch_idp_action(params, action, kwargs.get('gateway')) + + # Gateway response: data = {configurationUid, success, data: } + response_data = payload.get('data', {}) + if isinstance(response_data, str): + try: + response_data = json.loads(response_data) + except (json.JSONDecodeError, TypeError): + response_data = {} + + if not isinstance(response_data, dict) or not response_data.get('success'): + error = response_data.get('error', 'Unknown error') if isinstance(response_data, dict) else str(response_data) + raise CommandError('pam-privileged-access', f'Gateway reported failure: {error}') + + # Decrypt the inner encrypted data using the config record key + encrypted_content = response_data.get('data') + if not encrypted_content: + print('No groups found.') + return + + groups = _decrypt_gateway_data(params, config_uid, encrypted_content) + + if kwargs.get('output_format') == 'json': + print(json.dumps(groups, indent=2)) + return + + if not groups or not isinstance(groups, list): + print('No groups found.') + return + + from keepercommander.commands.base import dump_report_data + headers = ['Group ID', 'Name', 'Members'] + table = [] + for group in groups: + if isinstance(group, dict): + users = group.get('users', []) + member_count = len(users) if isinstance(users, list) else 0 + table.append([ + group.get('id', ''), + group.get('name', ''), + str(member_count), + ]) + dump_report_data(table, headers=headers) + + +class PAMAccessGroupAddUserCommand(Command): + parser = argparse.ArgumentParser(prog='pam access group add-user', + description='Add a user to a privileged group in the Identity Provider') + parser.add_argument('--config', '-c', required=True, dest='config_uid', + help='PAM configuration UID') + parser.add_argument('--username', '-u', required=True, dest='username', + help='Username or user ID') + parser.add_argument('--group', '-gr', required=True, dest='group_id', + help='Group name or ID') + parser.add_argument('--gateway', '-g', dest='gateway', + help='Gateway UID or name') + + def get_parser(self): + return PAMAccessGroupAddUserCommand.parser + + def execute(self, params, **kwargs): + config_uid = kwargs['config_uid'] + username = kwargs['username'] + group_id = kwargs['group_id'] + idp_config_uid = resolve_pam_idp_config(params, config_uid) + record_key = _get_record_key(params, config_uid) + + inputs = GatewayActionIdpInputs( + configuration_uid=config_uid, + idp_config_uid=idp_config_uid, + user=_encrypt_field(username, record_key), + groupId=group_id, + ) + action = GatewayActionIdpAddUserToGroup(inputs=inputs) + + _dispatch_idp_action(params, action, kwargs.get('gateway')) + + logging.info(f'User "{username}" added to group "{group_id}".') + + +class PAMAccessGroupRemoveUserCommand(Command): + parser = argparse.ArgumentParser(prog='pam access group remove-user', + description='Remove a user from a privileged group in the Identity Provider') + parser.add_argument('--config', '-c', required=True, dest='config_uid', + help='PAM configuration UID') + parser.add_argument('--username', '-u', required=True, dest='username', + help='Username or user ID') + parser.add_argument('--group', '-gr', required=True, dest='group_id', + help='Group name or ID') + parser.add_argument('--gateway', '-g', dest='gateway', + help='Gateway UID or name') + + def get_parser(self): + return PAMAccessGroupRemoveUserCommand.parser + + def execute(self, params, **kwargs): + config_uid = kwargs['config_uid'] + username = kwargs['username'] + group_id = kwargs['group_id'] + idp_config_uid = resolve_pam_idp_config(params, config_uid) + record_key = _get_record_key(params, config_uid) + + inputs = GatewayActionIdpInputs( + configuration_uid=config_uid, + idp_config_uid=idp_config_uid, + user=_encrypt_field(username, record_key), + groupId=group_id, + ) + action = GatewayActionIdpRemoveUserFromGroup(inputs=inputs) + + _dispatch_idp_action(params, action, kwargs.get('gateway')) + + logging.info(f'User "{username}" removed from group "{group_id}".') \ No newline at end of file diff --git a/keepercommander/commands/pam_cloud/pam_privileged_workflow.py b/keepercommander/commands/pam_cloud/pam_privileged_workflow.py new file mode 100644 index 000000000..2f0adc2f5 --- /dev/null +++ b/keepercommander/commands/pam_cloud/pam_privileged_workflow.py @@ -0,0 +1,433 @@ +import argparse +import base64 +import json +import logging + +from keeper_secrets_manager_core.utils import url_safe_str_to_bytes + +from keepercommander.commands.base import Command, RecordMixin, GroupCommand +from keepercommander.commands.pam.pam_dto import ( + GatewayAction, + GatewayActionIdpInputs, + GatewayActionIdpValidateDomain, +) +from keepercommander.commands.pam.router_helper import router_send_action_to_gateway, _post_request_to_router +from keepercommander.commands.pam_cloud.pam_privileged_access import resolve_pam_idp_config +from keepercommander.commands.tunnel.port_forward.tunnel_helpers import ( + get_config_uid_from_record, + get_gateway_uid_from_record, +) +from keepercommander.error import CommandError +from keepercommander import api, vault +from keepercommander.proto import GraphSync_pb2, pam_pb2, workflow_pb2 + + +ELIGIBLE_RECORD_TYPES = {'pamRemoteBrowser', 'pamDatabase', 'pamMachine', 'pamCloudAccess'} + + +# --- Command Groups --- + +class PAMPrivilegedWorkflowCommand(GroupCommand): + def __init__(self): + super().__init__() + self.register_command('request', PAMRequestAccessCommand(), + 'Request access for a shared record') + self.register_command('status', PAMAccessStateCommand(), + 'List your active access requests and statuses') + self.register_command('requests', PAMApprovalRequestsCommand(), + 'List pending workflow approval requests') + self.register_command('approve', PAMApproveAccessCommand(), + 'Approve or deny a workflow access request') + self.register_command('revoke', PAMRevokeAccessCommand(), + 'Revoke/end an active workflow access session') + self.register_command('config', PAMWorkflowConfigCommand(), + 'Read or configure workflow settings for a resource') + + +class PAMRequestAccessCommand(Command): + parser = argparse.ArgumentParser(prog='pam workflow request', description='Request access to a shared PAM record') + + parser.add_argument('record', action='store', help='Record UID or title of the shared PAM record') + parser.add_argument('--message', '-m', dest='message', action='store', + help='Justification message to include with the request') + + def get_parser(self): + return PAMRequestAccessCommand.parser + + def execute(self, params, **kwargs): + record_name = kwargs.get('record') + record = RecordMixin.resolve_single_record(params, record_name) + + if not record: + raise CommandError('pam-workflow-request-access', f'Record "{record_name}" not found.') + + if not isinstance(record, vault.TypedRecord): + raise CommandError('pam-workflow-request-access', 'Only typed records are supported.') + + if record.record_type not in ELIGIBLE_RECORD_TYPES: + allowed = ', '.join(sorted(ELIGIBLE_RECORD_TYPES)) + raise CommandError('pam-workflow-request-access', + f'Record type "{record.record_type}" is not eligible. Allowed types: {allowed}') + + # Load share info to find the record owner + api.get_record_shares(params, [record.record_uid]) + + rec_cached = params.record_cache.get(record.record_uid) + if not rec_cached: + raise CommandError('pam-workflow-request-access', 'Record not found in cache.') + + shares = rec_cached.get('shares', {}) + user_perms = shares.get('user_permissions', []) + + owner = next((up.get('username') for up in user_perms if up.get('owner')), None) + if not owner: + raise CommandError('pam-workflow-request-access', 'Could not determine record owner.') + + if owner == params.user: + raise CommandError('pam-workflow-request-access', 'You are the owner of this record.') + + # Resolve PAM config and IdP config for this resource + config_uid = get_config_uid_from_record(params, vault, record.record_uid) + if not config_uid: + raise CommandError('pam-workflow-request-access', 'Could not resolve PAM configuration for this resource.') + + gateway_uid = get_gateway_uid_from_record(params, vault, record.record_uid) + + # Validate the requesting user's domain against the IdP + try: + idp_config_uid = resolve_pam_idp_config(params, config_uid) + except CommandError: + idp_config_uid = config_uid + + inputs = GatewayActionIdpInputs( + configuration_uid=config_uid, + idp_config_uid=idp_config_uid, + user=params.user, + resourceUid=record.record_uid, + ) + action = GatewayActionIdpValidateDomain(inputs=inputs) + conversation_id = GatewayAction.generate_conversation_id() + action.conversationId = conversation_id + + router_response = router_send_action_to_gateway( + params=params, + gateway_action=action, + message_type=pam_pb2.CMT_GENERAL, + is_streaming=False, + destination_gateway_uid_str=gateway_uid, + ) + + if router_response: + response = router_response.get('response', {}) + payload_str = response.get('payload') + if payload_str: + payload = json.loads(payload_str) + data = payload.get('data', {}) + if isinstance(data, dict) and not data.get('success', True): + error_msg = data.get('error', 'Domain validation failed') + raise CommandError('pam-workflow-request-access', error_msg) + + # Domain validated — submit workflow access request to krouter + record_uid_bytes = url_safe_str_to_bytes(record.record_uid) + + access_request = workflow_pb2.WorkflowAccessRequest() + access_request.resource.type = GraphSync_pb2.RFT_REC + access_request.resource.value = record_uid_bytes + + message = kwargs.get('message') + if message: + access_request.reason = message + + try: + _post_request_to_router(params, 'request_workflow_access', rq_proto=access_request) + except Exception as e: + raise CommandError('pam-request-access', f'Failed to submit access request: {e}') + + logging.info(f'Access request submitted for record "{record.title}".') + + +class PAMAccessStateCommand(Command): + parser = argparse.ArgumentParser(prog='pam workflow status', description='List your active workflow access requests and their status') + + parser.add_argument('record', nargs='?', action='store', default=None, help='Optional: Record UID to check specific resource workflow state') + + def get_parser(self): + return PAMAccessStateCommand.parser + + def execute(self, params, **kwargs): + stage_names = { + 0: 'Ready to Start', + 1: 'Started', + 2: 'Needs Action', + 3: 'Waiting', + } + condition_names = { + 0: 'Approval', + 1: 'Check-in', + 2: 'MFA', + 3: 'Time', + 4: 'Reason', + 5: 'Ticket', + } + + record_uid = kwargs.get('record') + + if record_uid: + # Use get_workflow_state for a specific resource (more detailed, reads full state) + record_uid_bytes = url_safe_str_to_bytes(record_uid) + rq = workflow_pb2.WorkflowState() + rq.resource.type = GraphSync_pb2.RFT_REC + rq.resource.value = record_uid_bytes + try: + wf = _post_request_to_router( + params, 'get_workflow_state', + rq_proto=rq, + rs_type=workflow_pb2.WorkflowState + ) + except Exception as e: + raise CommandError('pam-access-state', f'Failed to get workflow state: {e}') + + if not wf: + logging.info('No active workflow for this resource.') + return + + workflows = [wf] + else: + # Use get_user_access_state for all workflows + try: + response = _post_request_to_router( + params, 'get_user_access_state', + rs_type=workflow_pb2.UserAccessState + ) + except Exception as e: + raise CommandError('pam-access-state', f'Failed to get access state: {e}') + + if not response or not response.workflows: + logging.info('No active access requests.') + return + + workflows = response.workflows + + import time + now_ms = int(time.time() * 1000) + + for wf in workflows: + flow_uid = base64.urlsafe_b64encode(wf.flowUid).rstrip(b'=').decode() + resource_uid = base64.urlsafe_b64encode(wf.resource.value).rstrip(b'=').decode() if wf.resource.value else 'N/A' + stage = stage_names.get(wf.status.stage, str(wf.status.stage)) if wf.status else 'Unknown' + conditions = ', '.join(condition_names.get(c, str(c)) for c in wf.status.conditions) if wf.status and wf.status.conditions else 'None' + print(f' Flow UID: {flow_uid}') + print(f' Resource UID: {resource_uid}') + print(f' Stage: {stage}') + print(f' Conditions: {conditions}') + if wf.status and wf.status.startedOn: + from datetime import datetime + started = datetime.fromtimestamp(wf.status.startedOn / 1000) + print(f' Started: {started.strftime("%Y-%m-%d %H:%M:%S")}') + if wf.status and wf.status.expiresOn: + from datetime import datetime + expires = datetime.fromtimestamp(wf.status.expiresOn / 1000) + remaining_ms = wf.status.expiresOn - now_ms + if remaining_ms > 0: + remaining_min = remaining_ms // 60000 + remaining_sec = (remaining_ms % 60000) // 1000 + print(f' Expires: {expires.strftime("%Y-%m-%d %H:%M:%S")} ({remaining_min}m {remaining_sec}s remaining)') + else: + print(f' Expires: {expires.strftime("%Y-%m-%d %H:%M:%S")} (expired)') + print() + + +class PAMApprovalRequestsCommand(Command): + parser = argparse.ArgumentParser(prog='pam workflow requests', description='List pending workflow approval requests') + + def get_parser(self): + return PAMApprovalRequestsCommand.parser + + def execute(self, params, **kwargs): + try: + response = _post_request_to_router( + params, 'get_approval_requests', + rs_type=workflow_pb2.ApprovalRequests + ) + except Exception as e: + raise CommandError('pam-approval-requests', f'Failed to get approval requests: {e}') + + if not response or not response.workflows: + logging.info('No pending approval requests.') + return + + for wf in response.workflows: + flow_uid = base64.urlsafe_b64encode(wf.flowUid).rstrip(b'=').decode() + resource_uid = base64.urlsafe_b64encode(wf.resource.value).rstrip(b'=').decode() if wf.resource.value else 'N/A' + reason = wf.reason.decode() if wf.reason else '' + print(f' Flow UID: {flow_uid}') + print(f' User ID: {wf.userId}') + print(f' Resource UID: {resource_uid}') + if reason: + print(f' Reason: {reason}') + print() + + +class PAMApproveAccessCommand(Command): + parser = argparse.ArgumentParser(prog='pam workflow approve', description='Approve a workflow access request') + + parser.add_argument('flow_uid', action='store', help='Flow UID of the request to approve') + parser.add_argument('--deny', action='store_true', help='Deny instead of approve') + parser.add_argument('--reason', dest='denial_reason', action='store', help='Reason for denial') + + def get_parser(self): + return PAMApproveAccessCommand.parser + + def execute(self, params, **kwargs): + flow_uid_str = kwargs.get('flow_uid') + deny = kwargs.get('deny', False) + + # Pad base64url if needed + padding = 4 - len(flow_uid_str) % 4 + if padding != 4: + flow_uid_str += '=' * padding + flow_uid_bytes = base64.urlsafe_b64decode(flow_uid_str) + + approval = workflow_pb2.WorkflowApprovalOrDenial() + approval.flowUid = flow_uid_bytes + approval.deny = deny + + if deny and kwargs.get('denial_reason'): + approval.denialReason = kwargs['denial_reason'] + + endpoint = 'deny_workflow_access' if deny else 'approve_workflow_access' + + try: + _post_request_to_router(params, endpoint, rq_proto=approval) + except Exception as e: + action = 'deny' if deny else 'approve' + raise CommandError('pam-approve-access', f'Failed to {action} access request: {e}') + + if deny: + logging.info(f'Access request denied.') + else: + logging.info(f'Access request approved.') + + +class PAMRevokeAccessCommand(Command): + parser = argparse.ArgumentParser(prog='pam workflow revoke', description='Revoke/end an active workflow access session') + + parser.add_argument('flow_uid', action='store', help='Flow UID of the active access to revoke') + + def get_parser(self): + return PAMRevokeAccessCommand.parser + + def execute(self, params, **kwargs): + flow_uid_str = kwargs.get('flow_uid') + + padding = 4 - len(flow_uid_str) % 4 + if padding != 4: + flow_uid_str += '=' * padding + flow_uid_bytes = base64.urlsafe_b64decode(flow_uid_str) + + ref = GraphSync_pb2.GraphSyncRef() + ref.type = GraphSync_pb2.RFT_WORKFLOW + ref.value = flow_uid_bytes + + try: + _post_request_to_router(params, 'end_workflow', rq_proto=ref) + except Exception as e: + raise CommandError('pam-revoke-access', f'Failed to revoke access: {e}') + + logging.info(f'Access revoked.') + + +class PAMWorkflowConfigCommand(Command): + parser = argparse.ArgumentParser(prog='pam workflow config', description='Read or configure workflow settings for a resource') + + parser.add_argument('record', action='store', help='Record UID of the resource') + parser.add_argument('--set', action='store_true', help='Create or update workflow config') + parser.add_argument('--approvals-needed', type=int, default=None, help='Number of approvals required') + parser.add_argument('--approver', action='append', dest='approvers', help='Approver email (can specify multiple)') + parser.add_argument('--start-on-approval', action='store_true', default=False, help='Auto-start access on approval') + parser.add_argument('--access-length', type=int, default=None, help='Access duration in seconds') + + def get_parser(self): + return PAMWorkflowConfigCommand.parser + + def execute(self, params, **kwargs): + record_uid = kwargs.get('record') + record_uid_bytes = url_safe_str_to_bytes(record_uid) + + ref = GraphSync_pb2.GraphSyncRef() + ref.type = GraphSync_pb2.RFT_REC + ref.value = record_uid_bytes + + if not kwargs.get('set'): + # Read current config + try: + config = _post_request_to_router( + params, 'read_workflow_config', + rq_proto=ref, + rs_type=workflow_pb2.WorkflowConfig + ) + except Exception as e: + raise CommandError('pam-workflow-config', f'Failed to read workflow config: {e}') + + if not config or not config.parameters.approvalsNeeded: + print(' No workflow configuration found for this resource.') + return + + p = config.parameters + print(f' Approvals Needed: {p.approvalsNeeded}') + print(f' Checkout Needed: {p.checkoutNeeded}') + print(f' Start on Approval: {p.startAccessOnApproval}') + print(f' Require Reason: {p.requireReason}') + print(f' Require Ticket: {p.requireTicket}') + print(f' Require MFA: {p.requireMFA}') + print(f' Access Length: {p.accessLength // 1000}s' if p.accessLength else ' Access Length: unlimited') + if config.approvers: + print(f' Approvers:') + for a in config.approvers: + if a.user: + print(f' - {a.user}') + elif a.userId: + print(f' - User ID: {a.userId}') + return + + # Set/update config + wf_params = workflow_pb2.WorkflowParameters() + wf_params.resource.type = GraphSync_pb2.RFT_REC + wf_params.resource.value = record_uid_bytes + + approvals = kwargs.get('approvals_needed') + if approvals is not None: + wf_params.approvalsNeeded = approvals + else: + wf_params.approvalsNeeded = 1 + + wf_params.startAccessOnApproval = kwargs.get('start_on_approval', False) + + access_length_sec = kwargs.get('access_length') or 3600 + wf_params.accessLength = access_length_sec * 1000 # proto field is in milliseconds + + try: + _post_request_to_router(params, 'create_workflow_config', rq_proto=wf_params) + logging.info(f'Workflow config created (approvalsNeeded={wf_params.approvalsNeeded}).') + except Exception as e: + # Try update if create fails + try: + _post_request_to_router(params, 'update_workflow_config', rq_proto=wf_params) + logging.info(f'Workflow config updated (approvalsNeeded={wf_params.approvalsNeeded}).') + except Exception as e2: + raise CommandError('pam-workflow-config', f'Failed to set workflow config: {e2}') + + # Add approvers if specified + approvers = kwargs.get('approvers') + if approvers: + wf_config = workflow_pb2.WorkflowConfig() + wf_config.parameters.CopyFrom(wf_params) + for approver_email in approvers: + a = wf_config.approvers.add() + a.user = approver_email + + try: + _post_request_to_router(params, 'add_workflow_approvers', rq_proto=wf_config) + logging.info(f'Approvers added: {", ".join(approvers)}' ) + except Exception as e: + raise CommandError('pam-workflow-config', f'Failed to add approvers: {e}') \ No newline at end of file diff --git a/keepercommander/commands/pam_import/base.py b/keepercommander/commands/pam_import/base.py index 9bd9ab8d8..22137b8cf 100644 --- a/keepercommander/commands/pam_import/base.py +++ b/keepercommander/commands/pam_import/base.py @@ -125,6 +125,7 @@ def _initialize(self): self.attachments = None # PamAttachmentsObject # common settings (shared across all config types) + self.identity_provider_uid: str = "" # optional, text:identityProviderUid self.pam_resources = {} # {"folderUid": "", "controllerUid": ""} - "resourceRef": unused/legacy # Local environment: pamNetworkConfiguration @@ -245,6 +246,9 @@ def __init__(self, environment_type:str, settings:dict, controller_uid:str, fold self.scripts = PamScriptsObject.load(settings.get("scripts", None)) self.attachments = PamAttachmentsObject.load(settings.get("attachments", None)) + val = settings.get("identity_provider_uid", None) + if isinstance(val, str): self.identity_provider_uid = val + # Local Network if environment_type == "local": val = settings.get("network_id", None) @@ -2035,6 +2039,7 @@ def __init__( self.sftp = sftp if isinstance(sftp, SFTPConnectionSettings) else None self.disableAudio = disableAudio self.resizeMethod = resizeMethod # disable_dynamic_resizing ? "" : "display-update" + # resize-method: "" | "display-update" | "reconnect" # Performance Properties self.enableWallpaper = enableWallpaper self.enableFullWindowDrag = enableFullWindowDrag @@ -3036,6 +3041,24 @@ def is_blank_instance(obj, skiplist: Optional[List[str]] = None): return False return True +def is_database_protocol(protocol): + """ + Returns True if the protocol is one of the database protocols: MYSQL, POSTGRESQL, or SQLSERVER. + + Accepts ConnectionProtocol or the string wire values (e.g. 'mysql', 'postgresql', 'sql-server'). + """ + db_members = ( + ConnectionProtocol.MYSQL, + ConnectionProtocol.POSTGRESQL, + ConnectionProtocol.SQLSERVER, + ) + db_values = {m.value for m in db_members} + if isinstance(protocol, ConnectionProtocol): + return protocol in db_members + if isinstance(protocol, str): + return str(protocol).strip().lower() in db_values + return False + def get_sftp_attribute(obj, name: str) -> str: # Get one of pam_settings.connection.sftp.{sftpResource,sftpResourceUid,sftpUser,sftpUserUid} value: str = "" diff --git a/keepercommander/commands/pam_import/edit.py b/keepercommander/commands/pam_import/edit.py index 881887c56..2a6524266 100644 --- a/keepercommander/commands/pam_import/edit.py +++ b/keepercommander/commands/pam_import/edit.py @@ -451,6 +451,8 @@ def process_pam_config(self, params, project: dict) -> dict: "ai_terminate_session_on_detection": pce.ai_terminate_session_on_detection }) + if pce.identity_provider_uid: args["identity_provider_uid"] = pce.identity_provider_uid + if pce.environment == "local": if pce.network_cidr: args["network_cidr"] = pce.network_cidr if pce.network_id: args["network_id"] = pce.network_id diff --git a/keepercommander/commands/pam_launch/guac_cli/input.py b/keepercommander/commands/pam_launch/guac_cli/input.py index 693077806..ff8142d83 100644 --- a/keepercommander/commands/pam_launch/guac_cli/input.py +++ b/keepercommander/commands/pam_launch/guac_cli/input.py @@ -30,6 +30,7 @@ from __future__ import annotations import collections +import os import sys import logging import threading @@ -50,6 +51,11 @@ _CHORD_SHIFT_INSERT = '\x18' _CHORD_CTRL_INSERT = '\x19' +# Wake the input loop periodically so stop() can join without a keystroke; Unix +# path uses select() with this timeout and only reads stdin while running so +# bytes are not consumed after teardown (they stay for the Commander REPL). +_INPUT_POLL_SEC = 0.1 + class InputHandler: """ @@ -93,7 +99,7 @@ def __init__( def _get_stdin_reader(self): if sys.platform == 'win32': - return WindowsStdinReader() + return WindowsStdinReader(should_continue=lambda: self.running) return UnixStdinReader() def start(self): @@ -114,14 +120,46 @@ def stop(self): self.stdin_reader.restore() self.raw_mode_active = False if self.thread: - self.thread.join(timeout=1.0) + for _ in range(50): + self.thread.join(timeout=0.1) + if not self.thread.is_alive(): + break logging.debug('InputHandler stopped') def _input_loop(self): + import select + while self.running: try: - ch = self.stdin_reader.read_char() - if ch: + if sys.platform == 'win32': + ch = self.stdin_reader.read_char(timeout=_INPUT_POLL_SEC) + if not ch: + continue + if not self.running: + break + self._process_input(ch) + else: + ready, _, _ = select.select( + [sys.stdin], [], [], _INPUT_POLL_SEC + ) + if not self.running: + break + if not ready: + continue + if not self.running: + break + # Read exactly 1 byte at the fd level, bypassing Python's + # BufferedReader (which would call os.read(fd, 8192) and + # consume the whole escape sequence in one shot, leaving + # the fd empty so subsequent select() calls in + # _read_escape_sequence return not-ready and the remaining + # bytes get injected later as stray characters). + raw = os.read(sys.stdin.fileno(), 1) + if not raw: + continue + ch = raw.decode('utf-8', 'replace') + if not self.running: + break self._process_input(ch) except Exception as exc: logging.error(f'Error in input loop: {exc}') @@ -134,6 +172,8 @@ def _process_input(self, ch: str): Process a single character (or the first character of a buffered sequence) from stdin and emit the appropriate key event(s). """ + if not self.running: + return if not ch: return @@ -326,7 +366,14 @@ def read_char(self, timeout: Optional[float] = None) -> Optional[str]: if not ready: return None try: - return sys.stdin.read(1) + # Use os.read for exactly-1-byte fd-level reads so that select() + # and the actual read operate at the same kernel-buffer level. + # sys.stdin.read(1) goes through Python's 8192-byte BufferedReader + # which can consume the entire escape sequence (\x1b[A) in one + # os.read() call, leaving the fd empty and breaking subsequent + # select() calls in _read_escape_sequence. + raw = os.read(sys.stdin.fileno(), 1) + return raw.decode('utf-8', 'replace') if raw else None except Exception: return None @@ -383,7 +430,8 @@ class WindowsStdinReader: the queue transparently. """ - def __init__(self): + def __init__(self, should_continue: Optional[Callable[[], bool]] = None): + self._should_continue = should_continue self._queue: collections.deque = collections.deque() self._hstdin = None self._input_record_type = None @@ -465,6 +513,9 @@ def _read_via_console_input(self, timeout: Optional[float]) -> Optional[str]: if result != 0: # WAIT_OBJECT_0 = 0 return None + if self._should_continue is not None and not self._should_continue(): + return None + record = self._InputRecord() n_read = wintypes.DWORD(0) ok = self._ReadConsoleInputW( @@ -525,6 +576,8 @@ def _read_via_msvcrt(self, timeout: Optional[float]) -> Optional[str]: while True: if timeout is not None and (time.time() - start) >= timeout: return None + if self._should_continue is not None and not self._should_continue(): + return None if msvcrt.kbhit(): ch = msvcrt.getch() # Extended key prefix — read second byte immediately diff --git a/keepercommander/commands/pam_launch/guac_cli/instructions.py b/keepercommander/commands/pam_launch/guac_cli/instructions.py index b93310b51..7c0cbfed2 100644 --- a/keepercommander/commands/pam_launch/guac_cli/instructions.py +++ b/keepercommander/commands/pam_launch/guac_cli/instructions.py @@ -35,6 +35,8 @@ import sys from typing import Any, Callable, Dict, List, Optional, cast +from ..terminal_size import default_handshake_dpi + def is_stdout_pipe_stream_name(name: str) -> bool: """True if Guacamole named pipe is the terminal STDOUT stream (case/whitespace tolerant).""" @@ -112,7 +114,7 @@ def handle_size(args: List[str]) -> None: logging.debug(f"[SIZE] {width}x{height}") elif len(args) >= 3: layer, width, height = args[0], args[1], args[2] - dpi = args[3] if len(args) > 3 else "96" + dpi = args[3] if len(args) > 3 else str(default_handshake_dpi()) logging.debug(f"[SIZE] layer={layer}, {width}x{height} @ {dpi}dpi") else: logging.debug(f"[SIZE] {args}") diff --git a/keepercommander/commands/pam_launch/guac_cli/win_console_input.py b/keepercommander/commands/pam_launch/guac_cli/win_console_input.py index d2c04d05c..d2925708a 100644 --- a/keepercommander/commands/pam_launch/guac_cli/win_console_input.py +++ b/keepercommander/commands/pam_launch/guac_cli/win_console_input.py @@ -14,7 +14,8 @@ When ENABLE_PROCESSED_INPUT is set, the system handles Ctrl+C and raises SIGINT instead of placing 0x03 in the console input queue. Clearing that flag matches -Unix tty.setraw(ISIG off) behavior for the interactive session. +Unix tty.setraw(ISIG off) behavior for the interactive session. Line input and +echo are also cleared so local console echo does not duplicate remote SSH echo. """ from __future__ import annotations @@ -25,13 +26,33 @@ # https://learn.microsoft.com/en-us/windows/console/setconsolemode _ENABLE_PROCESSED_INPUT = 0x0001 +_ENABLE_LINE_INPUT = 0x0002 +_ENABLE_ECHO_INPUT = 0x0004 _STD_INPUT_HANDLE = -10 +# Flags to clear when entering raw mode for ReadConsoleInputW. +# Mirrors what tty.setraw() does on Unix (clears ICANON + ECHO + ISIG): +# ENABLE_PROCESSED_INPUT — deliver Ctrl+C as 0x03 (not SIGINT) +# ENABLE_LINE_INPUT — disable line-editing / cooked-mode buffer +# ENABLE_ECHO_INPUT — disable console host visual echo +# +# Without clearing ENABLE_LINE_INPUT + ENABLE_ECHO_INPUT, the Windows console +# host (conhost.exe / PowerShell window) visually echoes each typed character +# to the screen the moment it enters the input queue — before ReadConsoleInputW +# consumes it. Combined with the SSH server's remote echo (which arrives via +# the guacd STDOUT blob), the user sees every character twice. Windows Terminal +# (ConPTY) suppresses the visual echo on its own, which is why the duplicate is +# intermittent rather than universal. +_RAW_MODE_CLEAR = _ENABLE_PROCESSED_INPUT | _ENABLE_LINE_INPUT | _ENABLE_ECHO_INPUT + def win_stdin_disable_ctrl_c_process_input() -> Optional[int]: """ - Clear ENABLE_PROCESSED_INPUT on the stdin console handle so Ctrl+C is read - as character 0x03 (ReadConsoleInput / msvcrt) instead of raising SIGINT. + Set stdin console handle to raw mode for ReadConsoleInputW: + - Clear ENABLE_PROCESSED_INPUT so Ctrl+C is read as 0x03, not SIGINT. + - Clear ENABLE_LINE_INPUT + ENABLE_ECHO_INPUT to suppress the console + host's visual echo, preventing duplicate characters when the remote + SSH session also echoes typed input. Returns the previous mode for win_stdin_restore_console_mode, or None if not Windows, not a console, or the API failed. @@ -48,11 +69,11 @@ def win_stdin_disable_ctrl_c_process_input() -> Optional[int]: if not kernel32.GetConsoleMode(h, ctypes.byref(mode)): return None old = int(mode.value) - new = old & ~_ENABLE_PROCESSED_INPUT + new = old & ~_RAW_MODE_CLEAR if new == old: return old if not kernel32.SetConsoleMode(h, new): - logging.debug('SetConsoleMode(clear ENABLE_PROCESSED_INPUT) failed') + logging.debug('SetConsoleMode(raw mode) failed') return None return old except Exception as exc: diff --git a/keepercommander/commands/pam_launch/guacamole/client.py b/keepercommander/commands/pam_launch/guacamole/client.py index 7ddd7b720..62aa60b6d 100644 --- a/keepercommander/commands/pam_launch/guacamole/client.py +++ b/keepercommander/commands/pam_launch/guacamole/client.py @@ -457,17 +457,22 @@ def send_mouse_state(self, x: int, y: int, button_mask: int) -> None: return self.tunnel.send_message("mouse", x, y, button_mask) - def send_size(self, width: int, height: int) -> None: + def send_size(self, width: int, height: int, dpi: Optional[int] = None) -> None: """ Send the current screen size to the server. Args: width: Screen width in pixels. height: Screen height in pixels. + dpi: Optional display DPI; when set, sent as a third ``size`` argument + (same shape as the Guacamole handshake ``size`` instruction). """ if not self._is_connected(): return - self.tunnel.send_message("size", width, height) + if dpi is None: + self.tunnel.send_message("size", width, height) + else: + self.tunnel.send_message("size", width, height, dpi) def send_ack(self, stream_index: int, message: str, code: int) -> None: """ diff --git a/keepercommander/commands/pam_launch/launch.py b/keepercommander/commands/pam_launch/launch.py index 8a6be003e..92f696551 100644 --- a/keepercommander/commands/pam_launch/launch.py +++ b/keepercommander/commands/pam_launch/launch.py @@ -33,7 +33,8 @@ _version_at_least, _pam_settings_connection_port, ) -from .terminal_size import get_terminal_size_pixels, is_interactive_tty +from .terminal_size import get_terminal_size_pixels, is_interactive_tty, PIXEL_MODE_GUACD, scale_screen_info +from .terminal_reset import reset_local_terminal_after_pam_session from .guac_cli.stdin_handler import StdinHandler from .guac_cli.input import InputHandler from .guac_cli.session_input import CtrlCCoordinator, PasteOrchestrator @@ -279,6 +280,9 @@ class PAMLaunchCommand(Command): help='Send typed input via stdin pipe bytes (pipe/blob/end, kcm-cli style) instead of ' 'the default Guacamole key-event mode. Paste and Ctrl+C double-tap behave the ' 'same in both modes.') + parser.add_argument('--scale', '-s', required=False, dest='scale', type=int, default=None, + help='Scale pixel width/height by this percentage (e.g. 50 = half canvas, 200 = double). ' + 'Range: [40-400]. Helps when fullscreen TUI programs show garbled layout.') def get_parser(self): return PAMLaunchCommand.parser @@ -812,14 +816,13 @@ def execute(self, params: KeeperParams, **kwargs): if fs_int != 12: fs_disp = fs_disp = str(fs_int) if fs_int is not None else str(pam_connection_font_size).strip() logging.warning( - 'Record %s sets connection.fontSize=%s (guacd default is 12); session recordings ' + 'Record %s sets connection.fontSize=%s (session overrides with 12); session recordings ' 'may look different from this Commander terminal session.', record_uid, fs_disp, ) print( - f'Warning: This record sets fontSize={fs_disp}; session recordings may look ' - 'different from this Commander terminal session.', + f'Warning: connection.fontSize={fs_disp} is ignored here; this session uses font size 12 ', ) # Launch terminal connection @@ -851,10 +854,17 @@ def execute(self, params: KeeperParams, **kwargs): # Always start interactive CLI session # Pass launch_credential_uid to know if ConnectAs payload is needed + _scale = kwargs.get('scale') + if isinstance(_scale, int): + if _scale < 40 or _scale > 400: + raise CommandError('pam launch', + f'--scale must be between 40 and 400 (got {_scale})') self._start_cli_session( - result, params, + result, + params, kwargs.get('launch_credential_uid'), use_stdin=kwargs.get('use_stdin', False), + cli_scale=_scale, ) else: error_msg = result.get('error', 'Unknown error') @@ -869,6 +879,7 @@ def _start_cli_session( params: KeeperParams, launch_credential_uid: Optional[str] = None, use_stdin: bool = False, + cli_scale: Optional[int] = None, ): """ Start CLI session using PythonHandler protocol mode. @@ -1074,14 +1085,17 @@ def signal_handler_fn(signum, frame): # Wait for Guacamole ready print("Waiting for Guacamole connection...") - # Clear screen by printing terminal height worth of newlines - # This prevents raw mode from overwriting existing screen lines + # Clear screen by printing terminal height worth of newlines. + # This prevents raw mode from overwriting existing screen lines. + # Keep in sync: terminal_reset uses max(current rows, this) at exit. + pam_session_start_rows = None terminal_height = 24 try: terminal_size = shutil.get_terminal_size() terminal_height = terminal_size.lines except Exception: terminal_height = 24 + pam_session_start_rows = terminal_height print("\n" * terminal_height, end='', flush=True) guac_ready_timeout = 10.0 # Reduced from 30s - sync triggers readiness quickly @@ -1092,6 +1106,45 @@ def signal_handler_fn(signum, frame): logging.debug( 'Terminal session active. Ctrl+C → remote interrupt; double Ctrl+C (<400 ms) to exit.', ) + # Handshake ``size`` may have been sent while the local console was still + # changing during WebRTC/backend wait (before ``pre_offer_sync`` patches the + # handler). Push the current grid as a runtime ``size`` once data is flowing. + if is_interactive_tty(): + try: + _pr_raw = get_terminal_size_pixels() + if isinstance(cli_scale, int) and cli_scale > 0 and cli_scale != 100: + _pr = scale_screen_info( + _pr_raw["columns"], _pr_raw["rows"], cli_scale + ) + else: + _pr = _pr_raw + python_handler.send_size( + _pr['pixel_width'], + _pr['pixel_height'], + _pr['dpi'], + ) + logging.debug( + 'Post-ready Guacamole size sync: %sx%s -> %sx%spx @ %sdpi%s', + _pr['columns'], + _pr['rows'], + _pr['pixel_width'], + _pr['pixel_height'], + _pr['dpi'], + f' (--scale {cli_scale}%)' if cli_scale else '', + ) + except Exception as _e: + logging.debug('Post-ready size sync skipped: %s', _e) + + # Correct font-size to 12 via argv — GW may have applied the record's + # fontSize during the upstream handshake before our connect instruction. + # font-size=12 is required for pixel metrics to match the 96-DPI cell model. + _record_font_size = tunnel_result.get('settings', {}).get('terminal', {}).get('fontSize') + if _record_font_size and str(_record_font_size) != '12': + try: + python_handler.send_argv('font-size', '12') + logging.debug('Post-ready argv: font-size corrected from %s to 12', _record_font_size) + except Exception as _e: + logging.debug('Post-ready argv font-size skipped: %s', _e) else: logging.warning(f"Guacamole did not report ready within {guac_ready_timeout}s") logging.warning("Terminal may still work if data is flowing.") @@ -1194,6 +1247,7 @@ def _remote_key_ctrl_c() -> None: # so the final resting size is always dispatched. _last_sent_cols = 0 _last_sent_rows = 0 + if _resize_enabled: try: _init_ts = shutil.get_terminal_size() @@ -1210,13 +1264,20 @@ def _remote_key_ctrl_c() -> None: state = tube_registry.get_connection_state(tube_id) if state and state.lower() in ('closed', 'disconnected', 'failed'): logging.debug(f"Tube/connection closed (state: {state}) - exiting") - python_handler.running = False + # Do not set python_handler.running = False here: the input thread would + # still read stdin and send_key/send_stdin would drop bytes (first key + # "eaten" after return to Commander). Stopping order is input_handler + # first in finally, then python_handler.stop() clears running. + shutdown_requested = True break except Exception: # If we can't check state, continue (tube might be closing) pass time.sleep(0.1) elapsed += 0.1 + # SIGINT may set shutdown_requested during sleep; exit before resize/status work. + if shutdown_requested or not python_handler.running: + break # --- Resize polling (Phase 1: cheap cols/rows check) --- # Check every _RESIZE_POLL_EVERY iterations AND at least @@ -1237,28 +1298,42 @@ def _remote_key_ctrl_c() -> None: except Exception: _cur_cols, _cur_rows = _last_sent_cols, _last_sent_rows + # Send only when cols or rows change; pixel values are derived + # from the grid via kcm_cli_approximate_pixels in get_terminal_size_pixels. if (_cur_cols, _cur_rows) != (_last_sent_cols, _last_sent_rows): # Phase 2: size changed - apply debounce then # fetch exact pixels and send. if _now - _last_resize_send_time >= _RESIZE_DEBOUNCE: + if shutdown_requested or not python_handler.running: + break try: - _si = get_terminal_size_pixels(_cur_cols, _cur_rows) + if isinstance(cli_scale, int) and cli_scale > 0 and cli_scale != 100: + _si = scale_screen_info( + _cur_cols, _cur_rows, cli_scale + ) + else: + _si = get_terminal_size_pixels( + _cur_cols, _cur_rows + ) python_handler.send_size( _si['pixel_width'], _si['pixel_height'], _si['dpi'], ) - _last_sent_cols = _cur_cols - _last_sent_rows = _cur_rows + # Track what get_terminal_size_pixels actually used + # (it re-queries internally), not just the poll value. + # If the inner query saw a transient size, the next + # poll will detect the mismatch and send a correction. + _last_sent_cols = _si['columns'] + _last_sent_rows = _si['rows'] _last_resize_send_time = _now logging.debug( - f"Terminal resized: {_cur_cols}x{_cur_rows} cols/rows " - f"-> {_si['pixel_width']}x{_si['pixel_height']}px " - f"@ {_si['dpi']}dpi" + f"Terminal resized: {_si['columns']}x{_si['rows']} cols/rows " + f"-> {_si['pixel_width']}x{_si['pixel_height']}px @ {_si['dpi']}dpi " ) except Exception as _e: logging.debug(f"Failed to send resize: {_e}") - # else: debounce active - _last_sent_cols/rows unchanged + # else: debounce active - last sent * not updated # so the change is re-detected on the next eligible poll. # Status indicator every 30 seconds @@ -1269,8 +1344,14 @@ def _remote_key_ctrl_c() -> None: logging.debug(f"[{int(elapsed)}s] Session active (rx={rx}, tx={tx}, syncs={syncs})") except KeyboardInterrupt: + shutdown_requested = True logging.debug("\n\nExiting CLI terminal mode...") + except Exception as e: + shutdown_requested = True + logging.debug(f"CLI session loop ended abnormally: {e}") + raise + finally: # Stop input handler first (restores terminal) logging.debug("Stopping input handler...") @@ -1279,6 +1360,15 @@ def _remote_key_ctrl_c() -> None: except Exception as e: logging.debug(f"Error stopping input handler: {e}") + # Fullscreen TUIs (nano, mcedit, etc.) may leave alternate screen / + # cursor modes in the outer terminal; restore after raw mode is back. + try: + reset_local_terminal_after_pam_session( + session_start_rows=pam_session_start_rows, + ) + except Exception as e: + logging.debug(f"Terminal reset after pam session: {e}") + # Cleanup - check if connection is already closed to avoid deadlock logging.debug("Stopping Python handler...") try: diff --git a/keepercommander/commands/pam_launch/python_handler.py b/keepercommander/commands/pam_launch/python_handler.py index 63694418a..2d789a3f1 100644 --- a/keepercommander/commands/pam_launch/python_handler.py +++ b/keepercommander/commands/pam_launch/python_handler.py @@ -28,7 +28,7 @@ Python Layer (this module): - Receive Guacamole data via callback - Parse Guacamole instructions - - Respond to 'args' with 'connect', 'size', 'audio', 'image' (handshake) + - Respond to 'args' with 'size', 'audio', 'video', 'image', 'connect' (handshake) - Send Guacamole responses back to Rust Event Types from Rust: @@ -40,7 +40,7 @@ 1. Gateway sends 'select' to guacd with protocol type (ssh, telnet, etc.) 2. guacd responds with 'args' listing required parameters 3. Gateway forwards 'args' to Python via WebRTC - 4. Python responds with 'connect', 'size', 'audio', 'image' + 4. Python responds with 'size', 'audio', 'video', 'image', 'connect' 5. guacd responds with 'ready' (optional, custom extension) 6. guacd sends first 'sync' (TRUE readiness signal - matches JS client behavior) 7. Terminal session begins @@ -59,10 +59,12 @@ import base64 import logging import threading +import time from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Any from .guacamole import Parser, to_instruction from .guac_cli.instructions import create_instruction_router, is_stdout_pipe_stream_name +from .terminal_size import default_handshake_dpi if TYPE_CHECKING: pass @@ -108,7 +110,7 @@ def __init__( - port: Target port - width: Terminal width in pixels - height: Terminal height in pixels - - dpi: Display DPI (default 96) + - dpi: Display DPI (default ``default_handshake_dpi()``) - audio_mimetypes: List of supported audio types (optional) - image_mimetypes: List of supported image types (optional) - guacd_params: Additional guacd parameters dict (optional) @@ -195,6 +197,8 @@ def __init__( # Clipboard stream counter — starts at 200 to avoid collision with # image streams (1–99) and named pipe streams (100–101). self._clipboard_stream_index: int = 200 + # argv stream counter — starts at 300 to avoid collision with clipboard streams. + self._argv_stream_index: int = 300 # Count `pipe` opcodes from guacd (diagnostics when STDOUT never binds). self._guac_pipe_instruction_count: int = 0 @@ -289,13 +293,16 @@ def _on_connection_opened(self, conn_no: int): 3. Gateway receives OpenConnection, starts guacd, connects to target 4. Gateway sends ConnectionOpened back to Rust 5. Rust notifies Python via this callback - 6. Gateway/guacd sends 'args' instruction (Guacamole handshake starts) + 6. Gateway may begin streaming Guacamole **server→client** display ops + (size/move/pipe/img/…). The initial **args** handshake from guacd is + often completed **upstream** (gateway↔guacd) and never forwarded here, + so ``_on_args`` may never run. """ logging.debug(f"✓ Connection opened: conn_no={conn_no}") self.conn_no = conn_no - # The connection is now ready for Guacamole protocol - # Gateway will send guacd's 'args' instruction next + # Guacamole protocol continues on the data channel; handshake may already + # be complete on the gateway side (no 'args' instruction to Python). def _on_data(self, conn_no: int, payload: bytes): """ @@ -350,15 +357,19 @@ def _on_args(self, args: List[str]) -> None: """ Handle args instruction from guacd (via Gateway). - This is the critical handshake step. When guacd receives 'select' from - the Gateway, it responds with 'args' listing the parameters it needs. - We must respond with 'connect' containing the parameter values, - followed by 'size', 'audio', and 'image' instructions. + **When this runs:** Only if the gateway **forwards** guacd's ``args`` + instruction on the PythonHandler data channel. Many PAM deployments + complete select/args/size/connect **between gateway and guacd** and only + then stream display + STDOUT to Python — in that mode this handler is + **never** called. + + When it does run, respond with 'size', 'audio', 'video', 'image', then + 'connect' (guacr-guacd order) so DPI in ``size`` precedes ``connect``. Guacamole handshake sequence: 1. Gateway sends 'select ' to guacd 2. guacd responds with 'args' (list of required params) - 3. We respond with 'connect' (param values), 'size', 'audio', 'image' + 3. We respond with 'size', 'audio', 'video', 'image', 'connect' (matches guacd client / guacr-guacd) 4. guacd responds with 'ready' Args: @@ -374,7 +385,7 @@ def _on_args(self, args: List[str]) -> None: # Build and send the handshake response self._send_handshake_response(list(args)) self.handshake_sent = True - logging.debug("✓ Guacamole handshake sent (connect+size+audio+image)") + logging.debug("✓ Guacamole handshake sent (size,audio,video,image,connect)") except Exception as e: logging.error(f"Error sending handshake response: {e}", exc_info=True) @@ -390,7 +401,9 @@ def _send_handshake_response(self, args_list: List[str]): # Get terminal dimensions (default to standard CLI size) width = settings.get('width', 800) height = settings.get('height', 600) - dpi = settings.get('dpi', 96) + # Default must match ``screen_info['dpi']`` / pixel mode (``default_handshake_dpi()``), + # or Cairo cell metrics diverge from client pixel sizing (wrong $COLUMNS / TUI layout). + dpi = settings.get('dpi', default_handshake_dpi()) # Get guacd parameters (hostname, port, username, password, etc.) guacd_params = settings.get('guacd_params', {}) @@ -421,34 +434,34 @@ def _send_handshake_response(self, args_list: List[str]): connect_args.append(value) - # Send connect instruction - connect_instruction = self._format_instruction('connect', *connect_args) - self._send_to_gateway(connect_instruction) - logging.debug(f"Sent 'connect' with {len(connect_args)} args") - - # Send size instruction + # Guacd expects the same order as guacr-guacd / JS client: size (with DPI), then + # audio/video/image, then connect — not connect before size. size_instruction = self._format_instruction('size', width, height, dpi) self._send_to_gateway(size_instruction) - logging.debug(f"Sent 'size': {width}x{height} @ {dpi}dpi") + logging.debug( + f"Sent 'size' (handshake): {width}x{height} @ {dpi}dpi, " + f"cols/rows:{settings.get('columns')}x{settings.get('rows')}" + ) - # Send audio instruction (supported audio mimetypes) audio_mimetypes = settings.get('audio_mimetypes', []) audio_instruction = self._format_instruction('audio', *audio_mimetypes) self._send_to_gateway(audio_instruction) logging.debug(f"Sent 'audio': {audio_mimetypes}") - # Send video instruction (supported video mimetypes - usually empty for terminal) video_mimetypes = settings.get('video_mimetypes', []) video_instruction = self._format_instruction('video', *video_mimetypes) self._send_to_gateway(video_instruction) logging.debug(f"Sent 'video': {video_mimetypes}") - # Send image instruction (supported image mimetypes) image_mimetypes = settings.get('image_mimetypes', ['image/png', 'image/jpeg', 'image/webp']) image_instruction = self._format_instruction('image', *image_mimetypes) self._send_to_gateway(image_instruction) logging.debug(f"Sent 'image': {image_mimetypes}") + connect_instruction = self._format_instruction('connect', *connect_args) + self._send_to_gateway(connect_instruction) + logging.debug(f"Sent 'connect' with {len(connect_args)} args") + def _on_sync(self, args: List[str]) -> None: """ Handle sync instruction from guacd. @@ -885,20 +898,28 @@ def send_mouse(self, x: int, y: int, buttons: int = 0): except Exception as e: logging.error(f"Error sending mouse event: {e}") - def send_size(self, width: int, height: int, dpi: int = 96): + def send_size(self, width: int, height: int, dpi: Optional[int] = None): """ - Send terminal size to guacd. + Send terminal size to guacd (runtime resize). + + Sends ``size`` with width, height, and DPI (three arguments), matching the + handshake instruction shape and the same DPI as handshake (typically + ``default_handshake_dpi()`` for the active pixel mode) so guacd can + keep Cairo cell metrics aligned on every resize where supported. Only sends if session is active (running and data flowing). Args: width: Width in pixels height: Height in pixels - dpi: DPI (default 96) + dpi: Display DPI for font rasterisation (defaults to ``default_handshake_dpi()``) """ if not self.running or not self.data_flowing.is_set(): return + if dpi is None: + dpi = default_handshake_dpi() + try: instruction = self._format_instruction('size', width, height, dpi) self._send_to_gateway(instruction) @@ -949,6 +970,38 @@ def send_clipboard_stream(self, text: str) -> None: except Exception as exc: logging.error(f"Error sending clipboard stream: {exc}") + def send_argv(self, name: str, value: str) -> None: + """ + Change a guacd connection parameter at runtime via the argv stream protocol. + + Wire format (mirrors Guacamole JS client argv channel): + argv,,text/plain,; + blob,,; + end,; + + Args: + name: Parameter name (e.g. 'font-size', 'color-scheme'). + value: New parameter value as a string. + """ + if not self.running or not self.data_flowing.is_set(): + return + try: + stream_id = str(self._argv_stream_index) + self._argv_stream_index += 1 + value_b64 = base64.b64encode(value.encode('utf-8')).decode('ascii') + self._send_to_gateway( + self._format_instruction('argv', stream_id, 'text/plain', name) + ) + self._send_to_gateway( + self._format_instruction('blob', stream_id, value_b64) + ) + self._send_to_gateway( + self._format_instruction('end', stream_id) + ) + logging.debug('argv sent: %s=%r stream_id=%s', name, value, stream_id) + except Exception as exc: + logging.error('Error sending argv %s: %s', name, exc) + def wait_for_ready(self, timeout: float = 10.0) -> bool: """ Wait for the Guacamole connection to be ready. @@ -1064,7 +1117,7 @@ def create_python_handler( - port: Target port - width: Terminal width in pixels - height: Terminal height in pixels - - dpi: Display DPI (default 96) + - dpi: Display DPI (default ``default_handshake_dpi()``) - audio_mimetypes: List of supported audio types - image_mimetypes: List of supported image types - guacd_params: Dict of guacd connection parameters @@ -1082,7 +1135,7 @@ def create_python_handler( 'protocol': 'ssh', 'width': 800, 'height': 600, - 'dpi': 96, + 'dpi': 192, 'guacd_params': { 'hostname': '192.168.1.100', 'port': '22', diff --git a/keepercommander/commands/pam_launch/terminal_connection.py b/keepercommander/commands/pam_launch/terminal_connection.py index 952c1d96a..76d3c2777 100644 --- a/keepercommander/commands/pam_launch/terminal_connection.py +++ b/keepercommander/commands/pam_launch/terminal_connection.py @@ -22,6 +22,7 @@ import base64 import json import secrets +import shutil import time import uuid from typing import TYPE_CHECKING, Optional, Dict, Any @@ -102,8 +103,10 @@ from .terminal_size import ( DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS, + GUACAMOLE_HANDSHAKE_DPI, _build_screen_info, get_terminal_size_pixels, + scale_screen_info, ) # Computed at import time using the best available platform APIs so the initial @@ -1126,8 +1129,14 @@ def _build_guacamole_connection_settings( terminal_settings = settings.get('terminal', {}) if terminal_settings.get('colorScheme'): guacd_params['color-scheme'] = terminal_settings['colorScheme'] - if terminal_settings.get('fontSize'): - guacd_params['font-size'] = terminal_settings['fontSize'] + _record_font_size = terminal_settings.get('fontSize') + if _record_font_size and str(_record_font_size) != '12': + logging.debug( + "Record font-size %r is not supported for terminal sessions " + "(pixel metrics are calibrated for font-size 12); converting to font-size 12.", + _record_font_size, + ) + guacd_params['font-size'] = '12' # PAM clipboard → guacd: only pass disable-* when the record sets them (guacd "true" = on). _pam_clip = settings.get('clipboard') or {} @@ -1136,6 +1145,15 @@ def _build_guacamole_connection_settings( if _pam_clip.get('disableCopy'): guacd_params['disable-copy'] = 'true' + # Terminal dimensions and DPI must be in guacd_params so the 'connect' instruction + # carries them to guacd. Without these, guacd initialises its font metrics at its + # built-in default DPI (96), giving char_width ≈ 10 px. The kcm pixel formula uses + # char_width = 19 px (calibrated for DPI 192), so a missing DPI in 'connect' causes + # guacd to compute ~2× too many PTY columns from the pixel width we send. + guacd_params['width'] = str(screen_info.get('pixel_width', 800)) + guacd_params['height'] = str(screen_info.get('pixel_height', 600)) + guacd_params['dpi'] = str(screen_info.get('dpi', GUACAMOLE_HANDSHAKE_DPI)) + # Build final connection settings connection_settings = { 'protocol': protocol, @@ -1143,7 +1161,10 @@ def _build_guacamole_connection_settings( 'port': settings.get('port', 22), 'width': screen_info.get('pixel_width', 800), 'height': screen_info.get('pixel_height', 600), - 'dpi': screen_info.get('dpi', 96), + # DPI comes from screen_info (192 for KCM mode, 96 for guacd/scale mode) — also + # carried via guacd_params['dpi'] so the 'connect' instruction sets guacd's font + # metrics to the correct DPI from the start. + 'dpi': screen_info.get('dpi', GUACAMOLE_HANDSHAKE_DPI), 'guacd_params': guacd_params, # Supported mimetypes for terminal sessions 'audio_mimetypes': [], # No audio for terminal @@ -1383,6 +1404,28 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, # No linked user, no supply flags - use pamMachine credentials directly logging.debug("No linked user or supply flags - using pamMachine credentials directly") + # Fresh TTY metrics for the Python handshake — do not use ``screen_info`` from + # function start (DEFAULT_SCREEN_INFO snapshot); that can be import-time or stale. + _scale = kwargs.get('scale') + if isinstance(_scale, int) and _scale > 0 and _scale != 100: + _ts = shutil.get_terminal_size(fallback=(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS)) + screen_info = scale_screen_info(_ts.columns, _ts.lines, _scale) + logging.debug( + "--scale %s%%: guacd-96 base, grid %sx%s → %sx%spx @ %sdpi", + _scale, + screen_info["columns"], + screen_info["rows"], + screen_info["pixel_width"], + screen_info["pixel_height"], + screen_info["dpi"], + ) + else: + try: + screen_info = get_terminal_size_pixels() + except Exception: + logging.debug("Falling back to default terminal size for PythonHandler connection_settings") + screen_info = _build_screen_info(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS) + # Build connection settings for Guacamole handshake # These are used when guacd sends 'args' instruction connection_settings = _build_guacamole_connection_settings( @@ -1484,16 +1527,43 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, # platform-specific APIs (Windows: GetCurrentConsoleFontEx; Unix: # TIOCGWINSZ) to obtain exact pixel dimensions before falling back to # the fixed cell-size estimate. - try: - screen_info = get_terminal_size_pixels() - except Exception: - logging.debug("Falling back to default terminal size for offer payload") - screen_info = _build_screen_info(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS) + _scale = kwargs.get('scale') + if isinstance(_scale, int) and _scale > 0 and _scale != 100: + _ts = shutil.get_terminal_size(fallback=(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS)) + screen_info = scale_screen_info(_ts.columns, _ts.lines, _scale) + logging.debug( + "--scale %s%% (offer): guacd-96 base, grid %sx%s → %sx%spx @ %sdpi", + _scale, + screen_info["columns"], + screen_info["rows"], + screen_info["pixel_width"], + screen_info["pixel_height"], + screen_info["dpi"], + ) + else: + try: + screen_info = get_terminal_size_pixels() + except Exception: + logging.debug("Falling back to default terminal size for offer payload") + screen_info = _build_screen_info(DEFAULT_TERMINAL_COLUMNS, DEFAULT_TERMINAL_ROWS) logging.debug( f"Using terminal metrics columns={screen_info['columns']} rows={screen_info['rows']} -> " f"{screen_info['pixel_width']}x{screen_info['pixel_height']}px @ {screen_info['dpi']}dpi" ) + # Offer payload and Guacamole ``size`` handshake must agree. The handler was created + # earlier; refresh its stored width/height/dpi so a slightly later ``args``/handshake + # matches what we send in the connection offer (avoids PTY geometry vs. local TTY drift). + if python_handler is not None: + python_handler.connection_settings['width'] = screen_info['pixel_width'] + python_handler.connection_settings['height'] = screen_info['pixel_height'] + python_handler.connection_settings['dpi'] = screen_info['dpi'] + # Keep connect-instruction lookup in sync with top-level handshake size/DPI. + _gp_sync = python_handler.connection_settings.setdefault('guacd_params', {}) + _gp_sync['width'] = str(screen_info['pixel_width']) + _gp_sync['height'] = str(screen_info['pixel_height']) + _gp_sync['dpi'] = str(screen_info['dpi']) + offer_payload = offer.get("offer") decoded_offer_bytes = None decoded_offer_text = None @@ -1543,7 +1613,8 @@ def _open_terminal_webrtc_tunnel(params: KeeperParams, "offer": offer_payload, "audio": ["audio/L8", "audio/L16"], # Supported audio codecs "video": [], # Supported video codecs - None for terminal - "size": [screen_info['pixel_width'], screen_info['pixel_height'], screen_info['dpi']], # [width, height, dpi] + # [width, height, dpi] — matches screen_info / pixel mode (e.g. 96 guacd, 192 kcm) + "size": [screen_info['pixel_width'], screen_info['pixel_height'], screen_info['dpi']], "image": ["image/jpeg", "image/png", "image/webp"], # Supported image formats # CRITICAL: Gateway needs 'host' to configure guacd connection "host": { diff --git a/keepercommander/commands/pam_launch/terminal_reset.py b/keepercommander/commands/pam_launch/terminal_reset.py new file mode 100644 index 000000000..eac428487 --- /dev/null +++ b/keepercommander/commands/pam_launch/terminal_reset.py @@ -0,0 +1,268 @@ +# _ __ +# | |/ /___ ___ _ __ ___ _ _ ® +# | ' int: + """Fresh terminal row count at call time (handles resize since session start).""" + try: + if sys.stdout.isatty(): + return max(1, shutil.get_terminal_size().lines) + except Exception as exc: + logging.debug('post-reset line count: %s', exc) + return _FALLBACK_TERMINAL_ROWS + + +def _padding_line_count(session_start_rows: int | None) -> int: + """ + Rows of newline padding: max(current size, session start size). + + Session start printed ``session_start_rows`` newlines; if the user shrinks the + window before exit, ``current`` alone would under-pad; ``max`` fixes that. + If the window grew, ``current`` is larger and dominates. + """ + current = _post_reset_line_count() + if session_start_rows is None: + return current + return max(current, max(1, int(session_start_rows))) + + +def _ansi_terminal_reset_string() -> str: + """VT sequences to undo common fullscreen TUI state (nano, vim, etc.).""" + return ( + '\x1b[?1049l' # rmcup — exit alternate screen + '\x1b[?47l' # old secondary screen off (no-op on modern terminals) + '\x1b[r' # reset scroll region / margins (DECSTBM full screen) + '\x1b[?6l' # origin mode off — cursor addressing to full screen + # xterm mouse / SGR mouse (nano may enable) + '\x1b[?1000l\x1b[?1002l\x1b[?1003l\x1b[?1006l' + '\x1b[?2004l' # bracketed paste off + '\x1b[?25h' # show cursor + '\x1b[0m' # SGR reset + '\x1b[?1l' # DECCKM off — normal cursor keys + '\x1b[?7h' # autowrap on + ) + + +def _post_reset_newlines(session_start_rows: int | None = None) -> str: + """Padding newlines: see ``_padding_line_count`` and launch.py pre-session clear.""" + n = _padding_line_count(session_start_rows) + if sys.platform == 'win32': + return '\r\n' * n + return '\n' * n + + +def _post_reset_clear_viewport() -> str: + """ + Full terminal reset via viewport clear. + + CSI 3 J + 2 J + H clears scrollback and the visible screen. That strips + remote TUI residue aggressively but loses all buffered scrollback in the + emulator — enable in :func:`reset_local_terminal_after_pam_session` only if + that trade-off is acceptable. + """ + return '\x1b[3J\x1b[2J\x1b[H' + + +def _windows_scroll_viewport_to_cursor() -> None: + """ + If the cursor row is outside the visible window, scroll the window so the + cursor is on-screen (typically at the bottom). Fixes ConPTY/Windows + Terminal leaving the viewport at the top while output continues below. + """ + try: + import ctypes + from ctypes import wintypes + + kernel32 = ctypes.windll.kernel32 + STD_OUTPUT_HANDLE = -11 + h = kernel32.GetStdHandle(STD_OUTPUT_HANDLE) + if not h or h == wintypes.HANDLE(-1).value: + return + + class COORD(ctypes.Structure): + _fields_ = [('X', wintypes.SHORT), ('Y', wintypes.SHORT)] + + class SMALL_RECT(ctypes.Structure): + _fields_ = [ + ('Left', wintypes.SHORT), + ('Top', wintypes.SHORT), + ('Right', wintypes.SHORT), + ('Bottom', wintypes.SHORT), + ] + + class CONSOLE_SCREEN_BUFFER_INFO(ctypes.Structure): + _fields_ = [ + ('dwSize', COORD), + ('dwCursorPosition', COORD), + ('wAttributes', wintypes.WORD), + ('srWindow', SMALL_RECT), + ('dwMaximumWindowSize', COORD), + ] + + info = CONSOLE_SCREEN_BUFFER_INFO() + if not kernel32.GetConsoleScreenBufferInfo(h, ctypes.byref(info)): + return + + cy = int(info.dwCursorPosition.Y) + top = int(info.srWindow.Top) + bottom = int(info.srWindow.Bottom) + win_h = bottom - top + 1 + buf_rows = int(info.dwSize.Y) + if win_h <= 0 or buf_rows <= 0: + return + + # Put the cursor on the bottom row of the viewport (normal shell UX). + # ConPTY sometimes leaves the window pinned while the cursor moves down. + max_top = max(0, buf_rows - win_h) + new_top = cy - win_h + 1 + if new_top < 0: + new_top = 0 + elif new_top > max_top: + new_top = max_top + + new_bottom = new_top + win_h - 1 + if new_bottom >= buf_rows: + new_bottom = buf_rows - 1 + new_top = max(0, new_bottom - win_h + 1) + + if new_top == top and new_bottom == bottom: + return + + sr = SMALL_RECT() + sr.Left = info.srWindow.Left + sr.Right = info.srWindow.Right + sr.Top = wintypes.SHORT(new_top) + sr.Bottom = wintypes.SHORT(new_bottom) + + if not kernel32.SetConsoleWindowInfo(h, True, ctypes.byref(sr)): + logging.debug( + 'SetConsoleWindowInfo failed: %s', kernel32.GetLastError() + ) + except Exception as exc: + logging.debug('Windows viewport scroll after pam session: %s', exc) + + +def _flush_stdin_queue_posix() -> None: + try: + if not sys.stdin.isatty(): + return + import termios + + termios.tcflush(sys.stdin.fileno(), termios.TCIFLUSH) + except Exception as exc: + logging.debug('tcflush stdin after pam session: %s', exc) + + +def _flush_stdin_queue_windows() -> None: + """Clear the console input queue (parity with POSIX tcflush TCIFLUSH).""" + try: + import ctypes + + kernel32 = ctypes.windll.kernel32 + STD_INPUT_HANDLE = -10 + h = kernel32.GetStdHandle(STD_INPUT_HANDLE) + # INVALID_HANDLE_VALUE (-1) is truthy in Python; still call Flush and rely on return. + if h == 0: + return + if not kernel32.FlushConsoleInputBuffer(h): + logging.debug('FlushConsoleInputBuffer failed: %s', kernel32.GetLastError()) + except Exception as exc: + logging.debug('FlushConsoleInputBuffer after pam session: %s', exc) + + +def _stty_sane_posix() -> None: + try: + if not sys.stdin.isatty(): + return + subprocess.run( + ['stty', 'sane'], + stdin=sys.stdin, + check=False, + timeout=3, + capture_output=True, + ) + except Exception as exc: + logging.debug('stty sane after pam session: %s', exc) + + +def reset_local_terminal_after_pam_session( + session_start_rows: int | None = None, +) -> None: + """ + Best-effort reset of the interactive terminal after pam launch CLI mode. + + Call only after InputHandler/StdinHandler.stop() has restored raw mode / + Windows console mode so stdin matches the outer shell again. + + Args: + session_start_rows: Row count used at session start for the pre-session + newline clear (``launch.py``). When set, padding uses + ``max(fresh get_terminal_size().lines, session_start_rows)``. + """ + if not sys.stdout.isatty(): + return + + try: + sys.stdout.write(_ansi_terminal_reset_string()) + + # Optional full clear (scrollback loss) + sys.stdout.write(_post_reset_clear_viewport()) + + sys.stdout.write(_post_reset_newlines(session_start_rows=session_start_rows)) + sys.stdout.flush() + except Exception as exc: + logging.debug('Terminal ANSI reset: %s', exc) + + # Queued input: POSIX tcflush; Windows FlushConsoleInputBuffer (before stty on Unix). + if sys.platform == 'win32': + _windows_scroll_viewport_to_cursor() + _flush_stdin_queue_windows() + else: + _flush_stdin_queue_posix() + _stty_sane_posix() diff --git a/keepercommander/commands/pam_launch/terminal_size.py b/keepercommander/commands/pam_launch/terminal_size.py index 5f304391e..d23f6cbc8 100644 --- a/keepercommander/commands/pam_launch/terminal_size.py +++ b/keepercommander/commands/pam_launch/terminal_size.py @@ -14,6 +14,22 @@ Provides get_terminal_size_pixels() which returns terminal dimensions in pixels and DPI for use in Guacamole 'size' instructions. +DPI for text terminals +---------------------- +For plaintext SSH/Telnet sessions, OS display DPI is **irrelevant** to the +remote terminal geometry. guacd uses handshake DPI only for Cairo font +rasterisation (glyph pixel size); runtime ``size`` instructions carry only +width and height — no DPI. The remote PTY rows/cols are derived purely from +pixel dimensions divided by the font cell size (which is set once at session +start from handshake DPI + ``font-size``). + +The ``dpi`` key in screen-info dicts follows the active **pixel mode** +(:data:`DEFAULT_PIXEL_MODE` and ``KEEPER_GUAC_PIXEL_MODE``): ``_KCM_DPI`` (192) +for ``kcm`` or ``_GUACD_DPI`` (96) for ``guacd``, matching the cell-size +formula used for ``pixel_width`` / ``pixel_height``. The platform DPI helpers +(_get_dpi_windows, etc.) are retained for possible future use but are **not +called** on the primary text-terminal code path. + Also defines the screen-size constants and _build_screen_info() fallback that were previously in terminal_connection.py; terminal_connection.py imports them from here to avoid a circular dependency (terminal_connection imports @@ -22,11 +38,24 @@ from __future__ import annotations +import ctypes +import ctypes.wintypes +import ctypes.util import logging +import os import shutil import struct +import subprocess import sys -from typing import Dict, Optional +import time +from typing import Dict, Optional, Tuple + +if sys.platform != 'win32': + import fcntl + import termios +else: + fcntl = None # type: ignore[assignment] + termios = None # type: ignore[assignment] # --------------------------------------------------------------------------- @@ -37,25 +66,219 @@ # pixel-based values that Guacamole expects. DEFAULT_TERMINAL_COLUMNS = 80 DEFAULT_TERMINAL_ROWS = 24 -DEFAULT_CELL_WIDTH_PX = 10 -DEFAULT_CELL_HEIGHT_PX = 19 -DEFAULT_SCREEN_DPI = 96 +DEFAULT_SCREEN_DPI = 192 + +# --------------------------------------------------------------------------- +# Pixel-mode selection +# --------------------------------------------------------------------------- + +# Two pixel modes are supported, selectable via KEEPER_GUAC_PIXEL_MODE env var +# or the pixel_mode parameter on get_terminal_size_pixels / _build_screen_info: +# +# 'kcm' — matches kcm-cli/src/tty.js (default). +# DPI 192, char 19×38 px, plus canvas margins and scrollbar. +# Requires guacd to receive DPI=192 in the handshake 'size' +# instruction so its Pango font metrics yield char_width=19. +# +# 'guacd' — matches guacd's own defaults (terminal.h / display.c). +# DPI 96, char 10×20 px, scrollbar only (no canvas margin). +# Formula: cols = (width - SCROLLBAR_WIDTH) / char_width +# rows = height / char_height +# Guacd uses this when no DPI is supplied or DPI=96 is sent. +# +PIXEL_MODE_KCM = 'kcm' +PIXEL_MODE_GUACD = 'guacd' +DEFAULT_PIXEL_MODE = PIXEL_MODE_GUACD +# TODO: Switch to KCM mode once the fix is included in gateway builds. + +# --------------------------------------------------------------------------- +# kcm-cli pixel constants +# Source: kcm-cli/src/tty.js — calibrated for KCM's canvas renderer at DPI 192. +# columnsToPixels(c) = c * CHAR_WIDTH + TERM_MARGIN * 2 + SCROLLBAR_WIDTH +# rowsToPixels(r) = r * CHAR_HEIGHT + TERM_MARGIN * 2 +# Sanity check 80×24: width = 80*19 + 30 + 16 = 1566, height = 24*38 + 30 = 942. +# --------------------------------------------------------------------------- +_KCM_DPI = 192 # tty.js: export const DPI = 192 +_KCM_CHAR_WIDTH = 19 # tty.js: const CHAR_WIDTH = 19 +_KCM_CHAR_HEIGHT = 38 # tty.js: const CHAR_HEIGHT = 38 +# tty.js: const TERM_MARGIN = Math.floor(2 * DPI / 25.4) +# = 2 mm × DPI px/inch ÷ 25.4 mm/inch → 15 px at DPI 192 +_KCM_TERM_MARGIN = int(2 * _KCM_DPI / 25.4) +_KCM_SCROLLBAR_WIDTH = 16 # tty.js: const SCROLLBAR_WIDTH = 16 + +# --------------------------------------------------------------------------- +# guacd default pixel constants +# Source: guacamole-server terminal.h (GUAC_TERMINAL_DEFAULT_FONT_SIZE=12, +# GUAC_TERMINAL_SCROLLBAR_WIDTH=16) and display.c (Pango metrics). +# At DPI=96 with monospace 12pt, Pango yields char_width≈10, char_height≈20. +# guacd has NO canvas margin — the PTY formula is simply: +# cols = (width - SCROLLBAR_WIDTH) / char_width +# rows = height / char_height +# Sanity check 80×24: width = 80*10 + 16 = 816, height = 24*20 = 480. +# --------------------------------------------------------------------------- +_GUACD_DPI = 96 +_GUACD_CHAR_WIDTH = 10 # Pango approximate digit width, monospace 12pt @ 96 dpi +_GUACD_CHAR_HEIGHT = 20 # Pango ascent + descent, monospace 12pt @ 96 dpi +# Same formula: 2 mm × DPI px/inch ÷ 25.4 mm/inch → 7 px at DPI 96 +_GUACD_TERM_MARGIN = int(2 * _GUACD_DPI / 25.4) +_GUACD_SCROLLBAR_WIDTH = 16 # GUAC_TERMINAL_SCROLLBAR_WIDTH (scrollbar.h) + +# Re-query DPI after this interval so scaling / display changes can be picked up +# without re-running Commander. Individual API calls are cheap; this bounds work. +DPI_CACHE_TTL_SEC = 1.0 + + +# --------------------------------------------------------------------------- +# kcm-cli pixel approximation (primary path — strict KCM parity) +# --------------------------------------------------------------------------- + +def kcm_cli_approximate_pixels(columns: int, rows: int): + """Return (pixel_width, pixel_height) using kcm-cli tty.js formulas. + + Mirrors ``columnsToPixels`` / ``rowsToPixels`` from kcm-cli/src/tty.js so + Commander sends identical Guacamole ``size`` values for the same grid. + All arithmetic is integer (matches JS ``Math.floor`` semantics for the + TERM_MARGIN constant; the multiplications produce exact integers). + """ + pixel_width = columns * _KCM_CHAR_WIDTH + _KCM_TERM_MARGIN * 2 + _KCM_SCROLLBAR_WIDTH + pixel_height = rows * _KCM_CHAR_HEIGHT + _KCM_TERM_MARGIN * 2 + return pixel_width, pixel_height + + +def guacd_default_approximate_pixels(columns: int, rows: int): + """Return (pixel_width, pixel_height) using guacd's own default metrics. + + Mirrors guacd's PTY calculation (terminal.c): + cols = (width - GUAC_TERMINAL_SCROLLBAR_WIDTH) / char_width + rows = height / char_height + where char_width/char_height come from Pango metrics for monospace 12pt at + DPI 96 (≈10 × 20 px). There is no canvas margin — guacd subtracts only + the scrollbar from the total pixel width before dividing. + """ + pixel_width = columns * _GUACD_CHAR_WIDTH + _GUACD_TERM_MARGIN * 2 + _GUACD_SCROLLBAR_WIDTH + pixel_height = rows * _GUACD_CHAR_HEIGHT + _GUACD_TERM_MARGIN * 2 + return pixel_width, pixel_height + + +def scale_screen_info(columns: int, rows: int, scale_pct: int) -> Dict[str, int]: + """Return screen_info using guacd-96 base metrics scaled by *scale_pct* percent. + + Uses :func:`guacd_default_approximate_pixels` (DPI 96, 10×20 px chars) as + the base, then multiplies pixel_width and pixel_height by ``scale_pct / 100``. + Canonical ``pam launch --scale`` path: local console columns/rows with + ``dpi`` 96 aligned to guacd's default PTY pixel model. + + Example: scale_pct=80 → multiply base pixels by 0.80 (shrink) + scale_pct=120 → multiply base pixels by 1.20 (enlarge) + """ + base_w, base_h = guacd_default_approximate_pixels(columns, rows) + factor = scale_pct / 100.0 + return { + "columns": columns, + "rows": rows, + "pixel_width": max(1, int(base_w * factor)), + "pixel_height": max(1, int(base_h * factor)), + "dpi": _GUACD_DPI, + } + + +def _coerce_pixel_mode(mode_value: str) -> str: + """Normalize *mode_value* to ``kcm`` or ``guacd``; unknown/empty → :data:`DEFAULT_PIXEL_MODE`.""" + m = (mode_value or "").strip().lower() + if not m: + m = DEFAULT_PIXEL_MODE.strip().lower() + if m in (PIXEL_MODE_KCM, PIXEL_MODE_GUACD): + return m + return DEFAULT_PIXEL_MODE.strip().lower() + + +def approximate_pixels(columns: int, rows: int, pixel_mode: str = DEFAULT_PIXEL_MODE): + """Return (pixel_width, pixel_height) for the given pixel mode. + + Parameters + ---------- + pixel_mode: + ``'kcm'`` — kcm-cli/tty.js formula (DPI 192, char 19×38, with margin). + ``'guacd'`` — guacd defaults (DPI 96, char 10×20, scrollbar only). + Other values are treated as :data:`DEFAULT_PIXEL_MODE` (see :func:`_coerce_pixel_mode`). + """ + mode = _coerce_pixel_mode(pixel_mode) + if mode == PIXEL_MODE_GUACD: + return guacd_default_approximate_pixels(columns, rows) + return kcm_cli_approximate_pixels(columns, rows) # --------------------------------------------------------------------------- # Fallback helper (previously defined in terminal_connection.py) # --------------------------------------------------------------------------- -def _build_screen_info(columns: int, rows: int) -> Dict[str, int]: - """Convert character columns/rows into pixel measurements for the Gateway.""" +def _dpi_for_cell_fallback() -> int: + """DPI to embed when pixel dimensions come from cell estimates (same as Guacamole resize).""" + if sys.platform == 'win32': + return _get_dpi_windows() + if sys.platform == 'darwin': + return _get_dpi_macos() + return DEFAULT_SCREEN_DPI + +def _resolve_pixel_mode(pixel_mode: Optional[str] = None) -> str: + """Return the effective pixel mode, consulting env when *pixel_mode* is None. + + Unknown strings (including env typos) become :data:`DEFAULT_PIXEL_MODE`. + """ + if pixel_mode is not None: + return _coerce_pixel_mode(pixel_mode) + raw = os.environ.get("KEEPER_GUAC_PIXEL_MODE", DEFAULT_PIXEL_MODE) + return _coerce_pixel_mode(raw) + + +def _dpi_for_mode(pixel_mode: str) -> int: + """Return Guacamole handshake DPI for *pixel_mode*. + + Must stay aligned with :func:`approximate_pixels` cell metrics: + + - ``guacd`` → :data:`_GUACD_DPI` (96) + - ``kcm`` → :data:`_KCM_DPI` (192) + - Any other value → DPI for :data:`DEFAULT_PIXEL_MODE` (via :func:`_coerce_pixel_mode`). + """ + mode = _coerce_pixel_mode(pixel_mode) + if mode == PIXEL_MODE_GUACD: + return _GUACD_DPI + return _KCM_DPI + + +def default_handshake_dpi() -> int: + """DPI for Guacamole text-terminal handshake (matches ``screen_info['dpi']`` when mode is unset). + + Resolved from ``KEEPER_GUAC_PIXEL_MODE`` and :data:`DEFAULT_PIXEL_MODE` via + :func:`_dpi_for_mode`: **192** for ``kcm``, **96** for ``guacd``. + """ + return _dpi_for_mode(_resolve_pixel_mode(None)) + + +# Default for imports and ``settings.get('dpi', …)`` fallbacks; set at import time. +GUACAMOLE_HANDSHAKE_DPI = default_handshake_dpi() + + +def _build_screen_info(columns: int, rows: int, pixel_mode: Optional[str] = None) -> Dict[str, int]: + """Convert character columns/rows into pixel measurements for the Gateway. + + Uses the pixel mode to select the DPI and pixel formula: + - ``'kcm'`` — DPI 192, kcm-cli tty.js formula. + - ``'guacd'`` — DPI 96, guacd-native formula (no canvas margin). + + The active mode is resolved from the *pixel_mode* argument first, then the + ``KEEPER_GUAC_PIXEL_MODE`` environment variable, then :data:`DEFAULT_PIXEL_MODE`. + """ + mode = _resolve_pixel_mode(pixel_mode) col_value = columns if isinstance(columns, int) and columns > 0 else DEFAULT_TERMINAL_COLUMNS row_value = rows if isinstance(rows, int) and rows > 0 else DEFAULT_TERMINAL_ROWS + pixel_w, pixel_h = approximate_pixels(col_value, row_value, mode) return { "columns": col_value, "rows": row_value, - "pixel_width": col_value * DEFAULT_CELL_WIDTH_PX, - "pixel_height": row_value * DEFAULT_CELL_HEIGHT_PX, - "dpi": DEFAULT_SCREEN_DPI, + "pixel_width": pixel_w, + "pixel_height": pixel_h, + "dpi": _dpi_for_mode(mode), } @@ -63,11 +286,11 @@ def _build_screen_info(columns: int, rows: int) -> Dict[str, int]: # Module-level caches # --------------------------------------------------------------------------- -# DPI is cached for the lifetime of the process. Display DPI rarely changes -# during a session - it would only change if the user moves the console window -# to a different-DPI monitor, which is not worth the overhead of re-querying -# on every resize event. +# DPI cache: refreshed at most once per :data:`DPI_CACHE_TTL_SEC` when +# :func:`get_terminal_size_pixels` / resize runs need a DPI value. Not tied to +# "current window screen" on all platforms yet — see platform DPI helpers. _dpi: Optional[int] = None +_dpi_cache_mono: Optional[float] = None # time.monotonic() when *_dpi* was stored # TIOCGWINSZ pixel support: None = untested, True = returns non-zero pixels, # False = permanently disabled (returned all-zero pixel fields). When False, @@ -95,30 +318,49 @@ def is_interactive_tty() -> bool: _is_tty = sys.stdin.isatty() and sys.stdout.isatty() except Exception: _is_tty = False - return _is_tty + return bool(_is_tty) # --------------------------------------------------------------------------- # Platform DPI helpers # --------------------------------------------------------------------------- +def _invalidate_dpi_cache_if_stale() -> None: + """Clear module DPI cache when the TTL has expired.""" + global _dpi, _dpi_cache_mono + if _dpi is None: + return + if _dpi_cache_mono is None: + _dpi = None + return + if time.monotonic() - _dpi_cache_mono >= DPI_CACHE_TTL_SEC: + _dpi = None + _dpi_cache_mono = None + + +def _store_dpi(value: int) -> int: + """Store DPI and refresh cache timestamp.""" + global _dpi, _dpi_cache_mono + _dpi = int(value) + _dpi_cache_mono = time.monotonic() + return _dpi + + def _get_dpi_windows() -> int: """Return display DPI on Windows via ctypes, cached for the session. Tries GetDpiForSystem (shcore.dll, Windows 8.1+) first, then falls back - to GetDeviceCaps(LOGPIXELSX). Returns DEFAULT_SCREEN_DPI (96) on failure. + to GetDeviceCaps(LOGPIXELSX). Returns :data:`DEFAULT_SCREEN_DPI` on failure. """ - global _dpi + _invalidate_dpi_cache_if_stale() if _dpi is not None: return _dpi try: - import ctypes # GetDpiForSystem - available on Windows 8.1+ via shcore.dll try: dpi = ctypes.windll.shcore.GetDpiForSystem() if dpi and dpi > 0: - _dpi = int(dpi) - return _dpi + return _store_dpi(int(dpi)) except Exception: pass # Fallback: GDI GetDeviceCaps(LOGPIXELSX) @@ -128,48 +370,221 @@ def _get_dpi_windows() -> int: try: dpi = ctypes.windll.gdi32.GetDeviceCaps(hdc, LOGPIXELSX) if dpi and dpi > 0: - _dpi = int(dpi) - return _dpi + return _store_dpi(int(dpi)) finally: ctypes.windll.user32.ReleaseDC(0, hdc) except Exception as e: logging.debug(f"Could not query Windows DPI: {e}") - _dpi = DEFAULT_SCREEN_DPI - return _dpi + return _store_dpi(DEFAULT_SCREEN_DPI) + + +def _get_dpi_macos() -> int: + """Return approximate main-display DPI on macOS via CoreGraphics. + + Uses ``CGDisplayScreenSize`` (physical size in mm) and ``CGDisplayPixelsWide`` / + ``CGDisplayPixelsHigh`` to compute horizontal/vertical DPI and averages them. + This uses ``CGMainDisplayID()`` (system menu-bar display), not the screen that + holds the terminal window. Falls back to :data:`DEFAULT_SCREEN_DPI` on failure. + + Cached for :data:`DPI_CACHE_TTL_SEC` (then refreshed on next query). + """ + _invalidate_dpi_cache_if_stale() + if _dpi is not None: + return _dpi + try: + path = ctypes.util.find_library('CoreGraphics') + if not path: + return _store_dpi(DEFAULT_SCREEN_DPI) + cg = ctypes.CDLL(path) + CGDirectDisplayID = ctypes.c_uint32 + + cg.CGMainDisplayID.restype = CGDirectDisplayID + cg.CGMainDisplayID.argtypes = [] + + class CGSize(ctypes.Structure): + _fields_ = [('width', ctypes.c_double), ('height', ctypes.c_double)] + + cg.CGDisplayScreenSize.argtypes = [CGDirectDisplayID] + cg.CGDisplayScreenSize.restype = CGSize + + cg.CGDisplayPixelsWide.argtypes = [CGDirectDisplayID] + cg.CGDisplayPixelsWide.restype = ctypes.c_size_t + cg.CGDisplayPixelsHigh.argtypes = [CGDirectDisplayID] + cg.CGDisplayPixelsHigh.restype = ctypes.c_size_t + + main = cg.CGMainDisplayID() + size_mm = cg.CGDisplayScreenSize(main) + pw = float(cg.CGDisplayPixelsWide(main)) + ph = float(cg.CGDisplayPixelsHigh(main)) + + if size_mm.width > 0 and size_mm.height > 0 and pw > 0 and ph > 0: + dpi_x = pw / (size_mm.width / 25.4) + dpi_y = ph / (size_mm.height / 25.4) + dpi = int(round((dpi_x + dpi_y) / 2.0)) + if 72 <= dpi <= 600: + return _store_dpi(dpi) + except Exception as e: + logging.debug(f"Could not query macOS DPI: {e}") + return _store_dpi(DEFAULT_SCREEN_DPI) + + +def _try_linux_x11_xft_dpi() -> Optional[int]: + """Read Xft.dpi from the X11 resource database (``XGetDefault``). + + Works on X11 sessions and often under XWayland when ``DISPLAY`` is set. + Returns None if libX11 is unavailable, the display cannot be opened, or + Xft.dpi is unset. + """ + try: + lib_path = ctypes.util.find_library('X11') + if not lib_path: + return None + x11 = ctypes.CDLL(lib_path) + XOpenDisplay = x11.XOpenDisplay + XOpenDisplay.argtypes = [ctypes.c_char_p] + XOpenDisplay.restype = ctypes.c_void_p + XCloseDisplay = x11.XCloseDisplay + XCloseDisplay.argtypes = [ctypes.c_void_p] + XCloseDisplay.restype = ctypes.c_int + XGetDefault = x11.XGetDefault + XGetDefault.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p] + XGetDefault.restype = ctypes.c_char_p + + dpy = XOpenDisplay(None) + if not dpy: + return None + try: + raw = XGetDefault(dpy, b'Xft', b'dpi') + if not raw: + return None + val = raw.decode('utf-8', errors='ignore').strip() + dpi_f = float(val) + dpi = int(round(dpi_f)) + if 72 <= dpi <= 600: + return dpi + finally: + XCloseDisplay(dpy) + except Exception as e: + logging.debug(f"Linux X11 Xft.dpi query failed: {e}") + return None + + +def _try_linux_gnome_text_scaling_dpi() -> Optional[int]: + """GNOME (and many Wayland) sessions: ``text-scaling-factor`` × :data:`DEFAULT_SCREEN_DPI`.""" + try: + if not shutil.which('gsettings'): + return None + r = subprocess.run( + [ + 'gsettings', + 'get', + 'org.gnome.desktop.interface', + 'text-scaling-factor', + ], + capture_output=True, + text=True, + timeout=0.5, + check=False, + ) + if r.returncode != 0 or not (r.stdout or '').strip(): + return None + line = (r.stdout or '').strip().strip("'\"") + factor = float(line) + if factor <= 0: + return None + dpi = int(round(DEFAULT_SCREEN_DPI * factor)) + if 72 <= dpi <= 600: + return dpi + except Exception as e: + logging.debug(f"Linux gsettings text-scaling query failed: {e}") + return None + + +def _try_linux_env_scale_dpi() -> Optional[int]: + """Derive effective DPI from common toolkit / Qt environment variables.""" + gdk = os.environ.get('GDK_SCALE') + if gdk: + try: + factor = float(gdk) + if factor > 0: + dpi = int(round(DEFAULT_SCREEN_DPI * factor)) + if 72 <= dpi <= 600: + return dpi + except ValueError: + pass + qt = os.environ.get('QT_SCALE_FACTOR') or os.environ.get('QT_SCREEN_SCALE_FACTORS') + if qt: + first = qt.split(';')[0].strip().split(',')[0].strip() + try: + factor = float(first) + if factor > 0: + dpi = int(round(DEFAULT_SCREEN_DPI * factor)) + if 72 <= dpi <= 600: + return dpi + except ValueError: + pass + return None + + +def _get_dpi_linux() -> int: + """Return display DPI on Linux without extra Python dependencies. + + Tries in order: + + 1. **X11** — ``Xft.dpi`` from the X resource database (``libX11``). + 2. **GNOME** — ``gsettings`` ``text-scaling-factor`` × :data:`DEFAULT_SCREEN_DPI`. + 3. **Environment** — ``GDK_SCALE`` or ``QT_SCALE_FACTOR`` / ``QT_SCREEN_SCALE_FACTORS`` + multiplied by :data:`DEFAULT_SCREEN_DPI`. + + Falls back to :data:`DEFAULT_SCREEN_DPI` when nothing applies (e.g. SSH + with no display, minimal containers, or non-GNOME Wayland without Xft). + """ + _invalidate_dpi_cache_if_stale() + if _dpi is not None: + return _dpi + for probe in ( + _try_linux_x11_xft_dpi, + _try_linux_gnome_text_scaling_dpi, + _try_linux_env_scale_dpi, + ): + found = probe() + if found is not None: + return _store_dpi(found) + return _store_dpi(DEFAULT_SCREEN_DPI) def _get_dpi_unix() -> int: - """Return display DPI on Unix/macOS, cached for the session. + """Return display DPI on Unix. - There is no portable, connection-independent way to query DPI from a - terminal process on Unix without a display-server connection. Standard - Guacamole sessions use 96 DPI as the baseline, so we return that. + On **macOS**, uses CoreGraphics physical screen size + pixel dimensions + (see :func:`_get_dpi_macos`). On **Linux**, uses :func:`_get_dpi_linux`. + Other Unix systems use :data:`DEFAULT_SCREEN_DPI`. + + Cached for :data:`DPI_CACHE_TTL_SEC` (then refreshed). """ - global _dpi - if _dpi is None: - _dpi = DEFAULT_SCREEN_DPI - return _dpi + _invalidate_dpi_cache_if_stale() + if _dpi is not None: + return _dpi + if sys.platform == 'darwin': + return _get_dpi_macos() + if sys.platform.startswith('linux'): + return _get_dpi_linux() + return _store_dpi(DEFAULT_SCREEN_DPI) # --------------------------------------------------------------------------- # Platform pixel-dimension helpers # --------------------------------------------------------------------------- -def _get_pixels_windows(columns: int, rows: int): - """Return (pixel_width, pixel_height) on Windows via GetCurrentConsoleFontEx. - - Retrieves the console font glyph size in pixels (dwFontSize.X / .Y) and - multiplies by columns/rows to get the total terminal window pixel size. - Returns (0, 0) on any failure so the caller can fall back gracefully. - """ +def _windows_console_font_cell() -> Optional[Tuple[int, int]]: + """Return console font cell (dwFontSize.X, dwFontSize.Y), or None if unavailable.""" + if sys.platform != 'win32': + return None try: - import ctypes - import ctypes.wintypes - STD_OUTPUT_HANDLE = -11 handle = ctypes.windll.kernel32.GetStdHandle(STD_OUTPUT_HANDLE) if not handle or handle == ctypes.wintypes.HANDLE(-1).value: - return 0, 0 + return None class COORD(ctypes.Structure): _fields_ = [('X', ctypes.c_short), ('Y', ctypes.c_short)] @@ -188,15 +603,27 @@ class CONSOLE_FONT_INFOEX(ctypes.Structure): font_info.cbSize = ctypes.sizeof(CONSOLE_FONT_INFOEX) if ctypes.windll.kernel32.GetCurrentConsoleFontEx(handle, False, ctypes.byref(font_info)): - fw = font_info.dwFontSize.X - fh = font_info.dwFontSize.Y + fw = int(font_info.dwFontSize.X) + fh = int(font_info.dwFontSize.Y) if fw > 0 and fh > 0: - return columns * fw, rows * fh - - return 0, 0 + return fw, fh except Exception as e: logging.debug(f"GetCurrentConsoleFontEx failed: {e}") + return None + + +def _get_pixels_windows(columns: int, rows: int): + """Return (pixel_width, pixel_height) on Windows via GetCurrentConsoleFontEx. + + Retrieves the console font glyph size in pixels (dwFontSize.X / .Y) and + multiplies by columns/rows to get the total terminal window pixel size. + Returns (0, 0) on any failure so the caller can fall back gracefully. + """ + cell = _windows_console_font_cell() + if not cell: return 0, 0 + fw, fh = cell + return columns * fw, rows * fh def _get_pixels_unix(columns: int, rows: int): @@ -210,10 +637,9 @@ def _get_pixels_unix(columns: int, rows: int): global _tiocgwinsz_works if _tiocgwinsz_works is False: return 0, 0 + if fcntl is None or termios is None: + return 0, 0 try: - import fcntl - import termios - buf = struct.pack('HHHH', 0, 0, 0, 0) result = fcntl.ioctl(sys.stdout.fileno(), termios.TIOCGWINSZ, buf) # struct winsize layout: ws_row, ws_col, ws_xpixel, ws_ypixel @@ -237,6 +663,7 @@ def _get_pixels_unix(columns: int, rows: int): def get_terminal_size_pixels( columns: Optional[int] = None, rows: Optional[int] = None, + pixel_mode: Optional[str] = None, ) -> Dict[str, int]: """Return terminal size in pixels and DPI for a Guacamole 'size' instruction. @@ -244,30 +671,21 @@ def get_terminal_size_pixels( for maximum accuracy. The optional *columns* and *rows* arguments serve as a fallback used only when the internal query fails. - Platform behaviour - ------------------ - Windows - Uses GetCurrentConsoleFontEx to obtain the console font glyph size in - pixels, then multiplies columns × rows for exact pixel dimensions. - DPI is obtained via GetDpiForSystem (or GetDeviceCaps as fallback). - Both are cached for the session. - - Unix / macOS - Tries TIOCGWINSZ ws_xpixel / ws_ypixel for pixel dimensions. If those - fields are zero (common - many terminal emulators do not fill them in), - the failure is cached permanently and the cell-size fallback is used on - every subsequent call without retrying the ioctl. - - Fallback - When platform-specific pixel APIs return (0, 0), falls back to - _build_screen_info(columns, rows) which uses DEFAULT_CELL_WIDTH_PX / - DEFAULT_CELL_HEIGHT_PX to estimate pixel dimensions from char cells. + Parameters + ---------- + pixel_mode: + ``'kcm'`` — kcm-cli/tty.js formula, DPI 192, char 19×38 + margin. + ``'guacd'`` — guacd-native formula, DPI 96, char 10×20, scrollbar only. + ``None`` — resolved from ``KEEPER_GUAC_PIXEL_MODE`` env var, then + :data:`DEFAULT_PIXEL_MODE`. Returns ------- dict with keys: columns, rows, pixel_width, pixel_height, dpi - (same structure as _build_screen_info - drop-in compatible) + (same structure as _build_screen_info — drop-in compatible) """ + mode = _resolve_pixel_mode(pixel_mode) + # Resolve caller-supplied hints as fallback values fallback_cols = columns if (isinstance(columns, int) and columns > 0) else DEFAULT_TERMINAL_COLUMNS fallback_rows = rows if (isinstance(rows, int) and rows > 0) else DEFAULT_TERMINAL_ROWS @@ -281,22 +699,11 @@ def get_terminal_size_pixels( actual_cols = fallback_cols actual_rows = fallback_rows - # Platform-specific pixel dimensions - if sys.platform == 'win32': - pixel_w, pixel_h = _get_pixels_windows(actual_cols, actual_rows) - dpi = _get_dpi_windows() - else: - pixel_w, pixel_h = _get_pixels_unix(actual_cols, actual_rows) - dpi = _get_dpi_unix() - - # Fallback: platform API returned (0, 0) - use fixed cell-size estimate - if pixel_w <= 0 or pixel_h <= 0: - return _build_screen_info(actual_cols, actual_rows) - + pixel_w, pixel_h = approximate_pixels(actual_cols, actual_rows, mode) return { "columns": actual_cols, "rows": actual_rows, "pixel_width": pixel_w, "pixel_height": pixel_h, - "dpi": dpi, + "dpi": _dpi_for_mode(mode), } diff --git a/keepercommander/commands/pam_service/list.py b/keepercommander/commands/pam_service/list.py index 56ba15d0c..995be7d1d 100644 --- a/keepercommander/commands/pam_service/list.py +++ b/keepercommander/commands/pam_service/list.py @@ -56,7 +56,7 @@ def execute(self, params: KeeperParams, **kwargs): if user_record is None: continue acl = user_service.get_acl(resource_record.record_uid, user_record.record_uid) - if acl is None or (acl.is_service is False and acl.is_task is False): + if acl is None or (acl.is_service is False and acl.is_task is False and acl.is_iis_pool is False): continue if user_record.record_uid not in service_map: service_map[user_record.record_uid] = { diff --git a/keepercommander/commands/pedm/pedm_admin.py b/keepercommander/commands/pedm/pedm_admin.py index 6365cbedf..372b84e2a 100644 --- a/keepercommander/commands/pedm/pedm_admin.py +++ b/keepercommander/commands/pedm/pedm_admin.py @@ -1227,13 +1227,7 @@ def get_policy_controls(policy_type_name: str, **kwargs) -> Optional[List[str]]: if not p_controls: return None - allowed_controls: Set[str] = set() - if policy_type_name == 'PrivilegeElevation': - allowed_controls.update(('audit', 'notify', 'mfa', 'justify', 'approval')) - elif policy_type_name == 'Access': - allowed_controls.update(('audit', 'notify', 'allow', 'deny')) - elif policy_type_name == 'CommandLine': - allowed_controls.update(('audit', 'notify', 'allow', 'deny')) + allowed_controls: Set[str] = {'audit', 'notify', 'mfa', 'justify', 'approval', 'allow', 'deny'} controls: List[str] = [] if isinstance(p_controls, str): @@ -1357,6 +1351,11 @@ def __init__(self): help='Policy Status') parser.add_argument('--enable', dest='enable', action='store', choices=['on', 'off'], help='Enables or disables policy') + parser.add_argument('--message', dest='notification_message', action='store', + help='Notification message (only for monitor_and_notify status)') + parser.add_argument('--require-acknowledgement', dest='require_acknowledgement', + action='store', choices=['on', 'off'], default=None, + help='Require policy acknowledgement (only for monitor_and_notify status)') super().__init__(parser) @@ -1378,6 +1377,11 @@ def execute(self, context: KeeperParams, **kwargs) -> None: policy_uid = utils.generate_uid() controls = PedmPolicyMixin.get_policy_controls(policy_type, **kwargs) + arg_status = kwargs.get('status') + effective_status = arg_status if isinstance(arg_status, str) else 'enforce' + if policy_type != 'LeastPrivilege' and effective_status == 'enforce' and not controls: + raise base.CommandError(f'At least one --control is required for {policy_type} policy type when status is enforce') + policy_data: Dict[str, Any] = { 'PolicyName': kwargs.get('policy_name') or '', 'PolicyType': policy_type, @@ -1445,10 +1449,22 @@ def execute(self, context: KeeperParams, **kwargs) -> None: else: policy_data['Status'] = 'enforce' + notification_message = kwargs.get('notification_message') + require_ack = kwargs.get('require_acknowledgement') + if notification_message is not None or require_ack is not None: + if policy_data['Status'] != 'monitor_and_notify': + raise base.CommandError('--message and --require-acknowledgement are only valid when --status is monitor_and_notify') + if notification_message is not None: + policy_data['NotificationMessage'] = notification_message + if require_ack is not None: + policy_data['NotificationRequiresAcknowledge'] = require_ack == 'on' + disabled: bool = False arg_enable = kwargs.get('enable') if isinstance(arg_enable, str): - disabled = True if arg_enable == 'off' else False + disabled = arg_enable == 'off' + if disabled: + policy_data['Status'] = 'off' policy_key = utils.generate_aes_key() add_policy = admin_types.PedmPolicy( @@ -1459,11 +1475,24 @@ def execute(self, context: KeeperParams, **kwargs) -> None: if isinstance(status, admin_types.EntityStatus) and not status.success: raise base.CommandError(f'Failed to add policy "{status.entity_uid}": {status.message}') + policy_name = policy_data.get('PolicyName') or '' + print(f'Successfully created policy "{policy_name}" with Policy ID: {policy_uid}') + class PedmPolicyEditCommand(base.ArgparseCommand, PedmPolicyMixin): + POLICY_TYPE_MAP = { + 'elevation': 'PrivilegeElevation', + 'file_access': 'FileAccess', + 'command': 'CommandLine', + 'least_privilege': 'LeastPrivilege', + } + def __init__(self): parser = argparse.ArgumentParser(prog='edit', description='Edit EPM policy', parents=[PedmPolicyMixin.policy_filter]) - parser.add_argument('policy', help='Policy UID') + parser.add_argument('policy', help='Policy UID or name') + parser.add_argument('--policy-type', dest='policy_type', action='store', + choices=['elevation', 'file_access', 'command', 'least_privilege'], + help='Change policy type') parser.add_argument('--policy-name', dest='policy_name', action='store', help='Policy name') parser.add_argument('--control', dest='control', action='append', @@ -1474,6 +1503,11 @@ def __init__(self): help='Policy Status') parser.add_argument('--enable', dest='enable', action='store', choices=['on', 'off'], help='Enables or disables policy') + parser.add_argument('--message', dest='notification_message', action='store', + help='Notification message (only for monitor_and_notify status)') + parser.add_argument('--require-acknowledgement', dest='require_acknowledgement', + action='store', choices=['on', 'off'], default=None, + help='Require policy acknowledgement (only for monitor_and_notify status)') super().__init__(parser) def execute(self, context: KeeperParams, **kwargs) -> None: @@ -1482,7 +1516,17 @@ def execute(self, context: KeeperParams, **kwargs) -> None: policy = PedmUtils.resolve_single_policy(plugin, kwargs.get('policy')) policy_data = copy.deepcopy(policy.data or {}) - policy_type = policy_data.get('PolicyType') or 'Unknown' + + p_type = kwargs.get('policy_type') + if p_type: + new_policy_type = PedmPolicyEditCommand.POLICY_TYPE_MAP.get(p_type) + if not new_policy_type: + raise base.CommandError(f'"policy-type: {p_type}" is not supported') + policy_data['PolicyType'] = new_policy_type + policy_type = new_policy_type + else: + policy_type = policy_data.get('PolicyType') or 'Unknown' + controls = PedmPolicyMixin.get_policy_controls(policy_type, **kwargs) if isinstance(controls, list): actions = policy_data.get('Actions') @@ -1506,10 +1550,35 @@ def execute(self, context: KeeperParams, **kwargs) -> None: if isinstance(arg_status, str): policy_data['Status'] = arg_status + effective_status = policy_data.get('Status', '') + + if p_type and policy_type != 'LeastPrivilege' and effective_status == 'enforce' and not controls: + existing_actions = policy_data.get('Actions') + existing_controls = [] + if isinstance(existing_actions, dict): + on_success = existing_actions.get('OnSuccess') + if isinstance(on_success, dict): + existing_controls = on_success.get('Controls') or [] + if not existing_controls: + raise base.CommandError(f'At least one --control is required for {policy_type} policy type when status is enforce') + notification_message = kwargs.get('notification_message') + require_ack = kwargs.get('require_acknowledgement') + if notification_message is not None or require_ack is not None: + if effective_status != 'monitor_and_notify': + raise base.CommandError('--message and --require-acknowledgement are only valid when status is monitor_and_notify') + if notification_message is not None: + policy_data['NotificationMessage'] = notification_message + if require_ack is not None: + policy_data['NotificationRequiresAcknowledge'] = require_ack == 'on' + disabled: Optional[bool] = None arg_enable = kwargs.get('enable') if isinstance(arg_enable, str): - disabled = True if arg_enable == 'off' else False + disabled = arg_enable == 'off' + if disabled: + policy_data['Status'] = 'off' + elif policy_data.get('Status') == 'off': + policy_data['Status'] = arg_status if isinstance(arg_status, str) else 'enforce' pu = admin_types.PedmUpdatePolicy(policy_uid=policy.policy_uid, data=policy_data, disabled=disabled) @@ -1519,6 +1588,9 @@ def execute(self, context: KeeperParams, **kwargs) -> None: if isinstance(status, admin_types.EntityStatus) and not status.success: raise base.CommandError(f'Failed to update policy "{status.entity_uid}": {status.message}') + updated_name = policy_data.get('PolicyName') or policy.policy_uid + print(f'Successfully updated policy "{updated_name}" (Policy ID: {policy.policy_uid})') + class PedmPolicyViewCommand(base.ArgparseCommand): def __init__(self): diff --git a/keepercommander/commands/record.py b/keepercommander/commands/record.py index 829a5f5fa..7353a0e35 100644 --- a/keepercommander/commands/record.py +++ b/keepercommander/commands/record.py @@ -288,11 +288,18 @@ def execute(self, params, **kwargs): "can_share": sf.default_can_share } if sf.records: - sfo['records'] = [{ - 'record_uid': r['record_uid'], - 'can_edit': r['can_edit'], - 'can_share': r['can_share'] - } for r in sf.records] + records_list = [] + for r in sf.records: + rec_entry = { + 'record_uid': r['record_uid'], + 'can_edit': r['can_edit'], + 'can_share': r['can_share'] + } + rec = vault.KeeperRecord.load(params, r['record_uid']) + if rec: + rec_entry['record_name'] = rec.title + records_list.append(rec_entry) + sfo['records'] = records_list def _format_expiration(expiration_value): if expiration_value is None or expiration_value <= 0: return 'never' diff --git a/keepercommander/commands/supershell/app.py b/keepercommander/commands/supershell/app.py index 42f7a3f93..8c12dbd55 100644 --- a/keepercommander/commands/supershell/app.py +++ b/keepercommander/commands/supershell/app.py @@ -1566,8 +1566,8 @@ def _get_rotation_info(self, record_uid: str) -> Optional[Dict[str, Any]]: # Only fetch DAG data for pamUser records (requires PAM infrastructure) if is_pam_user: try: - from .tunnel.port_forward.tunnel_helpers import get_keeper_tokens - from .tunnel.port_forward.TunnelGraph import TunnelDAG + from keepercommander.commands.tunnel.port_forward.tunnel_helpers import get_keeper_tokens + from keepercommander.commands.tunnel.port_forward.TunnelGraph import TunnelDAG from keeper_dag.edge import EdgeType encrypted_session_token, encrypted_transmission_key, transmission_key = get_keeper_tokens(self.params) @@ -2641,8 +2641,9 @@ def _display_secrets_manager_app(self, app_uid: str): t = self.theme_colors try: - from ...proto import APIRequest_pb2, enterprise_pb2 - from .. import api, utils + from keepercommander.proto import APIRequest_pb2, enterprise_pb2 + from keepercommander import api + from keepercommander.commands import utils import json record = self.records[app_uid] @@ -3859,7 +3860,7 @@ def action_sync_vault(self): # Run enterprise-down if available (enterprise users) try: - from .enterprise import EnterpriseDownCommand + from keepercommander.commands.enterprise import EnterpriseDownCommand EnterpriseDownCommand().execute(self.params) except Exception: pass # Not an enterprise user or command not available diff --git a/keepercommander/importer/keepass/keepass.py b/keepercommander/importer/keepass/keepass.py index a1040281f..4268ad67b 100644 --- a/keepercommander/importer/keepass/keepass.py +++ b/keepercommander/importer/keepass/keepass.py @@ -34,6 +34,22 @@ class XmlUtils(object): + @staticmethod + def is_valid_xml_char(char): # type: (str) -> bool + code_point = ord(char) + return code_point in {0x09, 0x0A, 0x0D} or \ + 0x20 <= code_point <= 0xD7FF or \ + 0xE000 <= code_point <= 0xFFFD or \ + 0x10000 <= code_point <= 0x10FFFF + + @staticmethod + def sanitize_xml_text(value): # type: (any) -> str + if not isinstance(value, str): + value = str(value) if value else '' + if not value: + return '' + return ''.join((char for char in value if XmlUtils.is_valid_xml_char(char))) + @staticmethod def escape_string(plain): # type: (str) -> str if not plain: @@ -261,9 +277,12 @@ def to_keepass_value(keeper_value): # type: (any) -> str if isinstance(keeper_value, list): return ','.join((KeepassExporter.to_keepass_value(x) for x in keeper_value)) elif isinstance(keeper_value, dict): - return ';\n'.join((f'{k}:{KeepassExporter.to_keepass_value(v)}' for k, v in keeper_value.items())) + return ';\n'.join(( + f'{XmlUtils.sanitize_xml_text(k)}:{KeepassExporter.to_keepass_value(v)}' + for k, v in keeper_value.items() + )) else: - return str(keeper_value) + return XmlUtils.sanitize_xml_text(keeper_value) def do_export(self, filename, records, file_password=None, **kwargs): master_password = file_password @@ -302,44 +321,57 @@ def do_export(self, filename, records, file_password=None, **kwargs): if path: comps = list(path_components(path)) for i in range(len(comps)): - comp = comps[i] + comp = self.sanitize_xml_text(comps[i]) + if not comp: + continue sub_node = next((x for x in node.subgroups if x.name == comp), None) if sub_node is None: sub_node = kdb.add_group(node, comp) node = sub_node + entry_title = self.to_keepass_value(r.title) + entry_login = self.to_keepass_value(r.login) + entry_password = self.to_keepass_value(r.password) + entry_url = self.to_keepass_value(r.login_url) + entry_notes = self.to_keepass_value(r.notes) entry = None entries = node.entries for en in entries: - if en.title == r.title and en.username == r.login and en.password == r.password: + if en.title == entry_title and en.username == entry_login and en.password == entry_password: entry = en break if entry is None: - entry = kdb.add_entry(node, title=r.title or '', username=r.login or '', - password=r.password or '', url=r.login_url or '', - notes=r.notes or '') + entry = kdb.add_entry(node, title=entry_title, username=entry_login, + password=entry_password, url=entry_url, + notes=entry_notes) if r.uid: entry.UUID = uuid.UUID(bytes=utils.base64_url_decode(r.uid)) if r.type: - entry.set_custom_property('$type', r.type) + entry.set_custom_property('$type', self.to_keepass_value(r.type)) if r.fields: custom_names = {} # type: Dict[str, int] for cf in r.fields: if cf.type == 'oneTimeCode': - entry.otp = cf.value + otp_value = self.to_keepass_value(cf.value) + entry.otp = otp_value # Set custom fields for Pleasant Password TOTP compatibility - totp_props = parse_totp_uri(cf.value) + totp_props = parse_totp_uri(otp_value) for key in ['secret', 'period', 'issuer', 'digits']: val = totp_props.get(key) - val and entry.set_custom_property(f'TOTP{key.capitalize()}', str(val)) + val and entry.set_custom_property( + f'TOTP{key.capitalize()}', + self.to_keepass_value(val) + ) continue - if cf.type and cf.label: - title = f'${cf.type}:{cf.label}' - elif cf.type: - title = f'${cf.type}' + field_type = self.sanitize_xml_text(cf.type) + field_label = self.sanitize_xml_text(cf.label) + if field_type and field_label: + title = f'${field_type}:{field_label}' + elif field_type: + title = f'${field_type}' else: - title = cf.label or '' + title = field_label or '' if title in custom_names: no = custom_names[title] no += 1 @@ -367,7 +399,7 @@ def do_export(self, filename, records, file_password=None, **kwargs): if binary: binary_id = kdb.add_binary(binary, compressed=True, protected=False) - entry.add_attachment(binary_id, atta.name) + entry.add_attachment(binary_id, self.to_keepass_value(atta.name)) else: scale = '' msize = self.max_size diff --git a/keepercommander/proto/workflow_pb2.py b/keepercommander/proto/workflow_pb2.py new file mode 100644 index 000000000..ace3a6982 --- /dev/null +++ b/keepercommander/proto/workflow_pb2.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: Workflow.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import GraphSync_pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0eWorkflow.proto\x12\x08Workflow\x1a\x0fGraphSync.proto\"b\n\x15WorkflowAccessRequest\x12)\n\x08resource\x18\x01 \x01(\x0b\x32\x17.GraphSync.GraphSyncRef\x12\x0e\n\x06reason\x18\x02 \x01(\x0c\x12\x0e\n\x06ticket\x18\x03 \x01(\x0c\"\xcb\x01\n\x0fWorkflowProcess\x12\x0f\n\x07\x66lowUid\x18\x01 \x01(\x0c\x12\x0e\n\x06userId\x18\x02 \x01(\x03\x12)\n\x08resource\x18\x03 \x01(\x0b\x32\x17.GraphSync.GraphSyncRef\x12\x11\n\tstartedOn\x18\x04 \x01(\x03\x12\x11\n\texpiresOn\x18\x05 \x01(\x03\x12\x0e\n\x06reason\x18\x06 \x01(\x0c\x12\x13\n\x0bmfaVerified\x18\x07 \x01(\x08\x12\x13\n\x0b\x65xternalRef\x18\x08 \x01(\x0c\x12\x0c\n\x04user\x18\t \x01(\t\"@\n\x10\x41pprovalRequests\x12,\n\tworkflows\x18\x01 \x03(\x0b\x32\x19.Workflow.WorkflowProcess\"O\n\x18WorkflowApprovalOrDenial\x12\x0f\n\x07\x66lowUid\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x65ny\x18\x02 \x01(\x08\x12\x14\n\x0c\x64\x65nialReason\x18\x03 \x01(\x0c\"U\n\x10WorkflowApproval\x12\x0e\n\x06userId\x18\x01 \x01(\x03\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x0f\n\x07\x66lowUid\x18\x03 \x01(\x0c\x12\x12\n\napprovedOn\x18\x04 \x01(\x03\"\xe6\x01\n\x0eWorkflowStatus\x12&\n\x05stage\x18\x01 \x01(\x0e\x32\x17.Workflow.WorkflowStage\x12-\n\nconditions\x18\x02 \x03(\x0e\x32\x19.Workflow.AccessCondition\x12.\n\napprovedBy\x18\x03 \x03(\x0b\x32\x1a.Workflow.WorkflowApproval\x12\x11\n\tstartedOn\x18\x04 \x01(\x03\x12\x11\n\texpiresOn\x18\x05 \x01(\x03\x12\x11\n\tescalated\x18\x06 \x01(\x08\x12\x14\n\x0c\x63heckedOutBy\x18\x07 \x01(\t\"u\n\rWorkflowState\x12\x0f\n\x07\x66lowUid\x18\x01 \x01(\x0c\x12)\n\x08resource\x18\x02 \x01(\x0b\x32\x17.GraphSync.GraphSyncRef\x12(\n\x06status\x18\x03 \x01(\x0b\x32\x18.Workflow.WorkflowStatus\"=\n\x0fUserAccessState\x12*\n\tworkflows\x18\x01 \x03(\x0b\x32\x17.Workflow.WorkflowState\"g\n\x10WorkflowApprover\x12\x0e\n\x04user\x18\x01 \x01(\tH\x00\x12\x10\n\x06userId\x18\x02 \x01(\x05H\x00\x12\x11\n\x07teamUid\x18\x03 \x01(\x0cH\x00\x12\x12\n\nescalation\x18\x04 \x01(\x08\x42\n\n\x08\x61pprover\"\xe7\x01\n\x12WorkflowParameters\x12)\n\x08resource\x18\x01 \x01(\x0b\x32\x17.GraphSync.GraphSyncRef\x12\x17\n\x0f\x61pprovalsNeeded\x18\x02 \x01(\x05\x12\x16\n\x0e\x63heckoutNeeded\x18\x03 \x01(\x08\x12\x1d\n\x15startAccessOnApproval\x18\x04 \x01(\x08\x12\x15\n\rrequireReason\x18\x05 \x01(\x08\x12\x15\n\rrequireTicket\x18\x06 \x01(\x08\x12\x12\n\nrequireMFA\x18\x07 \x01(\x08\x12\x14\n\x0c\x61\x63\x63\x65ssLength\x18\x08 \x01(\x03\"\x84\x01\n\x0eWorkflowConfig\x12\x30\n\nparameters\x18\x01 \x01(\x0b\x32\x1c.Workflow.WorkflowParameters\x12-\n\tapprovers\x18\x02 \x03(\x0b\x32\x1a.Workflow.WorkflowApprover\x12\x11\n\tcreatedOn\x18\x03 \x01(\x03*[\n\rWorkflowStage\x12\x15\n\x11WS_READY_TO_START\x10\x00\x12\x0e\n\nWS_STARTED\x10\x01\x12\x13\n\x0fWS_NEEDS_ACTION\x10\x02\x12\x0e\n\nWS_WAITING\x10\x03*i\n\x0f\x41\x63\x63\x65ssCondition\x12\x0f\n\x0b\x41\x43_APPROVAL\x10\x00\x12\x0e\n\nAC_CHECKIN\x10\x01\x12\n\n\x06\x41\x43_MFA\x10\x02\x12\x0b\n\x07\x41\x43_TIME\x10\x03\x12\r\n\tAC_REASON\x10\x04\x12\r\n\tAC_TICKET\x10\x05\x42$\n\x18\x63om.keepersecurity.protoB\x08Workflowb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'workflow_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals['DESCRIPTOR']._options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\030com.keepersecurity.protoB\010Workflow' +# @@protoc_insertion_point(module_scope) diff --git a/keepercommander/service/util/parse_keeper_response.py b/keepercommander/service/util/parse_keeper_response.py index 99a80fd9c..848f70f32 100644 --- a/keepercommander/service/util/parse_keeper_response.py +++ b/keepercommander/service/util/parse_keeper_response.py @@ -79,6 +79,7 @@ def _find_parser_method(command: str) -> str: 'enterprise-push': '_parse_enterprise_push_command', 'search record': '_parse_search_record_command', 'search folder': '_parse_search_folder_command', + 'policy add': '_parse_epm_policy_add_command', } for pattern, method_name in substring_patterns.items(): @@ -132,17 +133,20 @@ def parse_response(command: str, response: Any, log_output: str = None) -> Dict[ if not response_str: return KeeperResponseParser._handle_empty_response(command) - # If from log output, use logging-based parsing directly + # Find the appropriate parser method (used for both log and non-log paths) + parser_method_name = KeeperResponseParser._find_parser_method(command) + + # If from log output, use command-specific parser if available, else generic logging parser if is_from_log: return KeeperResponseParser._parse_logging_based_command(command, response_str) - # Find and call the appropriate parser method parser_method_name = KeeperResponseParser._find_parser_method(command) parser_method = getattr(KeeperResponseParser, parser_method_name) # Call the parser method with appropriate arguments if parser_method_name in ['_parse_generate_command', '_parse_json_format_command', - '_parse_pam_project_import_command', '_parse_enterprise_push_command']: + '_parse_pam_project_import_command', '_parse_enterprise_push_command', + '_parse_epm_policy_add_command']: return parser_method(command, response_str) else: return parser_method(response_str) if parser_method_name != '_parse_logging_based_command' else parser_method(command, response_str) @@ -923,6 +927,35 @@ def _handle_empty_response(command: str) -> Dict[str, Any]: "data": None } + @staticmethod + def _parse_epm_policy_add_command(command: str, response_str: str) -> Dict[str, Any]: + """Parse 'epm policy add' command output to extract policy ID and name.""" + response_str = KeeperResponseParser._filter_login_messages(response_str.strip()) + + result = { + "status": "success", + "command": "epm policy add", + "message": response_str, + "data": {} + } + + policy_match = re.search( + r'Successfully created policy "([^"]*)" with Policy ID:\s*(\S+)', + response_str + ) + if policy_match: + result["data"]["policy_name"] = policy_match.group(1) + result["data"]["policy_id"] = policy_match.group(2) + else: + response_lower = response_str.lower() + if any(kw in response_lower for kw in ["error", "failed", "not supported"]): + result["status"] = "error" + result["error"] = response_str + del result["data"] + del result["message"] + + return result + @staticmethod def _parse_enterprise_push_command(command: str, response_str: str) -> Dict[str, Any]: """Parse enterprise-push command responses.""" diff --git a/tests/test_pam_privileged_cloud.py b/tests/test_pam_privileged_cloud.py new file mode 100644 index 000000000..aed7f2ee8 --- /dev/null +++ b/tests/test_pam_privileged_cloud.py @@ -0,0 +1,273 @@ +"""Tests for PAM Identity Provider CLI commands (KPC Track D).""" + +import json +import unittest +from unittest.mock import MagicMock, patch + +from keepercommander.commands.pam.pam_dto import ( + GatewayActionIdpInputs, + GatewayActionIdpCreateUser, + GatewayActionIdpDeleteUser, + GatewayActionIdpAddUserToGroup, + GatewayActionIdpRemoveUserFromGroup, + GatewayActionIdpGroupList, +) + +from keepercommander.commands.pam_cloud.pam_privileged_access import ( + resolve_pam_idp_config, + VALID_CONFIG_TYPES, + PAMPrivilegedAccessCommand, + PAMAccessUserCommand, + PAMAccessGroupCommand, +) + +from keepercommander.error import CommandError + + +class TestGatewayActionIdpInputs(unittest.TestCase): + """Test GatewayActionIdpInputs serialization.""" + + def test_inputs_with_idp_config(self): + """idpConfigUid included when different from configurationUid.""" + inputs = GatewayActionIdpInputs('config-123', 'azure-456', user='john') + data = json.loads(inputs.toJSON()) + self.assertEqual(data['configurationUid'], 'config-123') + self.assertEqual(data['idpConfigUid'], 'azure-456') + self.assertEqual(data['user'], 'john') + + def test_inputs_self_managing(self): + """idpConfigUid omitted when matching configurationUid.""" + inputs = GatewayActionIdpInputs('config-123', 'config-123', user='john') + data = json.loads(inputs.toJSON()) + self.assertEqual(data['configurationUid'], 'config-123') + self.assertNotIn('idpConfigUid', data) + + def test_inputs_no_idp(self): + """idpConfigUid omitted when not provided.""" + inputs = GatewayActionIdpInputs('config-123', user='john') + data = json.loads(inputs.toJSON()) + self.assertEqual(data['configurationUid'], 'config-123') + self.assertNotIn('idpConfigUid', data) + + def test_inputs_none_values_excluded(self): + """None kwargs are not included in serialization.""" + inputs = GatewayActionIdpInputs('config-123', user='john', password=None) + data = json.loads(inputs.toJSON()) + self.assertNotIn('password', data) + + def test_inputs_with_all_fields(self): + """All fields serialized correctly.""" + inputs = GatewayActionIdpInputs( + 'config-123', 'azure-456', + user='john@contoso.com', + displayName='John Doe', + password='secret123', + groupId='group-789', + ) + data = json.loads(inputs.toJSON()) + self.assertEqual(data['user'], 'john@contoso.com') + self.assertEqual(data['displayName'], 'John Doe') + self.assertEqual(data['password'], 'secret123') + self.assertEqual(data['groupId'], 'group-789') + + +class TestGatewayActionSubclasses(unittest.TestCase): + """Test GatewayAction subclasses use correct RM action strings.""" + + def _get_payload(self, action_class, **input_kwargs): + inputs = GatewayActionIdpInputs('config-123', 'azure-456', **input_kwargs) + action = action_class(inputs=inputs) + return json.loads(action.toJSON()) + + def test_create_user_action_string(self): + payload = self._get_payload(GatewayActionIdpCreateUser, user='john') + self.assertEqual(payload['action'], 'rm-create-user') + + def test_delete_user_action_string(self): + payload = self._get_payload(GatewayActionIdpDeleteUser, user='john') + self.assertEqual(payload['action'], 'rm-delete-user') + + def test_add_user_to_group_action_string(self): + payload = self._get_payload(GatewayActionIdpAddUserToGroup, + user='john', groupId='grp-1') + self.assertEqual(payload['action'], 'rm-add-user-to-group') + + def test_remove_user_from_group_action_string(self): + payload = self._get_payload(GatewayActionIdpRemoveUserFromGroup, + user='john', groupId='grp-1') + self.assertEqual(payload['action'], 'rm-remove-user-from-group') + + def test_group_list_action_string(self): + payload = self._get_payload(GatewayActionIdpGroupList) + self.assertEqual(payload['action'], 'rm-group-list') + + def test_all_actions_not_scheduled(self): + """All IdP actions should be non-scheduled (synchronous).""" + for cls in [GatewayActionIdpCreateUser, GatewayActionIdpDeleteUser, + GatewayActionIdpAddUserToGroup, GatewayActionIdpRemoveUserFromGroup, + GatewayActionIdpGroupList]: + inputs = GatewayActionIdpInputs('config-123') + action = cls(inputs=inputs) + payload = json.loads(action.toJSON()) + self.assertFalse(payload['is_scheduled'], + f'{cls.__name__} should not be scheduled') + + def test_idp_config_uid_in_inputs(self): + """idpConfigUid should be inside the inputs object, not top-level.""" + payload = self._get_payload(GatewayActionIdpCreateUser, user='john') + self.assertIn('idpConfigUid', payload['inputs']) + self.assertNotIn('idpConfigUid', payload) + + +class TestResolveIdpConfig(unittest.TestCase): + """Test resolve_idp_config() helper.""" + + def _make_mock_record(self, record_type, idp_uid=None): + """Create a mock TypedRecord with an optional identityProviderUid custom field.""" + from keepercommander import vault + record = MagicMock(spec=vault.TypedRecord) + record.record_type = record_type + + custom_fields = [] + if idp_uid: + field = MagicMock() + field.type = 'text' + field.label = 'identityProviderUid' + field.get_external_value.return_value = iter([idp_uid]) + custom_fields.append(field) + + record.custom = custom_fields + return record + + @patch('keepercommander.commands.pam_cloud.pam_idp.vault.KeeperRecord.load') + def test_self_managing_azure(self, mock_load): + """Azure config without identityProviderUid returns self.""" + record = self._make_mock_record('pamAzureConfiguration') + mock_load.return_value = record + params = MagicMock() + + result = resolve_pam_idp_config(params, 'azure-123') + self.assertEqual(result, 'azure-123') + + @patch('keepercommander.commands.pam_cloud.pam_idp.vault.KeeperRecord.load') + def test_cross_reference(self, mock_load): + """Config with identityProviderUid returns the referenced UID.""" + net_record = self._make_mock_record('pamNetworkConfiguration', idp_uid='azure-456') + azure_record = self._make_mock_record('pamAzureConfiguration') + + def load_side_effect(params, uid): + if uid == 'net-123': + return net_record + elif uid == 'azure-456': + return azure_record + return None + + mock_load.side_effect = load_side_effect + params = MagicMock() + + result = resolve_pam_idp_config(params, 'net-123') + self.assertEqual(result, 'azure-456') + + @patch('keepercommander.commands.pam_cloud.pam_idp.vault.KeeperRecord.load') + def test_config_not_found(self, mock_load): + """Raises error when config UID doesn't exist.""" + mock_load.return_value = None + params = MagicMock() + + with self.assertRaises(CommandError): + resolve_pam_idp_config(params, 'nonexistent') + + @patch('keepercommander.commands.pam_cloud.pam_idp.vault.KeeperRecord.load') + def test_non_idp_type_without_ref(self, mock_load): + """Raises error for a non-IdP config type without identityProviderUid.""" + record = self._make_mock_record('pamNetworkConfiguration') + mock_load.return_value = record + params = MagicMock() + + with self.assertRaises(CommandError) as ctx: + resolve_pam_idp_config(params, 'net-123') + self.assertIn('No Identity Provider available', str(ctx.exception)) + + @patch('keepercommander.commands.pam_cloud.pam_idp.vault.KeeperRecord.load') + def test_referenced_config_not_found(self, mock_load): + """Raises error when referenced IdP config doesn't exist.""" + net_record = self._make_mock_record('pamNetworkConfiguration', idp_uid='missing-456') + + def load_side_effect(params, uid): + if uid == 'net-123': + return net_record + return None + + mock_load.side_effect = load_side_effect + params = MagicMock() + + with self.assertRaises(CommandError) as ctx: + resolve_pam_idp_config(params, 'net-123') + self.assertIn('not found', str(ctx.exception)) + + @patch('keepercommander.commands.pam_cloud.pam_idp.vault.KeeperRecord.load') + def test_referenced_config_invalid_type(self, mock_load): + """Raises error when a referenced config type doesn't support IdP.""" + net_record = self._make_mock_record('pamNetworkConfiguration', idp_uid='other-456') + other_record = self._make_mock_record('pamLocalConfiguration') + + def load_side_effect(params, uid): + if uid == 'net-123': + return net_record + elif uid == 'other-456': + return other_record + return None + + mock_load.side_effect = load_side_effect + params = MagicMock() + + with self.assertRaises(CommandError) as ctx: + resolve_pam_idp_config(params, 'net-123') + self.assertIn('does not support identity provider', str(ctx.exception)) + + +class TestValidIdpConfigTypes(unittest.TestCase): + """Test VALID_IDP_CONFIG_TYPES constant.""" + + def test_azure_is_valid(self): + self.assertIn('pamAzureConfiguration', VALID_CONFIG_TYPES) + + def test_okta_is_valid(self): + self.assertIn('pamOktaConfiguration', VALID_CONFIG_TYPES) + + def test_domain_is_valid(self): + self.assertIn('pamDomainConfiguration', VALID_CONFIG_TYPES) + + def test_aws_is_valid(self): + self.assertIn('pamAwsConfiguration', VALID_CONFIG_TYPES) + + def test_gcp_is_valid(self): + self.assertIn('pamGcpConfiguration', VALID_CONFIG_TYPES) + + def test_network_is_not_valid(self): + self.assertNotIn('pamNetworkConfiguration', VALID_CONFIG_TYPES) + + +class TestCommandGroupStructure(unittest.TestCase): + """Test command group hierarchy.""" + + def test_idp_has_user_and_group_subgroups(self): + cmd = PAMPrivilegedAccessCommand() + self.assertIn('user', cmd.subcommands) + self.assertIn('group', cmd.subcommands) + + def test_user_has_provision_deprovision_list(self): + cmd = PAMAccessUserCommand() + self.assertIn('provision', cmd.subcommands) + self.assertIn('deprovision', cmd.subcommands) + self.assertIn('list', cmd.subcommands) + + def test_group_has_add_remove_list(self): + cmd = PAMAccessGroupCommand() + self.assertIn('add-user', cmd.subcommands) + self.assertIn('remove-user', cmd.subcommands) + self.assertIn('list', cmd.subcommands) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/unit-tests/test_importer.py b/unit-tests/test_importer.py index b56ecd5b2..19217776a 100644 --- a/unit-tests/test_importer.py +++ b/unit-tests/test_importer.py @@ -1,10 +1,18 @@ -from unittest import TestCase, mock +import os +import tempfile +from unittest import TestCase, mock, skipUnless from data_vault import get_synced_params, get_connected_params from helper import KeeperApiHelper from keepercommander import vault from keepercommander.importer import importer, commands +try: + from keepercommander.importer.keepass.keepass import KeepassExporter, PyKeePass +except ImportError: + KeepassExporter = None + PyKeePass = None + class TestImporterUtils(TestCase): def setUp(self): @@ -72,6 +80,40 @@ def mock_read(): with mock.patch('os.path.isfile', return_value=True): cmd_import.execute(param_import, format='json', name='json') + @skipUnless(KeepassExporter and PyKeePass, 'pykeepass is not installed') + def test_keepass_export_sanitizes_xml_invalid_characters(self): + record = importer.Record() + record.title = 'bad\x10title' + record.login = 'user\x10name' + record.password = 'pass\x10word' + record.login_url = 'https://example.com/\x10path' + record.notes = 'note\x10body' + + folder = importer.Folder() + folder.path = 'group\x10name' + record.folders = [folder] + + record.fields.append(importer.RecordField('text', 'custom\x10label', 'value\x10data')) + + with tempfile.NamedTemporaryFile(suffix='.kdbx', delete=False) as temp_file: + file_name = temp_file.name + + try: + KeepassExporter().do_export(file_name, [record], file_password='password') + with PyKeePass(file_name, password='password') as kdb: + self.assertEqual(len(kdb.entries), 1) + entry = kdb.entries[0] + self.assertEqual(entry.title, 'badtitle') + self.assertEqual(entry.username, 'username') + self.assertEqual(entry.password, 'password') + self.assertEqual(entry.url, 'https://example.com/path') + self.assertEqual(entry.notes, 'notebody') + self.assertEqual(entry.group.name, 'groupname') + self.assertEqual(entry.custom_properties.get('$text:customlabel'), 'valuedata') + finally: + if os.path.exists(file_name): + os.unlink(file_name) + def test_host_serialization(self): host = { 'hostName': 'keepersecurity.com',