From af0b616316c2b65766dd121aec71e1ccc7f92dd5 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Wed, 28 Jan 2026 18:20:15 +0530 Subject: [PATCH 1/3] SK-2496: extract hard coded values to constants --- skyflow/client/skyflow.py | 55 ++++---- skyflow/service_account/_utils.py | 69 ++++----- skyflow/utils/_skyflow_messages.py | 3 + skyflow/utils/_utils.py | 109 ++++++++------- skyflow/utils/constants.py | 163 ++++++++++++++++++++++ skyflow/utils/logger/_log_helpers.py | 13 +- skyflow/utils/validations/_validations.py | 9 +- skyflow/vault/client/client.py | 29 ++-- skyflow/vault/controller/_connections.py | 5 +- skyflow/vault/controller/_detect.py | 39 +++--- skyflow/vault/controller/_vault.py | 12 +- 11 files changed, 340 insertions(+), 166 deletions(-) diff --git a/skyflow/client/skyflow.py b/skyflow/client/skyflow.py index 9f0d9dbf..0bfde34e 100644 --- a/skyflow/client/skyflow.py +++ b/skyflow/client/skyflow.py @@ -3,6 +3,7 @@ from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages from skyflow.utils.logger import log_info, Logger +from skyflow.utils.constants import OptionField from skyflow.utils.validations import validate_vault_config, validate_connection_config, validate_update_vault_config, \ validate_update_connection_config, validate_credentials, validate_log_level from skyflow.vault.client.client import VaultClient @@ -30,7 +31,7 @@ def update_vault_config(self,config): self.__builder.update_vault_config(config) def get_vault_config(self, vault_id): - return self.__builder.get_vault_config(vault_id).get("vault_client").get_config() + return self.__builder.get_vault_config(vault_id).get(OptionField.VAULT_CLIENT).get_config() def add_connection_config(self, config): self.__builder._Builder__add_connection_config(config) @@ -45,7 +46,7 @@ def update_connection_config(self, config): return self def get_connection_config(self, connection_id): - return self.__builder.get_connection_config(connection_id).get("vault_client").get_config() + return self.__builder.get_connection_config(connection_id).get(OptionField.VAULT_CLIENT).get_config() def add_skyflow_credentials(self, credentials): self.__builder._Builder__add_skyflow_credentials(credentials) @@ -66,15 +67,15 @@ def update_log_level(self, log_level): def vault(self, vault_id = None) -> Vault: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("vault_controller") + return vault_config.get(OptionField.VAULT_CONTROLLER) def connection(self, connection_id = None) -> Connection: connection_config = self.__builder.get_connection_config(connection_id) - return connection_config.get("controller") + return connection_config.get(OptionField.CONTROLLER) def detect(self, vault_id = None) -> Detect: vault_config = self.__builder.get_vault_config(vault_id) - return vault_config.get("detect_controller") + return vault_config.get(OptionField.DETECT_CONTROLLER) class Builder: def __init__(self): @@ -87,13 +88,13 @@ def __init__(self): self.__logger = Logger(LogLevel.ERROR) def add_vault_config(self, config): - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) if not isinstance(vault_id, str) or not vault_id: raise SkyflowError( SkyflowMessages.Error.INVALID_VAULT_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if vault_id in [vault.get("vault_id") for vault in self.__vault_list]: + if vault_id in [vault.get(OptionField.VAULT_ID) for vault in self.__vault_list]: log_info(SkyflowMessages.Info.VAULT_CONFIG_EXISTS.value.format(vault_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.VAULT_ID_ALREADY_EXISTS.value.format(vault_id), @@ -112,9 +113,9 @@ def remove_vault_config(self, vault_id): def update_vault_config(self, config): validate_update_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_config = self.__vault_configs[vault_id] - vault_config.get("vault_client").update_config(config) + vault_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_vault_config(self, vault_id): if vault_id is None: @@ -129,13 +130,13 @@ def get_vault_config(self, vault_id): def add_connection_config(self, config): - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) if not isinstance(connection_id, str) or not connection_id: raise SkyflowError( SkyflowMessages.Error.INVALID_CONNECTION_ID.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value ) - if connection_id in [connection.get("connection_id") for connection in self.__connection_list]: + if connection_id in [connection.get(OptionField.CONNECTION_ID) for connection in self.__connection_list]: log_info(SkyflowMessages.Info.CONNECTION_CONFIG_EXISTS.value.format(connection_id), self.__logger) raise SkyflowError( SkyflowMessages.Error.CONNECTION_ID_ALREADY_EXISTS.value.format(connection_id), @@ -153,9 +154,9 @@ def remove_connection_config(self, connection_id): def update_connection_config(self, config): validate_update_connection_config(self.__logger, config) - connection_id = config['connection_id'] + connection_id = config[OptionField.CONNECTION_ID] connection_config = self.__connection_configs[connection_id] - connection_config.get("vault_client").update_config(config) + connection_config.get(OptionField.VAULT_CLIENT).update_config(config) def get_connection_config(self, connection_id): if connection_id is None: @@ -183,32 +184,32 @@ def get_logger(self): def __add_vault_config(self, config): validate_vault_config(self.__logger, config) - vault_id = config.get("vault_id") + vault_id = config.get(OptionField.VAULT_ID) vault_client = VaultClient(config) self.__vault_configs[vault_id] = { - "vault_client": vault_client, - "vault_controller": Vault(vault_client), - "detect_controller": Detect(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.VAULT_CONTROLLER: Vault(vault_client), + OptionField.DETECT_CONTROLLER: Detect(vault_client) } - log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) - log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get("vault_id")), self.__logger) + log_info(SkyflowMessages.Info.VAULT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) + log_info(SkyflowMessages.Info.DETECT_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.VAULT_ID)), self.__logger) def __add_connection_config(self, config): validate_connection_config(self.__logger, config) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) vault_client = VaultClient(config) self.__connection_configs[connection_id] = { - "vault_client": vault_client, - "controller": Connection(vault_client) + OptionField.VAULT_CLIENT: vault_client, + OptionField.CONTROLLER: Connection(vault_client) } - log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get("connection_id")), self.__logger) + log_info(SkyflowMessages.Info.CONNECTION_CONTROLLER_INITIALIZED.value.format(config.get(OptionField.CONNECTION_ID)), self.__logger) def __update_vault_client_logger(self, log_level, logger): for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_logger(log_level,logger) + vault_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_logger(log_level,logger) + connection_config.get(OptionField.VAULT_CLIENT).set_logger(log_level,logger) def __set_log_level(self, log_level): validate_log_level(self.__logger, log_level) @@ -223,10 +224,10 @@ def __add_skyflow_credentials(self, credentials): self.__skyflow_credentials = credentials validate_credentials(self.__logger, credentials) for vault_id, vault_config in self.__vault_configs.items(): - vault_config.get("vault_client").set_common_skyflow_credentials(credentials) + vault_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(credentials) for connection_id, connection_config in self.__connection_configs.items(): - connection_config.get("vault_client").set_common_skyflow_credentials(self.__skyflow_credentials) + connection_config.get(OptionField.VAULT_CLIENT).set_common_skyflow_credentials(self.__skyflow_credentials) def build(self): validate_log_level(self.__logger, self.__log_level) self.__logger.set_log_level(self.__log_level) diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 715716d8..3f21ba21 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -6,6 +6,7 @@ from skyflow.service_account.client.auth_client import AuthClient from skyflow.utils.logger import log_info, log_error_log from skyflow.utils import get_base_url, format_scope, SkyflowMessages +from skyflow.utils.constants import JWT, CredentialField, JwtField, OptionField, ResponseField invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value @@ -17,8 +18,8 @@ def is_expired(token, logger = None): try: decoded = jwt.decode( - token, options={"verify_signature": False, "verify_aud": False}) - if time.time() >= decoded['exp']: + token, options={OptionField.VERIFY_SIGNATURE: False, OptionField.VERIFY_AUD: False}) + if time.time() >= decoded[JwtField.EXP]: log_info(SkyflowMessages.Info.BEARER_TOKEN_EXPIRED.value, logger) log_error_log(SkyflowMessages.ErrorLogs.INVALID_BEARER_TOKEN.value) return True @@ -59,22 +60,22 @@ def generate_bearer_token_from_creds(credentials, options = None, logger = None) def get_service_account_token(credentials, options, logger): try: - private_key = credentials["privateKey"] + private_key = credentials[CredentialField.PRIVATE_KEY] except: log_error_log(SkyflowMessages.ErrorLogs.PRIVATE_KEY_IS_REQUIRED.value, logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_PRIVATE_KEY.value, invalid_input_error_code) try: - client_id = credentials["clientID"] + client_id = credentials[CredentialField.CLIENT_ID] except: log_error_log(SkyflowMessages.ErrorLogs.CLIENT_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_CLIENT_ID.value, invalid_input_error_code) try: - key_id = credentials["keyID"] + key_id = credentials[CredentialField.KEY_ID] except: log_error_log(SkyflowMessages.ErrorLogs.KEY_ID_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_KEY_ID.value, invalid_input_error_code) try: - token_uri = credentials["tokenURI"] + token_uri = credentials[CredentialField.TOKEN_URI] except: log_error_log(SkyflowMessages.ErrorLogs.TOKEN_URI_IS_REQUIRED.value, logger=logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TOKEN_URI.value, invalid_input_error_code) @@ -85,27 +86,27 @@ def get_service_account_token(credentials, options, logger): auth_api = auth_client.get_auth_api() formatted_scope = None - if options and "role_ids" in options: - formatted_scope = format_scope(options.get("role_ids")) + if options and OptionField.ROLE_IDS in options: + formatted_scope = format_scope(options.get(OptionField.ROLE_IDS)) response = auth_api.authentication_service_get_auth_token(assertion = signed_token, - grant_type="urn:ietf:params:oauth:grant-type:jwt-bearer", + grant_type=JWT.GRANT_TYPE_JWT_BEARER, scope=formatted_scope) log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_SUCCESS.value, logger) return response.access_token, response.token_type def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): payload = { - "iss": client_id, - "key": key_id, - "aud": token_uri, - "sub": client_id, - "exp": datetime.datetime.utcnow() + datetime.timedelta(minutes=60) + JwtField.ISS: client_id, + JwtField.KEY: key_id, + JwtField.AUD: token_uri, + JwtField.SUB: client_id, + JwtField.EXP: datetime.datetime.utcnow() + datetime.timedelta(minutes=60) } - if options and "ctx" in options: - payload["ctx"] = options.get("ctx") + if options and JwtField.CTX in options: + payload[JwtField.CTX] = options.get(JwtField.CTX) try: - return jwt.encode(payload=payload, key=private_key, algorithm="RS256") + return jwt.encode(payload=payload, key=private_key, algorithm=JWT.ALGORITHM_RS256) except Exception: raise SkyflowError(SkyflowMessages.Error.JWT_INVALID_FORMAT.value, invalid_input_error_code) @@ -113,25 +114,25 @@ def get_signed_jwt(options, client_id, key_id, token_uri, private_key, logger): def get_signed_tokens(credentials_obj, options): try: - expiry_time = int(time.time()) + options.get("time_to_live", 60) - prefix = "signed_token_" + expiry_time = int(time.time()) + options.get(OptionField.TIME_TO_LIVE, 60) + prefix = JWT.SIGNED_TOKEN_PREFIX - if options and options.get("data_tokens"): - for token in options["data_tokens"]: + if options and options.get(OptionField.DATA_TOKENS): + for token in options[OptionField.DATA_TOKENS]: claims = { - "iss": "sdk", - "key": credentials_obj.get("keyID"), - "exp": expiry_time, - "sub": credentials_obj.get("clientID"), - "tok": token, - "iat": int(time.time()), + JwtField.ISS: JWT.ISSUER_SDK, + JwtField.KEY: credentials_obj.get(CredentialField.KEY_ID), + JwtField.EXP: expiry_time, + JwtField.SUB: credentials_obj.get(CredentialField.CLIENT_ID), + JwtField.TOK: token, + JwtField.IAT: int(time.time()), } - if "ctx" in options: - claims["ctx"] = options["ctx"] + if JwtField.CTX in options: + claims[JwtField.CTX] = options[JwtField.CTX] - private_key = credentials_obj.get("privateKey") - signed_jwt = jwt.encode(claims, private_key, algorithm="RS256") + private_key = credentials_obj.get(CredentialField.PRIVATE_KEY) + signed_jwt = jwt.encode(claims, private_key, algorithm=JWT.ALGORITHM_RS256) response_object = get_signed_data_token_response_object(prefix + signed_jwt, token) log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKEN_SUCCESS.value) return response_object @@ -170,7 +171,7 @@ def generate_signed_data_tokens_from_creds(credentials, options): def get_signed_data_token_response_object(signed_token, actual_token): response_object = { - "token": actual_token, - "signed_token": signed_token + ResponseField.TOKEN: actual_token, + ResponseField.SIGNED_TOKEN: signed_token } - return response_object.get("token"), response_object.get("signed_token") + return response_object.get(ResponseField.TOKEN), response_object.get(ResponseField.SIGNED_TOKEN) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..1954ed4d 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -71,6 +71,9 @@ class Error(Enum): RESPONSE_NOT_JSON = f"{error_prefix} Response {{}} is not valid JSON." API_ERROR = f"{error_prefix} Server returned status code {{}}" + INVALID_JSON_RESPONSE = f"{error_prefix} Invalid JSON response received." + UNKNOWN_ERROR_DEFAULT_MESSAGE = f"{error_prefix} An unknown error occurred." + INVALID_FILE_INPUT = f"{error_prefix} Validation error. Invalid file input. Specify a valid file input." INVALID_DETECT_ENTITIES_TYPE = f"{error_prefix} Validation error. Invalid type of detect entities. Specify detect entities as list of DetectEntities enum." INVALID_TYPE_FOR_DEFAULT_TOKEN_TYPE = f"{error_prefix} Validation error. Invalid type of default token type. Specify default token type as TokenType enum." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 4278357e..c6f294cd 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -20,7 +20,8 @@ from skyflow.vault.detect import DeidentifyTextResponse, ReidentifyTextResponse from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION -from .constants import PROTOCOL +from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, + EncodingType, BooleanString, ResponseField, CredentialField) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -44,7 +45,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg try: env_creds = env_skyflow_credentials.replace('\n', '\\n') return { - 'credentials_string': env_creds + CredentialField.CREDENTIALS_STRING: env_creds } except json.JSONDecodeError: raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) @@ -52,7 +53,7 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False api_key_pattern = re.compile(r'^sky-[a-zA-Z0-9]{5}-[a-fA-F0-9]{32}$') @@ -113,13 +114,13 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep except Exception: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - if not 'Content-Type'.lower() in header: - header['content-type'] = ContentType.JSON.value + if not HttpHeader.CONTENT_TYPE.lower() in header: + header[HttpHeader.CONTENT_TYPE_LOWERCASE] = ContentType.JSON.value try: if isinstance(request.body, dict): json_data, files = get_data_from_content_type( - request.body, header["content-type"] + request.body, header[HttpHeader.CONTENT_TYPE_LOWERCASE] ) else: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) @@ -216,30 +217,30 @@ def parse_insert_response(api_response, continue_on_error): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) inserted_fields = [] errors = [] insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response['Status'] == 200: - body = response['Body'] - if 'records' in body: - for record in body['records']: + if response[ResponseField.STATUS] == 200: + body = response[ResponseField.BODY] + if ResponseField.RECORDS in body: + for record in body[ResponseField.RECORDS]: inserted_field = { - 'skyflow_id': record['skyflow_id'], - 'request_index': idx + ResponseField.SKYFLOW_ID: record[ResponseField.SKYFLOW_ID], + ResponseField.REQUEST_INDEX: idx } - if 'tokens' in record: - inserted_field.update(record['tokens']) + if ResponseField.TOKENS in record: + inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response['Status'] == 400: + elif response[ResponseField.STATUS] == 400: error = { - 'request_index': idx, - 'request_id': request_id, - 'error': response['Body']['error'], - 'http_code': response['Status'], + ResponseField.REQUEST_INDEX: idx, + ResponseField.REQUEST_ID: request_id, + ResponseField.ERROR: response[ResponseField.BODY][ResponseField.ERROR], + ResponseField.HTTP_CODE: response[ResponseField.STATUS], } errors.append(error) @@ -248,7 +249,7 @@ def parse_insert_response(api_response, continue_on_error): else: for record in api_response_data.records: field_data = { - 'skyflow_id': record.skyflow_id + ResponseField.SKYFLOW_ID: record.skyflow_id } if record.tokens: @@ -263,7 +264,7 @@ def parse_insert_response(api_response, continue_on_error): def parse_update_record_response(api_response: V1UpdateRecordResponse): update_response = UpdateResponse() updated_field = dict() - updated_field['skyflow_id'] = api_response.skyflow_id + updated_field[ResponseField.SKYFLOW_ID] = api_response.skyflow_id if api_response.tokens is not None: updated_field.update(api_response.tokens) @@ -293,23 +294,23 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): api_response_headers = api_response.headers api_response_data = api_response.data # Retrieve the request ID from the headers - request_id = api_response_headers.get('x-request-id') + request_id = api_response_headers.get(HttpHeader.X_REQUEST_ID) detokenized_fields = [] errors = [] for record in api_response_data.records: if record.error: errors.append({ - "token": record.token, - "error": record.error, - "request_id": request_id + ResponseField.TOKEN: record.token, + ResponseField.ERROR: record.error, + ResponseField.REQUEST_ID: request_id }) else: value_type = record.value_type if record.value_type else None detokenized_fields.append({ - "token": record.token, - "value": record.value, - "type": value_type + ResponseField.TOKEN: record.token, + ResponseField.VALUE: record.value, + ResponseField.TYPE: value_type }) detokenized_fields = detokenized_fields @@ -322,7 +323,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): def parse_tokenize_response(api_response: V1TokenizeResponse): tokenize_response = TokenizeResponse() - tokenized_fields = [{"token": record.token} for record in api_response.records] + tokenized_fields = [{ResponseField.TOKEN: record.token} for record in api_response.records] tokenize_response.tokenized_fields = tokenized_fields @@ -334,7 +335,7 @@ def parse_query_response(api_response: V1GetQueryResponse): for record in api_response.records: field_object = { **record.fields, - "tokenized_data": {} + ResponseField.TOKENIZED_DATA: {} } fields.append(field_object) query_response.fields = fields @@ -344,14 +345,14 @@ def parse_invoke_connection_response(api_response: requests.Response): status_code = api_response.status_code content = api_response.content if isinstance(content, bytes): - content = content.decode('utf-8') + content = content.decode(EncodingType.UTF_8) try: api_response.raise_for_status() try: data = json.loads(content) metadata = {} - if 'x-request-id' in api_response.headers: - metadata['request_id'] = api_response.headers['x-request-id'] + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: @@ -360,19 +361,19 @@ def parse_invoke_connection_response(api_response: requests.Response): message = SkyflowMessages.Error.API_ERROR.value.format(status_code) try: error_response = json.loads(content) - request_id = api_response.headers['x-request-id'] - error_from_client = api_response.headers.get('error-from-client') - - status_code = error_response.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = error_response.get('error', {}).get('http_status') - grpc_code = error_response.get('error', {}).get('grpc_code') - details = error_response.get('error', {}).get('details') - message = error_response.get('error', {}).get('message', "An unknown error occurred.") + request_id = api_response.headers[HttpHeader.X_REQUEST_ID] + error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) + + status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) + grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) + details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) + message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) if error_from_client is not None: if details is None: details = [] - error_from_client_bool = error_from_client.lower() == 'true' - details.append({'error_from_client': error_from_client_bool}) + error_from_client_bool = error_from_client.lower() == BooleanString.TRUE + details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) except json.JSONDecodeError: @@ -399,14 +400,14 @@ def handle_exception(error, logger): if (isinstance(error, httpx.ConnectError)): handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) - request_id = error.headers.get('x-request-id', 'unknown-request-id') - content_type = error.headers.get('content-type') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, 'unknown-request-id') + content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body if content_type: - if 'application/json' in content_type: + if ContentTypeConstants.APPLICATION_JSON in content_type: handle_json_error(error, data, request_id, logger) - elif 'text/plain' in content_type: + elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: handle_generic_error(error, request_id, logger) @@ -421,15 +422,15 @@ def handle_json_error(err, data, request_id, logger): description = data.dict() else: description = json.loads(data) - status_code = description.get('error', {}).get('http_code', 500) # Default to 500 if not found - http_status = description.get('error', {}).get('http_status') - grpc_code = description.get('error', {}).get('grpc_code') - details = description.get('error', {}).get('details', []) + status_code = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + http_status = description.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) + grpc_code = description.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) + details = description.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS, []) - description_message = description.get('error', {}).get('message', "An unknown error occurred.") + description_message = description.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) log_and_reject_error(description_message, status_code, request_id, http_status, grpc_code, details, logger = logger) except json.JSONDecodeError: - log_and_reject_error("Invalid JSON response received.", err, request_id, logger = logger) + log_and_reject_error(SkyflowMessages.Error.INVALID_JSON_RESPONSE.value, err, request_id, logger = logger) def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index ef20faf8..30cb124d 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -2,3 +2,166 @@ PROTOCOL='https' SKY_META_DATA_HEADER='sky-metadata' +class SKYFLOW: + SKYFLOW_ID = 'skyflowId' + X_SKYFLOW_AUTHORIZATION = 'x-skyflow-authorization' + + +class HttpHeader: + CONTENT_TYPE = 'Content-Type' + CONTENT_TYPE_LOWERCASE = 'content-type' + X_REQUEST_ID = 'x-request-id' + ERROR_FROM_CLIENT = 'error-from-client' + AUTHORIZATION = 'Authorization' + + +class HttpStatusCode: + OK = 200 + BAD_REQUEST = 400 + INTERNAL_SERVER_ERROR = 500 + + +class ContentType: + APPLICATION_JSON = 'application/json' + APPLICATION_X_WWW_FORM_URLENCODED = 'application/x-www-form-urlencoded' + TEXT_PLAIN = 'text/plain' + + +class DetectStatus: + IN_PROGRESS = 'IN_PROGRESS' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + UNKNOWN = 'UNKNOWN' + + +class FileExtension: + JSON = 'json' + MP3 = 'mp3' + WAV = 'wav' + PDF = 'pdf' + TXT = 'txt' + DOC = 'doc' + DOCX = 'docx' + JPG = 'jpg' + JPEG = 'jpeg' + PNG = 'png' + BMP = 'bmp' + TIF = 'tif' + TIFF = 'tiff' + PPT = 'ppt' + PPTX = 'pptx' + CSV = 'csv' + XLS = 'xls' + XLSX = 'xlsx' + XML = 'xml' + + +class FileProcessing: + PROCESSED_PREFIX = 'processed-' + DEIDENTIFIED_PREFIX = 'deidentified.' + ENTITIES = 'entities' + + +class EncodingType: + UTF8 = 'utf8' + UTF_8 = 'utf-8' + BASE64 = 'base64' + BINARY = 'binary' + + +class JWT: + ALGORITHM_RS256 = 'RS256' + GRANT_TYPE_JWT_BEARER = 'urn:ietf:params:oauth:grant-type:jwt-bearer' + ISSUER_SDK = 'sdk' + SIGNED_TOKEN_PREFIX = 'signed_token_' + ROLE_PREFIX = 'role:' + + +class ApiKey: + SKY_PREFIX = 'sky-' + LENGTH = 42 + + +class UrlProtocol: + HTTPS = 'https' + HTTP = 'http' + + +class BooleanString: + TRUE = 'true' + FALSE = 'false' + + +class ResponseField: + STATUS = 'Status' + BODY = 'Body' + RECORDS = 'records' + TOKENS = 'tokens' + ERROR = 'error' + SKYFLOW_ID = 'skyflow_id' + REQUEST_INDEX = 'request_index' + REQUEST_ID = 'request_id' + HTTP_CODE = 'http_code' + HTTP_STATUS = 'http_status' + GRPC_CODE = 'grpc_code' + DETAILS = 'details' + MESSAGE = 'message' + ERROR_FROM_CLIENT = 'error_from_client' + TOKEN = 'token' + VALUE = 'value' + TYPE = 'type' + TOKENIZED_DATA = 'tokenized_data' + SIGNED_TOKEN = 'signed_token' + + +class CredentialField: + PRIVATE_KEY = 'privateKey' + CLIENT_ID = 'clientID' + KEY_ID = 'keyID' + TOKEN_URI = 'tokenURI' + CREDENTIALS_STRING = 'credentials_string' + API_KEY = 'api_key' + TOKEN = 'token' + PATH = 'path' + + +class JwtField: + ISS = 'iss' + KEY = 'key' + AUD = 'aud' + SUB = 'sub' + EXP = 'exp' + CTX = 'ctx' + TOK = 'tok' + IAT = 'iat' + + +class OptionField: + ROLE_IDS = 'role_ids' + DATA_TOKENS = 'data_tokens' + TIME_TO_LIVE = 'time_to_live' + ROLES = 'roles' + CTX = 'ctx' + VAULT_ID = 'vault_id' + CONNECTION_ID = 'connection_id' + CONNECTION_URL = 'connection_url' + VAULT_CLIENT = 'vault_client' + VAULT_CONTROLLER = 'vault_controller' + DETECT_CONTROLLER = 'detect_controller' + CONTROLLER = 'controller' + VERIFY_SIGNATURE = 'verify_signature' + VERIFY_AUD = 'verify_aud' + + +class ConfigField: + CREDENTIALS = 'credentials' + CLUSTER_ID = 'cluster_id' + ENV = 'env' + VAULT_ID = 'vault_id' + + +class RequestParameter: + VALUE = 'value' + COLUMN_GROUP = 'column_group' + REDACTION = 'redaction' + diff --git a/skyflow/utils/logger/_log_helpers.py b/skyflow/utils/logger/_log_helpers.py index fdb11ea9..3fff980b 100644 --- a/skyflow/utils/logger/_log_helpers.py +++ b/skyflow/utils/logger/_log_helpers.py @@ -1,5 +1,6 @@ from ..enums import LogLevel from . import Logger +from ..constants import ResponseField def log_info(message, logger = None): @@ -18,17 +19,17 @@ def log_error(message, http_code, request_id=None, grpc_code=None, http_status=N logger = Logger(LogLevel.ERROR) log_data = { - 'http_code': http_code, - 'message': message + ResponseField.HTTP_CODE: http_code, + ResponseField.MESSAGE: message } if grpc_code is not None: - log_data['grpc_code'] = grpc_code + log_data[ResponseField.GRPC_CODE] = grpc_code if http_status is not None: - log_data['http_status'] = http_status + log_data[ResponseField.HTTP_STATUS] = http_status if request_id is not None: - log_data['request_id'] = request_id + log_data[ResponseField.REQUEST_ID] = request_id if details is not None: - log_data['details'] = details + log_data[ResponseField.DETAILS] = details logger.error(log_data) \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..779fdfcc 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,6 +6,7 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages +from skyflow.utils.constants import ApiKey, ResponseField from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest @@ -50,11 +51,11 @@ def validate_required_field(logger, config, field_name, expected_type, empty_err raise SkyflowError(empty_error, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: - if not api_key.startswith('sky-'): + if not api_key.startswith(ApiKey.SKY_PREFIX): log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger=logger) return False - if len(api_key) != 42: + if len(api_key) != ApiKey.LENGTH: log_error_log(SkyflowMessages.ErrorLogs.INVALID_API_KEY.value, logger = logger) return False @@ -582,10 +583,10 @@ def validate_get_request(logger, request): def validate_update_request(logger, request): skyflow_id = "" - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} try: - skyflow_id = request.data.get("skyflow_id") + skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) except Exception: log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..2d77330e 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -2,6 +2,7 @@ from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages from skyflow.utils.logger import log_info +from skyflow.utils.constants import OptionField, CredentialField, ConfigField class VaultClient: @@ -23,11 +24,11 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) + credentials = get_credentials(self.__config.get(ConfigField.CREDENTIALS), self.__common_skyflow_credentials, logger = self.__logger) token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), + vault_url = get_vault_url(self.__config.get(ConfigField.CLUSTER_ID), + self.__config.get(ConfigField.ENV), + self.__config.get(ConfigField.VAULT_ID), logger = self.__logger) self.initialize_api_client(vault_url, token) @@ -50,29 +51,29 @@ def get_detect_file_api(self): return self.__api_client.files def get_vault_id(self): - return self.__config.get("vault_id") + return self.__config.get(ConfigField.VAULT_ID) def get_bearer_token(self, credentials): - if 'api_key' in credentials: - return credentials.get('api_key') - elif 'token' in credentials: - return credentials.get("token") + if CredentialField.API_KEY in credentials: + return credentials.get(CredentialField.API_KEY) + elif CredentialField.TOKEN in credentials: + return credentials.get(CredentialField.TOKEN) options = { - "role_ids": self.__config.get("roles"), - "ctx": self.__config.get("ctx") + OptionField.ROLE_IDS: self.__config.get(OptionField.ROLES), + OptionField.CTX: self.__config.get(OptionField.CTX) } if self.__bearer_token is None or self.__is_config_updated: - if 'path' in credentials: - path = credentials.get("path") + if CredentialField.PATH in credentials: + path = credentials.get(CredentialField.PATH) self.__bearer_token, _ = generate_bearer_token( path, options, self.__logger ) else: - credentials_string = credentials.get('credentials_string') + credentials_string = credentials.get(CredentialField.CREDENTIALS_STRING) log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, self.__logger) self.__bearer_token, _ = generate_bearer_token_from_creds( credentials_string, diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 81c6ea10..83b0ffbd 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,6 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW class Connection: @@ -23,9 +24,9 @@ def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: - invoke_connection_request.headers['x-skyflow-authorization'] = bearer_token + invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token - invoke_connection_request.headers['sky-metadata'] = json.dumps(get_metrics()) + invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 44ef2540..4f2f50f2 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -8,7 +8,8 @@ FileDataDeidentifyImage, Format, FileDataDeidentifyAudio, WordCharacterCount, DetectRunsResponse from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, + FileProcessing, EncodingType) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -64,7 +65,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): while True: response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data status = response.status - if status == 'IN_PROGRESS': + if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') else: @@ -76,7 +77,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): wait_time = next_wait_time current_wait_time = next_wait_time time.sleep(wait_time) - elif status == 'SUCCESS' or status == 'FAILED': + elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: return response except Exception as e: raise e @@ -88,7 +89,7 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o if not os.path.exists(output_directory): return - deidentify_file_prefix = "processed-" + deidentify_file_prefix = FileProcessing.PROCESSED_PREFIX output_list = response.output base_original_filename = os.path.basename(original_file_name) @@ -159,7 +160,7 @@ def output_to_dict_list(output): output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == "entities"] + entities = [o for o in output_list if o.get("type") == FileProcessing.ENTITIES] base64_string = first_output.get("file", None) extension = first_output.get("extension", None) @@ -167,14 +168,14 @@ def output_to_dict_list(output): if base64_string is not None: file_bytes = base64.b64decode(base64_string) file_obj = io.BytesIO(file_bytes) - file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else "processed_file" else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", "UNKNOWN"), + type=first_output.get("type", DetectStatus.UNKNOWN), extension=extension, word_count=word_count, char_count=char_count, @@ -282,11 +283,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): file_name = getattr(file_obj, 'name', None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') + base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) try: - if file_extension == 'txt': - req_file = FileDataDeidentifyText(base_64=base64_string, data_format="txt") + if file_extension == FileExtension.TXT: + req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { 'vault_id': self.__vault_client.get_vault_id(), @@ -299,7 +300,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['mp3', 'wav']: + elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio api_kwargs = { @@ -319,7 +320,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension == 'pdf': + elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { @@ -334,7 +335,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['jpeg', 'jpg', 'png', 'bmp', 'tif', 'tiff']: + elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { @@ -350,7 +351,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['ppt', 'pptx']: + elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { @@ -363,7 +364,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['csv', 'xls', 'xlsx']: + elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { @@ -376,7 +377,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['doc', 'docx']: + elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { @@ -389,7 +390,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): 'request_options': self.__get_headers() } - elif file_extension in ['json', 'xml']: + elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { @@ -423,7 +424,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, 'run_id', None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == 'SUCCESS': + if request.output_directory and processed_response.status == DetectStatus.SUCCESS: name_without_ext, _ = os.path.splitext(file_name) self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) @@ -450,7 +451,7 @@ def get_detect_run(self, request: GetDetectRunRequest): vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers() ) - if response.data.status == 'IN_PROGRESS': + if response.data.status == DetectStatus.IN_PROGRESS: parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 7cc9ec77..a5cd94fd 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -125,7 +125,7 @@ def update(self, request: UpdateRequest): validate_update_request(self.__vault_client.get_logger(), request) log_info(SkyflowMessages.Info.UPDATE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() - field = {key: value for key, value in request.data.items() if key != "skyflow_id"} + field = {key: value for key, value in request.data.items() if key != ResponseField.SKYFLOW_ID} record = V1FieldRecords(fields=field, tokens = request.tokens) records_api = self.__vault_client.get_records_api() @@ -134,7 +134,7 @@ def update(self, request: UpdateRequest): api_response = records_api.record_service_update_record( self.__vault_client.get_vault_id(), request.table, - id=request.data.get("skyflow_id"), + id=request.data.get(ResponseField.SKYFLOW_ID), record=record, tokenization=request.return_tokens, byot=request.token_mode.value, @@ -225,8 +225,8 @@ def detokenize(self, request: DetokenizeRequest): self.__initialize() tokens_list = [ V1DetokenizeRecordRequest( - token=item.get('token'), - redaction=item.get('redaction', RedactionType.DEFAULT) + token=item.get(ResponseField.TOKEN), + redaction=item.get(RequestParameter.REDACTION, RedactionType.DEFAULT) ) for item in request.data ] @@ -253,7 +253,7 @@ def tokenize(self, request: TokenizeRequest): self.__initialize() records_list = [ - V1TokenizeRecordRequest(value=item["value"], column_group=item["column_group"]) + V1TokenizeRecordRequest(value=item[RequestParameter.VALUE], column_group=item[RequestParameter.COLUMN_GROUP]) for item in request.values ] tokens_api = self.__vault_client.get_tokens_api() From d17e71d2fd34134221232d9ad55506dd6b011e86 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Mon, 2 Feb 2026 17:44:50 +0530 Subject: [PATCH 2/3] SK-2496: addressed review comments and suggestions --- ruff.toml | 2 +- skyflow/utils/_skyflow_messages.py | 1 + skyflow/utils/_utils.py | 23 +- skyflow/utils/constants.py | 115 +++++++++ skyflow/utils/validations/_validations.py | 288 ++++++++++++---------- skyflow/vault/controller/_connections.py | 4 +- skyflow/vault/controller/_detect.py | 274 ++++++++++---------- skyflow/vault/controller/_vault.py | 4 +- 8 files changed, 423 insertions(+), 288 deletions(-) diff --git a/ruff.toml b/ruff.toml index b6795704..8b0d5278 100644 --- a/ruff.toml +++ b/ruff.toml @@ -14,6 +14,6 @@ exclude = [ line-length = 120 [lint] -select = ["N"] +select = ["N", "PLR2004"] [lint.pep8-naming] diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 1954ed4d..21665972 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -389,6 +389,7 @@ class ErrorLogs(Enum): SAVING_DEIDENTIFY_FILE_FAILED = f"{ERROR}: [{error_prefix}] Error while saving deidentified file to output directory." REIDENTIFY_TEXT_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Reidentify text resulted in failure." DETECT_FILE_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify file resulted in failure." + EMPTY_FILE_COLUMN_NAME = f"{ERROR}: [{error_prefix}] Empty column name in FILE_UPLOAD" class Interface(Enum): INSERT = "INSERT" diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index c6f294cd..83c93b0c 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -21,7 +21,8 @@ from skyflow.vault.detect import EntityInfo, TextIndex from . import SkyflowMessages, SDK_VERSION from .constants import (PROTOCOL, HttpHeader, ApiKey, ContentType as ContentTypeConstants, - EncodingType, BooleanString, ResponseField, CredentialField) + EncodingType, BooleanString, ResponseField, CredentialField, SdkPrefix, + SdkMetricsKey, ErrorDefaults, HttpStatusCode) from .enums import Env, ContentType, EnvUrls from skyflow.vault.data import InsertResponse, UpdateResponse, DeleteResponse, QueryResponse, GetResponse from .validations import validate_invoke_connection_params @@ -129,7 +130,7 @@ def construct_invoke_connection_request(request, connection_url, logger) -> Prep validate_invoke_connection_params(logger, request.query_params, request.path_params) - if not hasattr(request.method, 'value'): + if not hasattr(request.method, ResponseField.VALUE): raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_METHOD.value, invalid_input_error_code) try: @@ -187,7 +188,7 @@ def get_data_from_content_type(data, content_type): def get_metrics(): - sdk_name_version = "skyflow-python@" + SDK_VERSION + sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION try: sdk_client_device_model = platform.node() @@ -205,10 +206,10 @@ def get_metrics(): sdk_runtime_details = "" details_dic = { - 'sdk_name_version': sdk_name_version, - 'sdk_client_device_model': sdk_client_device_model, - 'sdk_client_os_details': sdk_client_os_details, - 'sdk_runtime_details': "Python " + sdk_runtime_details, + SdkMetricsKey.SDK_NAME_VERSION: sdk_name_version, + SdkMetricsKey.SDK_CLIENT_DEVICE_MODEL: sdk_client_device_model, + SdkMetricsKey.SDK_CLIENT_OS_DETAILS: sdk_client_os_details, + SdkMetricsKey.SDK_RUNTIME_DETAILS: SdkPrefix.PYTHON_RUNTIME + sdk_runtime_details, } return details_dic @@ -223,7 +224,7 @@ def parse_insert_response(api_response, continue_on_error): insert_response = InsertResponse() if continue_on_error: for idx, response in enumerate(api_response_data.responses): - if response[ResponseField.STATUS] == 200: + if response[ResponseField.STATUS] == HttpStatusCode.OK: body = response[ResponseField.BODY] if ResponseField.RECORDS in body: for record in body[ResponseField.RECORDS]: @@ -235,7 +236,7 @@ def parse_insert_response(api_response, continue_on_error): if ResponseField.TOKENS in record: inserted_field.update(record[ResponseField.TOKENS]) inserted_fields.append(inserted_field) - elif response[ResponseField.STATUS] == 400: + elif response[ResponseField.STATUS] == HttpStatusCode.BAD_REQUEST: error = { ResponseField.REQUEST_INDEX: idx, ResponseField.REQUEST_ID: request_id, @@ -352,7 +353,7 @@ def parse_invoke_connection_response(api_response: requests.Response): data = json.loads(content) metadata = {} if HttpHeader.X_REQUEST_ID in api_response.headers: - metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] + metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID] return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: @@ -400,7 +401,7 @@ def handle_exception(error, logger): if (isinstance(error, httpx.ConnectError)): handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) - request_id = error.headers.get(HttpHeader.X_REQUEST_ID, 'unknown-request-id') + request_id = error.headers.get(HttpHeader.X_REQUEST_ID, ErrorDefaults.UNKNOWN_REQUEST_ID) content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) data = error.body diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 30cb124d..62aa4d11 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -13,11 +13,13 @@ class HttpHeader: X_REQUEST_ID = 'x-request-id' ERROR_FROM_CLIENT = 'error-from-client' AUTHORIZATION = 'Authorization' + X_SKYFLOW_AUTHORIZATION_HEADER = 'X-Skyflow-Authorization' class HttpStatusCode: OK = 200 BAD_REQUEST = 400 + UNAUTHORIZED = 401 INTERNAL_SERVER_ERROR = 500 @@ -123,6 +125,8 @@ class CredentialField: API_KEY = 'api_key' TOKEN = 'token' PATH = 'path' + CONTEXT = 'context' + ROLES = 'roles' class JwtField: @@ -165,3 +169,114 @@ class RequestParameter: COLUMN_GROUP = 'column_group' REDACTION = 'redaction' + +class FileUploadField: + TABLE = 'table' + SKYFLOW_ID = 'skyflow_id' + COLUMN_NAME = 'column_name' + FILE_PATH = 'file_path' + BASE64 = 'base64' + FILE_OBJECT = 'file_object' + FILE_NAME = 'file_name' + FILE = 'file' + NAME = 'name' + + +class DeidentifyFileRequestField: + ENTITIES = 'entities' + ALLOW_REGEX_LIST = 'allow_regex_list' + RESTRICT_REGEX_LIST = 'restrict_regex_list' + OUTPUT_PROCESSED_IMAGE = 'output_processed_image' + OUTPUT_OCR_TEXT = 'output_ocr_text' + MASKING_METHOD = 'masking_method' + PIXEL_DENSITY = 'pixel_density' + MAX_RESOLUTION = 'max_resolution' + OUTPUT_PROCESSED_AUDIO = 'output_processed_audio' + OUTPUT_TRANSCRIPTION = 'output_transcription' + BLEEP = 'bleep' + OUTPUT_DIRECTORY = 'output_directory' + WAIT_TIME = 'wait_time' + + +class DeidentifyField: + TEXT = 'text' + ENTITY_TYPES = 'entity_types' + TOKEN_TYPE = 'token_type' + ALLOW_REGEX = 'allow_regex' + RESTRICT_REGEX = 'restrict_regex' + TRANSFORMATIONS = 'transformations' + FORMAT = 'format' + OUTPUT = 'output' + STATUS = 'status' + RUN_ID = 'run_id' + WORD_CHARACTER_COUNT = 'word_character_count' + WORD_COUNT = 'word_count' + CHARACTER_COUNT = 'character_count' + SIZE = 'size' + DURATION = 'duration' + PAGES = 'pages' + SLIDES = 'slides' + PROCESSED_FILE = 'processed_file' + PROCESSED_FILE_TYPE = 'processed_file_type' + PROCESSED_FILE_EXTENSION = 'processed_file_extension' + REDACTED_FILE = 'redacted_file' + SHIFT_DATES = 'shift_dates' + DEFAULT = 'default' + ENTITY_UNQ_COUNTER = 'entity_unq_counter' + ENTITY_UNIQUE_COUNTER = 'entity_unique_counter' + ENTITY_ONLY = 'entity_only' + ENTITIES = 'entities' + MAX_DAYS = 'max_days' + MIN_DAYS = 'min_days' + MAX = 'max' + MIN = 'min' + FILE = 'file' + TYPE = 'type' + EXTENSION = 'extension' + IN_PROGRESS = 'IN_PROGRESS' + REQUEST_OPTIONS = 'request_options' + BLEEP_GAIN = 'bleep_gain' + BLEEP_FREQUENCY = 'bleep_frequency' + BLEEP_START_PADDING = 'bleep_start_padding' + BLEEP_STOP_PADDING = 'bleep_stop_padding' + DENSITY = 'density' + TOKEN_FORMAT = 'token_format' + PROCESSED_FILE_RESPONSE_KEY = 'processedFile' + PROCESSED_FILE_TYPE_RESPONSE_KEY = 'processedFileType' + PROCESSED_FILE_EXTENSION_RESPONSE_KEY = 'processedFileExtension' + + +class RequestOperation: + INSERT = 'INSERT' + DELETE = 'DELETE' + GET = 'GET' + UPDATE = 'UPDATE' + QUERY = 'QUERY' + TOKENIZE = 'TOKENIZE' + DETOKENIZE = 'DETOKENIZE' + FILE_UPLOAD = 'FILE_UPLOAD' + + +class ConfigType: + VAULT = 'vault' + CONNECTION = 'connection' + + +class SqlCommand: + SELECT = 'SELECT' + + +class SdkPrefix: + SKYFLOW_PYTHON = 'skyflow-python@' + PYTHON_RUNTIME = 'Python ' + + +class SdkMetricsKey: + SDK_NAME_VERSION = 'sdk_name_version' + SDK_CLIENT_DEVICE_MODEL = 'sdk_client_device_model' + SDK_CLIENT_OS_DETAILS = 'sdk_client_os_details' + SDK_RUNTIME_DETAILS = 'sdk_runtime_details' + + +class ErrorDefaults: + UNKNOWN_REQUEST_ID = 'unknown-request-id' diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 779fdfcc..2ac5783c 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -6,47 +6,66 @@ MaskingMethod from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages -from skyflow.utils.constants import ApiKey, ResponseField +from skyflow.utils.constants import ( + ApiKey, ResponseField, RequestParameter, + FileUploadField, + DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField +) from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest from skyflow.vault.detect._file_input import FileInput -valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] -valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] -valid_credentials_keys = ["path", "roles", "context", "token", "credentials_string"] +valid_vault_config_keys = [ + ConfigField.VAULT_ID, + ConfigField.CLUSTER_ID, + ConfigField.CREDENTIALS, + ConfigField.ENV +] +valid_connection_config_keys = [ + OptionField.CONNECTION_ID, + OptionField.CONNECTION_URL, + ConfigField.CREDENTIALS +] +valid_credentials_keys = [ + CredentialField.PATH, + CredentialField.ROLES, + CredentialField.CONTEXT, + CredentialField.TOKEN, + CredentialField.CREDENTIALS_STRING +] invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def validate_required_field(logger, config, field_name, expected_type, empty_error, invalid_error): field_value = config.get(field_name) if field_name not in config or not isinstance(field_value, expected_type): - if field_name == "vault_id": + if field_name == ConfigField.VAULT_ID: logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) - if field_name == "cluster_id": + if field_name == ConfigField.CLUSTER_ID: logger.error(SkyflowMessages.ErrorLogs.CLUSTER_ID_IS_REQUIRED.value) - if field_name == "connection_id": + if field_name == OptionField.CONNECTION_ID: logger.error(SkyflowMessages.ErrorLogs.CONNECTION_ID_IS_REQUIRED.value) - if field_name == "connection_url": + if field_name == OptionField.CONNECTION_URL: logger.error(SkyflowMessages.ErrorLogs.INVALID_CONNECTION_URL.value) raise SkyflowError(invalid_error, invalid_input_error_code) if isinstance(field_value, str) and not field_value.strip(): - if field_name == "vault_id": + if field_name == ConfigField.VAULT_ID: logger.error(SkyflowMessages.ErrorLogs.EMPTY_VAULTID.value) - if field_name == "cluster_id": + if field_name == ConfigField.CLUSTER_ID: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CLUSTER_ID.value) - if field_name == "connection_id": + if field_name == OptionField.CONNECTION_ID: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_ID.value) - if field_name == "connection_url": + if field_name == OptionField.CONNECTION_URL: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CONNECTION_URL.value) - if field_name == "path": + if field_name == CredentialField.PATH: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_PATH.value) - if field_name == "credentials_string": + if field_name == CredentialField.CREDENTIALS_STRING: logger.error(SkyflowMessages.ErrorLogs.EMPTY_CREDENTIALS_STRING.value) - if field_name == "token": + if field_name == CredentialField.TOKEN: logger.error(SkyflowMessages.ErrorLogs.EMPTY_TOKEN_VALUE.value) - if field_name == "api_key": + if field_name == CredentialField.API_KEY: logger.error(SkyflowMessages.ErrorLogs.EMPTY_API_KEY_VALUE.value) raise SkyflowError(empty_error, invalid_input_error_code) @@ -62,7 +81,7 @@ def validate_api_key(api_key: str, logger = None) -> bool: return True def validate_credentials(logger, credentials, config_id_type=None, config_id=None): - key_present = [k for k in ["path", "token", "credentials_string", "api_key"] if credentials.get(k)] + key_present = [k for k in [CredentialField.PATH, CredentialField.TOKEN, CredentialField.CREDENTIALS_STRING, CredentialField.API_KEY] if credentials.get(k)] if len(key_present) == 0: error_message = ( @@ -79,63 +98,63 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non ) raise SkyflowError(error_message, invalid_input_error_code) - if "roles" in credentials: + if CredentialField.ROLES in credentials: validate_required_field( - logger, credentials, "roles", list, + logger, credentials, CredentialField.ROLES, list, SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_ROLES_KEY_TYPE.value, SkyflowMessages.Error.EMPTY_ROLES_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_ROLES.value ) - if "context" in credentials: + if CredentialField.CONTEXT in credentials: validate_required_field( - logger, credentials, "context", str, + logger, credentials, CredentialField.CONTEXT, str, SkyflowMessages.Error.EMPTY_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CONTEXT.value, SkyflowMessages.Error.INVALID_CONTEXT_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CONTEXT.value ) - if "credentials_string" in credentials: + if CredentialField.CREDENTIALS_STRING in credentials: validate_required_field( - logger, credentials, "credentials_string", str, + logger, credentials, CredentialField.CREDENTIALS_STRING, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_STRING.value, SkyflowMessages.Error.INVALID_CREDENTIALS_STRING_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_STRING.value ) - elif "path" in credentials: + elif CredentialField.PATH in credentials: validate_required_field( - logger, credentials, "path", str, + logger, credentials, CredentialField.PATH, str, SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIAL_FILE_PATH.value, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value ) - elif "token" in credentials: + elif CredentialField.TOKEN in credentials: validate_required_field( - logger, credentials, "token", str, + logger, credentials, CredentialField.TOKEN, str, SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_CREDENTIALS_TOKEN.value, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value ) - if is_expired(credentials.get("token"), logger): + if is_expired(credentials.get(CredentialField.TOKEN), logger): raise SkyflowError( SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, invalid_input_error_code ) - elif "api_key" in credentials: + elif CredentialField.API_KEY in credentials: validate_required_field( - logger, credentials, "api_key", str, + logger, credentials, CredentialField.API_KEY, str, SkyflowMessages.Error.EMPTY_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.EMPTY_API_KEY.value, SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value ) - if not validate_api_key(credentials.get("api_key"), logger): + if not validate_api_key(credentials.get(CredentialField.API_KEY), logger): raise SkyflowError(SkyflowMessages.Error.INVALID_API_KEY.value.format(config_id_type, config_id) if config_id_type and config_id else SkyflowMessages.Error.INVALID_API_KEY.value, invalid_input_error_code) @@ -158,27 +177,27 @@ def validate_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) # Validate cluster_id (string, not empty) validate_required_field( - logger, config, "cluster_id", str, + logger, config, ConfigField.CLUSTER_ID, str, SkyflowMessages.Error.EMPTY_CLUSTER_ID.value.format(vault_id), SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id) ) # Validate credentials (dict, not empty) - if "credentials" in config and not config.get("credentials"): - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS in config and not config.get(ConfigField.CREDENTIALS): + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - if "credentials" in config and config.get("credentials"): - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + if ConfigField.CREDENTIALS in config and config.get(ConfigField.CREDENTIALS): + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) # Validate env (optional, should be one of LogLevel values) - if "env" in config and config.get("env") not in Env: + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: logger.error(SkyflowMessages.ErrorLogs.VAULTID_IS_REQUIRED.value) raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) @@ -190,23 +209,23 @@ def validate_update_vault_config(logger, config): # Validate vault_id (string, not empty) validate_required_field( - logger, config, "vault_id", str, + logger, config, ConfigField.VAULT_ID, str, SkyflowMessages.Error.EMPTY_VAULT_ID.value, SkyflowMessages.Error.INVALID_VAULT_ID.value ) - vault_id = config.get("vault_id") + vault_id = config.get(ConfigField.VAULT_ID) - if "cluster_id" in config and not config.get("cluster_id"): + if ConfigField.CLUSTER_ID in config and not config.get(ConfigField.CLUSTER_ID): raise SkyflowError(SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format(vault_id), invalid_input_error_code) - if "env" in config and config.get("env") not in Env: + if ConfigField.ENV in config and config.get(ConfigField.ENV) not in Env: raise SkyflowError(SkyflowMessages.Error.INVALID_ENV.value.format(vault_id), invalid_input_error_code) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", vault_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.VAULT, vault_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "vault", vault_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.VAULT, vault_id) return True @@ -215,23 +234,23 @@ def validate_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id" , str, + logger, config, OptionField.CONNECTION_ID , str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id) return True @@ -240,22 +259,22 @@ def validate_update_connection_config(logger, config): validate_keys(logger, config, valid_connection_config_keys) validate_required_field( - logger, config, "connection_id", str, + logger, config, OptionField.CONNECTION_ID, str, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value, SkyflowMessages.Error.INVALID_CONNECTION_ID.value ) - connection_id = config.get("connection_id") + connection_id = config.get(OptionField.CONNECTION_ID) validate_required_field( - logger, config, "connection_url", str, + logger, config, OptionField.CONNECTION_URL, str, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format(connection_id), SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id) ) - if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) - validate_credentials(logger, config.get("credentials")) + if ConfigField.CREDENTIALS not in config: + raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code) + validate_credentials(logger, config.get(ConfigField.CREDENTIALS)) return True @@ -263,8 +282,8 @@ def validate_file_from_request(file_input: FileInput): if file_input is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) - has_file = hasattr(file_input, 'file') and file_input.file is not None - has_file_path = hasattr(file_input, 'file_path') and file_input.file_path is not None + has_file = hasattr(file_input, FileUploadField.FILE) and file_input.file is not None + has_file_path = hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None # Must provide exactly one of file or file_path if (has_file and has_file_path) or (not has_file and not has_file_path): @@ -273,7 +292,7 @@ def validate_file_from_request(file_input: FileInput): if has_file: file = file_input.file # Validate file object has required attributes - if not hasattr(file, 'name') or not isinstance(file.name, str) or not file.name.strip(): + if not hasattr(file, FileUploadField.FILE_NAME) or not isinstance(file.name, str) or not file.name.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) # Validate file name @@ -290,14 +309,14 @@ def validate_file_from_request(file_input: FileInput): raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): - if not hasattr(request, 'file') or request.file is None: + if not hasattr(request, FileUploadField.FILE) or request.file is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) # Validate file input first validate_file_from_request(request.file) # Optional: entities - if hasattr(request, 'entities') and request.entities is not None: + if hasattr(request, DeidentifyFileRequestField.ENTITIES) and request.entities is not None: if not isinstance(request.entities, list): raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) @@ -305,12 +324,12 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value, invalid_input_error_code) # Optional: allow_regex_list - if hasattr(request, 'allow_regex_list') and request.allow_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.ALLOW_REGEX_LIST) and request.allow_regex_list is not None: if not isinstance(request.allow_regex_list, list) or not all(isinstance(x, str) for x in request.allow_regex_list): raise SkyflowError(SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value, invalid_input_error_code) # Optional: restrict_regex_list - if hasattr(request, 'restrict_regex_list') and request.restrict_regex_list is not None: + if hasattr(request, DeidentifyFileRequestField.RESTRICT_REGEX_LIST) and request.restrict_regex_list is not None: if not isinstance(request.restrict_regex_list, list) or not all(isinstance(x, str) for x in request.restrict_regex_list): raise SkyflowError(SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value, invalid_input_error_code) @@ -323,43 +342,42 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value, invalid_input_error_code) # Optional: output_processed_image - if hasattr(request, 'output_processed_image') and request.output_processed_image is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE) and request.output_processed_image is not None: if not isinstance(request.output_processed_image, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value, invalid_input_error_code) # Optional: output_ocr_text - if hasattr(request, 'output_ocr_text') and request.output_ocr_text is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT) and request.output_ocr_text is not None: if not isinstance(request.output_ocr_text, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value, invalid_input_error_code) # Optional: masking_method - # Optional: masking_method - if hasattr(request, 'masking_method') and request.masking_method is not None: + if hasattr(request, DeidentifyFileRequestField.MASKING_METHOD) and request.masking_method is not None: if not isinstance(request.masking_method, MaskingMethod): raise SkyflowError(SkyflowMessages.Error.INVALID_MASKING_METHOD.value, invalid_input_error_code) # Optional: pixel_density - if hasattr(request, 'pixel_density') and request.pixel_density is not None: + if hasattr(request, DeidentifyFileRequestField.PIXEL_DENSITY) and request.pixel_density is not None: if not isinstance(request.pixel_density, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value, invalid_input_error_code) # Optional: max_resolution - if hasattr(request, 'max_resolution') and request.max_resolution is not None: + if hasattr(request, DeidentifyFileRequestField.MAX_RESOLUTION) and request.max_resolution is not None: if not isinstance(request.max_resolution, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value, invalid_input_error_code) # Optional: output_processed_audio - if hasattr(request, 'output_processed_audio') and request.output_processed_audio is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO) and request.output_processed_audio is not None: if not isinstance(request.output_processed_audio, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value, invalid_input_error_code) # Optional: output_transcription - if hasattr(request, 'output_transcription') and request.output_transcription is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION) and request.output_transcription is not None: if not isinstance(request.output_transcription, DetectOutputTranscriptions): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value, invalid_input_error_code) # Optional: bleep - if hasattr(request, 'bleep') and request.bleep is not None: + if hasattr(request, DeidentifyFileRequestField.BLEEP) and request.bleep is not None: if not isinstance(request.bleep, Bleep): raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_TYPE.value, invalid_input_error_code) @@ -380,53 +398,53 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.INVALID_BLEEP_STOP_PADDING.value, invalid_input_error_code) # Optional: output_directory - if hasattr(request, 'output_directory') and request.output_directory is not None: + if hasattr(request, DeidentifyFileRequestField.OUTPUT_DIRECTORY) and request.output_directory is not None: if not isinstance(request.output_directory, str): raise SkyflowError(SkyflowMessages.Error.INVALID_OUTPUT_DIRECTORY_VALUE.value, invalid_input_error_code) if not os.path.isdir(request.output_directory): raise SkyflowError(SkyflowMessages.Error.OUTPUT_DIRECTORY_NOT_FOUND.value.format(request.output_directory), invalid_input_error_code) # Optional: wait_time - if hasattr(request, 'wait_time') and request.wait_time is not None: + if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: + if request.wait_time < 0 and request.wait_time > 64: # noqa: PLR2004 raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) if not isinstance(request.values, list) or not all(isinstance(v, dict) for v in request.values): - log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.VALUES_IS_REQUIRED.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if not len(request.values): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format("INSERT"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_VALUES.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value, invalid_input_error_code) for i, item in enumerate(request.values, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger) if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format(RequestOperation.INSERT, key), logger = logger) if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) if request.homogeneous is not None and not isinstance(request.homogeneous, bool): raise SkyflowError(SkyflowMessages.Error.INVALID_HOMOGENEOUS_TYPE.value, invalid_input_error_code) if request.upsert and request.homogeneous: - log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), logger = logger) - raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), invalid_input_error_code) + log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), logger = logger) + raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format(RequestOperation.INSERT), invalid_input_error_code) if request.token_mode is not None: if not isinstance(request.token_mode, TokenMode): @@ -442,15 +460,15 @@ def validate_insert_request(logger, request): for i, item in enumerate(request.tokens, start=1): for key, value in item.items(): if key is None or key == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format("INSERT"), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_TOKENS.value.format(RequestOperation.INSERT), logger=logger) if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format("INSERT", key), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_TOKENS.value.format(RequestOperation.INSERT, key), logger=logger) if not isinstance(request.tokens, list) or not request.tokens or not all( isinstance(t, dict) for t in request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -461,29 +479,29 @@ def validate_insert_request(logger, request): if request.token_mode == TokenMode.ENABLE_STRICT: if len(request.values) != len(request.tokens): - log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("INSERT"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.INSERT), logger = logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) for v, t in zip(request.values, request.tokens): if set(v.keys()) != set(t.keys()): - log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format("INSERT"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format(RequestOperation.INSERT), logger=logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) def validate_delete_request(logger, request): if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not request.ids: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("DELETE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.DELETE), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value, invalid_input_error_code) def validate_query_request(logger, request): if not request.query: - log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format("QUERY"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.QUERY_IS_REQUIRED.value.format(RequestOperation.QUERY), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) if not isinstance(request.query, str): @@ -491,10 +509,10 @@ def validate_query_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(query_type), invalid_input_error_code) if not request.query.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format("QUERY"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_QUERY.value.format(RequestOperation.QUERY), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_QUERY.value, invalid_input_error_code) - if not request.query.upper().startswith("SELECT"): + if not request.query.upper().startswith(SqlCommand.SELECT): command = request.query raise SkyflowError(SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(command), invalid_input_error_code) @@ -509,23 +527,23 @@ def validate_get_request(logger, request): download_url = request.download_url if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not skyflow_ids and not column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.NEITHER_IDS_NOR_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) if skyflow_ids and (not isinstance(skyflow_ids, list) or not skyflow_ids): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_IDS_TYPE.value.format(type(skyflow_ids)), invalid_input_error_code) if skyflow_ids: for index, skyflow_id in enumerate(skyflow_ids): if skyflow_id is None or skyflow_id == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format("GET", index), + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_ID_IN_IDS.value.format(RequestOperation.GET, index), logger=logger) if not isinstance(request.return_tokens, bool): @@ -535,7 +553,7 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction_type)), invalid_input_error_code) if fields is not None and (not isinstance(fields, list) or not fields): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FIELDS.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(fields)), invalid_input_error_code) if offset is not None and limit is not None: @@ -561,24 +579,24 @@ def validate_get_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if request.return_tokens and redaction_type: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format("GET"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_NOT_SUPPORTED_WITH_REDACTION.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value, invalid_input_error_code) if (column_name or column_values) and request.return_tokens: - log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format("GET"), + log_error_log(SkyflowMessages.ErrorLogs.TOKENIZATION_SUPPORTED_ONLY_WITH_IDS.value.format(RequestOperation.GET), logger=logger) raise SkyflowError(SkyflowMessages.Error.TOKENS_GET_COLUMN_NOT_SUPPORTED.value, invalid_input_error_code) if column_values and not column_name: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_GET.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code) if column_name and not column_values: - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format(RequestOperation.GET), logger = logger) SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) if (column_name or column_values) and skyflow_ids: - log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger) raise SkyflowError(SkyflowMessages.Error.BOTH_IDS_AND_COLUMN_DETAILS_SPECIFIED.value, invalid_input_error_code) def validate_update_request(logger, request): @@ -588,16 +606,16 @@ def validate_update_request(logger, request): try: skyflow_id = request.data.get(ResponseField.SKYFLOW_ID) except Exception: - log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.SKYFLOW_ID_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) if not skyflow_id.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format("UPDATE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_SKYFLOW_ID.value.format(RequestOperation.UPDATE), logger = logger) if not isinstance(request.table, str): - log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) if not request.table.strip(): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("UPDATE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format(RequestOperation.UPDATE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) if not isinstance(request.return_tokens, bool): @@ -615,7 +633,7 @@ def validate_update_request(logger, request): if request.tokens: if not isinstance(request.tokens, dict) or not request.tokens: - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("UPDATE"), logger=logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code) if request.token_mode == TokenMode.ENABLE and not request.tokens: @@ -628,14 +646,14 @@ def validate_update_request(logger, request): if request.token_mode == TokenMode.ENABLE_STRICT: if len(field) != len(request.tokens): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code) if set(field.keys()) != set(request.tokens.keys()): log_error_log( - SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"), + SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format(RequestOperation.UPDATE), logger=logger) raise SkyflowError( SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, @@ -649,20 +667,20 @@ def validate_detokenize_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code) if not len(request.data): - log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format("DETOKENIZE"), logger = logger) - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("DETOKENIZE"), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.TOKENS_REQUIRED.value.format(RequestOperation.DETOKENIZE), logger = logger) + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format(RequestOperation.DETOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code) for item in request.data: - if 'token' not in item: + if ResponseField.TOKEN not in item: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(type(request.data)), invalid_input_error_code) - token = item.get('token') - redaction = item.get('redaction', None) + token = item.get(ResponseField.TOKEN) + redaction = item.get(RequestParameter.REDACTION, None) if not isinstance(token, str) or not token: - raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE"), + raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format(RequestOperation.DETOKENIZE), invalid_input_error_code) if redaction is not None and not isinstance(redaction, RedactionType): @@ -681,16 +699,16 @@ def validate_tokenize_request(logger, request): if not isinstance(param, dict): raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER.value.format(i, type(param)), invalid_input_error_code) - allowed_keys = {"value", "column_group"} + allowed_keys = {RequestParameter.VALUE, RequestParameter.COLUMN_GROUP} if set(param.keys()) != allowed_keys: raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(i), invalid_input_error_code) - if not param.get("value"): - log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.VALUE): + log_error_log(SkyflowMessages.ErrorLogs.COLUMN_VALUES_IS_REQUIRED_TOKENIZE.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_VALUE.value.format(i), invalid_input_error_code) - if not param.get("column_group"): - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format("TOKENIZE"), logger = logger) + if not param.get(RequestParameter.COLUMN_GROUP): + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES.value.format(RequestOperation.TOKENIZE), logger = logger) raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETER_COLUMN_GROUP.value.format(i), invalid_input_error_code) @@ -699,32 +717,32 @@ def validate_file_upload_request(logger, request): raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) # Table - table = getattr(request, "table", None) + table = getattr(request, FileUploadField.TABLE, None) if table is None: raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_VALUE.value, invalid_input_error_code) elif table.strip() == "": raise SkyflowError(SkyflowMessages.Error.EMPTY_TABLE_VALUE.value, invalid_input_error_code) # Skyflow ID - skyflow_id = getattr(request, "skyflow_id", None) + skyflow_id = getattr(request, FileUploadField.SKYFLOW_ID, None) if skyflow_id is None: raise SkyflowError(SkyflowMessages.Error.IDS_KEY_ERROR.value, invalid_input_error_code) elif skyflow_id.strip() == "": - raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format("FILE_UPLOAD"), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.EMPTY_SKYFLOW_ID.value.format(RequestOperation.FILE_UPLOAD), invalid_input_error_code) # Column Name - column_name = getattr(request, "column_name", None) + column_name = getattr(request, FileUploadField.COLUMN_NAME, None) if column_name is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) elif column_name.strip() == "": - logger.error("Empty column name in FILE_UPLOAD") + log_error_log(SkyflowMessages.ErrorLogs.EMPTY_FILE_COLUMN_NAME.value, logger) raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) # File-related attributes - file_path = getattr(request, "file_path", None) - base64_str = getattr(request, "base64", None) - file_object = getattr(request, "file_object", None) - file_name = getattr(request, "file_name", None) + file_path = getattr(request, FileUploadField.FILE_PATH, None) + base64_str = getattr(request, FileUploadField.BASE64, None) + file_object = getattr(request, FileUploadField.FILE_OBJECT, None) + file_name = getattr(request, FileUploadField.FILE_NAME, None) # Check file_path first if present if not is_none_or_empty(file_path): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 83b0ffbd..ca8c7a1d 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,7 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest -from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader class Connection: @@ -23,7 +23,7 @@ def invoke(self, request: InvokeConnectionRequest): invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) - if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: + if not HttpHeader.X_SKYFLOW_AUTHORIZATION_HEADER.lower() in invoke_connection_request.headers: invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token invoke_connection_request.headers[SKY_META_DATA_HEADER] = json.dumps(get_metrics()) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 4f2f50f2..c6ef2fb1 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -9,7 +9,7 @@ from skyflow.utils._skyflow_messages import SkyflowMessages from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response from skyflow.utils.constants import (SKY_META_DATA_HEADER, DetectStatus, FileExtension, - FileProcessing, EncodingType) + FileProcessing, EncodingType, DeidentifyField, DeidentifyFileRequestField, FileUploadField, OptionField) from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request from skyflow.utils.validations._validations import validate_deidentify_text_request, validate_reidentify_text_request @@ -34,12 +34,12 @@ def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[ deidentify_text_body = {} parsed_entity_types = request.entities - deidentify_text_body['text'] = request.text - deidentify_text_body['entity_types'] = parsed_entity_types - deidentify_text_body['token_type'] = self.__get_token_format(request) - deidentify_text_body['allow_regex'] = request.allow_regex_list - deidentify_text_body['restrict_regex'] = request.restrict_regex_list - deidentify_text_body['transformations'] = self.__get_transformations(request) + deidentify_text_body[DeidentifyField.TEXT] = request.text + deidentify_text_body[DeidentifyField.ENTITY_TYPES] = parsed_entity_types + deidentify_text_body[DeidentifyField.TOKEN_TYPE] = self.__get_token_format(request) + deidentify_text_body[DeidentifyField.ALLOW_REGEX] = request.allow_regex_list + deidentify_text_body[DeidentifyField.RESTRICT_REGEX] = request.restrict_regex_list + deidentify_text_body[DeidentifyField.TRANSFORMATIONS] = self.__get_transformations(request) return deidentify_text_body @@ -50,8 +50,8 @@ def ___build_reidentify_text_body(self, request: ReidentifyTextRequest) -> Dict[ plaintext=request.plain_text_entities ) reidentify_text_body = {} - reidentify_text_body['text'] = request.text - reidentify_text_body['format'] = parsed_format + reidentify_text_body[DeidentifyField.TEXT] = request.text + reidentify_text_body[DeidentifyField.FORMAT] = parsed_format return reidentify_text_body def _get_file_extension(self, filename: str): @@ -67,7 +67,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): status = response.status if status == DetectStatus.IN_PROGRESS: if current_wait_time >= max_wait_time: - return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS') + return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: next_wait_time = current_wait_time * 2 if next_wait_time >= max_wait_time: @@ -83,7 +83,7 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): raise e def __save_deidentify_file_response_output(self, response: DetectRunsResponse, output_directory: str, original_file_name: str, name_without_ext: str): - if not response or not hasattr(response, 'output') or not response.output or not output_directory: + if not response or not hasattr(response, DeidentifyField.OUTPUT) or not response.output or not output_directory: return if not os.path.exists(output_directory): @@ -97,16 +97,16 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o for idx, output in enumerate(output_list): try: - processed_file = get_attribute(output, 'processedFile', 'processed_file') - processed_file_type = get_attribute(output, 'processedFileType', 'processed_file_type') - processed_file_extension = get_attribute(output, 'processedFileExtension', 'processed_file_extension') + processed_file = get_attribute(output, DeidentifyField.PROCESSED_FILE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE) + processed_file_type = get_attribute(output, DeidentifyField.PROCESSED_FILE_TYPE_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_TYPE) + processed_file_extension = get_attribute(output, DeidentifyField.PROCESSED_FILE_EXTENSION_RESPONSE_KEY, DeidentifyField.PROCESSED_FILE_EXTENSION) if not processed_file: continue decoded_data = base64.b64decode(processed_file) - if idx == 0 or processed_file_type == 'redacted_file': + if idx == 0 or processed_file_type == DeidentifyField.REDACTED_FILE: output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) if processed_file_extension: output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") @@ -120,62 +120,62 @@ def __save_deidentify_file_response_output(self, response: DetectRunsResponse, o handle_exception(e, self.__vault_client.get_logger()) def __parse_deidentify_file_response(self, data, run_id=None, status=None): - output = getattr(data, "output", []) - status_val = getattr(data, "status", None) or status - run_id_val = getattr(data, "run_id", None) or run_id + output = getattr(data, DeidentifyField.OUTPUT, []) + status_val = getattr(data, DeidentifyField.STATUS, None) or status + run_id_val = getattr(data, DeidentifyField.RUN_ID, None) or run_id word_count = None char_count = None - word_character_count = getattr(data, "word_character_count", None) + word_character_count = getattr(data, DeidentifyField.WORD_CHARACTER_COUNT, None) if word_character_count and isinstance(word_character_count, WordCharacterCount): - word_count = word_character_count.word_count - char_count = word_character_count.character_count + word_count = getattr(word_character_count, DeidentifyField.WORD_COUNT, None) + char_count = getattr(word_character_count, DeidentifyField.CHARACTER_COUNT, None) - size = getattr(data, "size", None) + size = getattr(data, DeidentifyField.SIZE, None) size = float(size) if size is not None else None - duration = getattr(data, "duration", None) - pages = getattr(data, "pages", None) - slides = getattr(data, "slides", None) + duration = getattr(data, DeidentifyField.DURATION, None) + pages = getattr(data, DeidentifyField.PAGES, None) + slides = getattr(data, DeidentifyField.SLIDES, None) def output_to_dict_list(output): result = [] for o in output: if isinstance(o, dict): result.append({ - "file": o.get("processed_file"), - "type": o.get("processed_file_type"), - "extension": o.get("processed_file_extension") + DeidentifyField.FILE: o.get(DeidentifyField.PROCESSED_FILE), + DeidentifyField.TYPE: o.get(DeidentifyField.PROCESSED_FILE_TYPE), + DeidentifyField.EXTENSION: o.get(DeidentifyField.PROCESSED_FILE_EXTENSION) }) else: result.append({ - "file": getattr(o, "processed_file", None), - "type": getattr(o, "processed_file_type", None), - "extension": getattr(o, "processed_file_extension", None) + DeidentifyField.FILE: getattr(o, DeidentifyField.PROCESSED_FILE, None), + DeidentifyField.TYPE: getattr(o, DeidentifyField.PROCESSED_FILE_TYPE, None), + DeidentifyField.EXTENSION: getattr(o, DeidentifyField.PROCESSED_FILE_EXTENSION, None) }) return result output_list = output_to_dict_list(output) first_output = output_list[0] if output_list else {} - entities = [o for o in output_list if o.get("type") == FileProcessing.ENTITIES] + entities = [o for o in output_list if o.get(DeidentifyField.TYPE) == FileProcessing.ENTITIES] - base64_string = first_output.get("file", None) - extension = first_output.get("extension", None) + base64_string = first_output.get(DeidentifyField.FILE, None) + extension = first_output.get(DeidentifyField.EXTENSION, None) if base64_string is not None: file_bytes = base64.b64decode(base64_string) file_obj = io.BytesIO(file_bytes) - file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else "processed_file" + file_obj.name = f"{FileProcessing.DEIDENTIFIED_PREFIX}{extension}" if extension else DeidentifyField.PROCESSED_FILE else: file_obj = None return DeidentifyFileResponse( file_base64=base64_string, file=file_obj, - type=first_output.get("type", DetectStatus.UNKNOWN), + type=first_output.get(DeidentifyField.TYPE, DetectStatus.UNKNOWN), extension=extension, word_count=word_count, char_count=char_count, @@ -189,25 +189,25 @@ def output_to_dict_list(output): ) def __get_token_format(self, request): - if not hasattr(request, "token_format") or request.token_format is None: + if not hasattr(request, DeidentifyField.TOKEN_FORMAT) or request.token_format is None: return None return { - 'default': getattr(request.token_format, "default", None), - 'entity_unq_counter': getattr(request.token_format, "entity_unique_counter", None), - 'entity_only': getattr(request.token_format, "entity_only", None), + DeidentifyField.DEFAULT: getattr(request.token_format, DeidentifyField.DEFAULT, None), + DeidentifyField.ENTITY_UNQ_COUNTER: getattr(request.token_format, DeidentifyField.ENTITY_UNIQUE_COUNTER, None), + DeidentifyField.ENTITY_ONLY: getattr(request.token_format, DeidentifyField.ENTITY_ONLY, None), } def __get_transformations(self, request): - if not hasattr(request, "transformations") or request.transformations is None: + if not hasattr(request, DeidentifyField.TRANSFORMATIONS) or request.transformations is None: return None - shift_dates = getattr(request.transformations, "shift_dates", None) + shift_dates = getattr(request.transformations, DeidentifyField.SHIFT_DATES, None) if shift_dates is None: return None return { - 'shift_dates': { - 'max_days': getattr(shift_dates, "max", None), - 'min_days': getattr(shift_dates, "min", None), - 'entity_types': getattr(shift_dates, "entities", None) + DeidentifyField.SHIFT_DATES: { + DeidentifyField.MAX_DAYS: getattr(shift_dates, DeidentifyField.MAX, None), + DeidentifyField.MIN_DAYS: getattr(shift_dates, DeidentifyField.MIN, None), + DeidentifyField.ENTITY_TYPES: getattr(shift_dates, DeidentifyField.ENTITIES, None) } } @@ -223,12 +223,12 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.deidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=deidentify_text_body['text'], - entity_types=deidentify_text_body['entity_types'], - allow_regex=deidentify_text_body['allow_regex'], - restrict_regex=deidentify_text_body['restrict_regex'], - token_type=deidentify_text_body['token_type'], - transformations=deidentify_text_body['transformations'], + text=deidentify_text_body[DeidentifyField.TEXT], + entity_types=deidentify_text_body[DeidentifyField.ENTITY_TYPES], + allow_regex=deidentify_text_body[DeidentifyField.ALLOW_REGEX], + restrict_regex=deidentify_text_body[DeidentifyField.RESTRICT_REGEX], + token_type=deidentify_text_body[DeidentifyField.TOKEN_TYPE], + transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], request_options=self.__get_headers() ) deidentify_text_response = parse_deidentify_text_response(api_response) @@ -251,8 +251,8 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_TRIGGERED.value, self.__vault_client.get_logger()) api_response = detect_api.reidentify_string( vault_id=self.__vault_client.get_vault_id(), - text=reidentify_text_body['text'], - format=reidentify_text_body['format'], + text=reidentify_text_body[DeidentifyField.TEXT], + format=reidentify_text_body[DeidentifyField.FORMAT], request_options=self.__get_headers() ) reidentify_text_response = parse_reidentify_text_response(api_response) @@ -267,11 +267,11 @@ def __get_file_from_request(self, request: DeidentifyFileRequest): file_input = request.file # Check for file - if hasattr(file_input, 'file') and file_input.file is not None: + if hasattr(file_input, FileUploadField.FILE) and file_input.file is not None: return file_input.file # Check for file_path if file is not provided - if hasattr(file_input, 'file_path') and file_input.file_path is not None: + if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: return open(file_input.file_path, 'rb') def deidentify_file(self, request: DeidentifyFileRequest): @@ -280,7 +280,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response file_obj = self.__get_file_from_request(request) - file_name = getattr(file_obj, 'name', None) + file_name = getattr(file_obj, FileUploadField.NAME, None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) @@ -290,138 +290,138 @@ def deidentify_file(self, request: DeidentifyFileRequest): req_file = FileDataDeidentifyText(base_64=base64_string, data_format=FileExtension.TXT) api_call = files_api.deidentify_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.MP3, FileExtension.WAV]: req_file = FileDataDeidentifyAudio(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_audio api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'output_transcription': getattr(request, 'output_transcription', None), - 'output_processed_audio': getattr(request, 'output_processed_audio', None), - 'bleep_gain': getattr(request, 'bleep', None).gain if getattr(request, 'bleep', None) is not None else None, - 'bleep_frequency': getattr(request, 'bleep', None).frequency if getattr(request, 'bleep', None) is not None else None, - 'bleep_start_padding': getattr(request, 'bleep', None).start_padding if getattr(request, 'bleep', None) is not None else None, - 'bleep_stop_padding': getattr(request, 'bleep', None).stop_padding if getattr(request, 'bleep', None) is not None else None, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION: getattr(request, DeidentifyFileRequestField.OUTPUT_TRANSCRIPTION, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_AUDIO, None), + DeidentifyField.BLEEP_GAIN: getattr(request, DeidentifyFileRequestField.BLEEP, None).gain if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.BLEEP_FREQUENCY: getattr(request, DeidentifyFileRequestField.BLEEP, None).frequency if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.BLEEP_START_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).start_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.BLEEP_STOP_PADDING: getattr(request, DeidentifyFileRequestField.BLEEP, None).stop_padding if getattr(request, DeidentifyFileRequestField.BLEEP, None) is not None else None, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension == FileExtension.PDF: req_file = FileDataDeidentifyPdf(base_64=base64_string) api_call = files_api.deidentify_pdf api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'max_resolution': getattr(request, 'max_resolution', None), - 'density': getattr(request, 'pixel_density', None), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MAX_RESOLUTION: getattr(request, DeidentifyFileRequestField.MAX_RESOLUTION, None), + DeidentifyFileRequestField.PIXEL_DENSITY: getattr(request, DeidentifyFileRequestField.PIXEL_DENSITY, None), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.JPEG, FileExtension.JPG, FileExtension.PNG, FileExtension.BMP, FileExtension.TIF, FileExtension.TIFF]: req_file = FileDataDeidentifyImage(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_image api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'masking_method': getattr(request, 'masking_method', None), - 'output_ocr_text': getattr(request, 'output_ocr_text', None), - 'output_processed_image': getattr(request, 'output_processed_image', None), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyFileRequestField.MASKING_METHOD: getattr(request, DeidentifyFileRequestField.MASKING_METHOD, None), + DeidentifyFileRequestField.OUTPUT_OCR_TEXT: getattr(request, DeidentifyFileRequestField.OUTPUT_OCR_TEXT, None), + DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE: getattr(request, DeidentifyFileRequestField.OUTPUT_PROCESSED_IMAGE, None), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.PPT, FileExtension.PPTX]: req_file = FileDataDeidentifyPresentation(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_presentation api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.CSV, FileExtension.XLS, FileExtension.XLSX]: req_file = FileDataDeidentifySpreadsheet(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_spreadsheet api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.DOC, FileExtension.DOCX]: req_file = FileDataDeidentifyDocument(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_document api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } elif file_extension in [FileExtension.JSON, FileExtension.XML]: req_file = FileDataDeidentifyStructuredText(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_structured_text api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } else: req_file = FileData(base_64=base64_string, data_format=file_extension) api_call = files_api.deidentify_file api_kwargs = { - 'vault_id': self.__vault_client.get_vault_id(), - 'file': req_file, - 'entity_types': request.entities, - 'token_type': self.__get_token_format(request), - 'allow_regex': request.allow_regex_list, - 'restrict_regex': request.restrict_regex_list, - 'transformations': self.__get_transformations(request), - 'request_options': self.__get_headers() + OptionField.VAULT_ID: self.__vault_client.get_vault_id(), + DeidentifyField.FILE: req_file, + DeidentifyField.ENTITY_TYPES: request.entities, + DeidentifyField.TOKEN_TYPE: self.__get_token_format(request), + DeidentifyField.ALLOW_REGEX: request.allow_regex_list, + DeidentifyField.RESTRICT_REGEX: request.restrict_regex_list, + DeidentifyField.TRANSFORMATIONS: self.__get_transformations(request), + DeidentifyField.REQUEST_OPTIONS: self.__get_headers() } log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) api_response = api_call(**api_kwargs) - run_id = getattr(api_response.data, 'run_id', None) + run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) if request.output_directory and processed_response.status == DetectStatus.SUCCESS: @@ -452,7 +452,7 @@ def get_detect_run(self, request: GetDetectRunRequest): request_options=self.__get_headers() ) if response.data.status == DetectStatus.IN_PROGRESS: - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) + parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index a5cd94fd..856a1961 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -8,7 +8,7 @@ from skyflow.utils import SkyflowMessages, parse_insert_response, \ handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics -from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter +from skyflow.utils.constants import SKY_META_DATA_HEADER, ResponseField, RequestParameter, FileUploadField from skyflow.utils.enums import RequestMethod from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log @@ -82,7 +82,7 @@ def __get_file_for_file_upload(self, request: FileUploadRequest) -> Optional[Fil return (request.file_name, decoded_bytes) elif request.file_object is not None: - if hasattr(request.file_object, "name") and request.file_object.name: + if hasattr(request.file_object, FileUploadField.NAME) and request.file_object.name: file_name = os.path.basename(request.file_object.name) return (file_name, request.file_object) From 5eb3da98f28b72045345c1f465bc736a9f7fc347 Mon Sep 17 00:00:00 2001 From: skyflow-himanshu Date: Mon, 2 Feb 2026 17:50:14 +0530 Subject: [PATCH 3/3] SK-2496: added samples to ignore for linting --- ruff.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ruff.toml b/ruff.toml index 8b0d5278..aea6cce7 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,7 +8,8 @@ exclude = [ "venv", "build", "dist", - "tests" + "tests", + "samples" ] line-length = 120