From a306f16320ea5d2734b419ac2303c5f5e75e633b Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 2 Feb 2026 15:56:48 +0530 Subject: [PATCH 1/4] SK-2522: fix identified bugs --- skyflow/utils/_skyflow_messages.py | 10 +- skyflow/utils/_utils.py | 154 +++-- skyflow/utils/enums/content_types.py | 3 +- skyflow/utils/validations/_validations.py | 15 +- skyflow/vault/client/client.py | 3 +- skyflow/vault/controller/_connections.py | 13 +- tests/utils/test__utils.py | 663 ++++++++++++++++++- tests/utils/validations/test__validations.py | 63 ++ tests/vault/controller/test__connection.py | 174 ++++- 9 files changed, 1021 insertions(+), 77 deletions(-) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 1954ed4d..69228224 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -42,11 +42,12 @@ class Error(Enum): EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Specify a valid file path." EMPTY_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Specify a valid file path." INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Expected file path to be a string." - INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a string." + INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a valid file path." EMPTY_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid token for {{}} with id {{}}.Specify a valid credentials token." EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." + EXPIRED_BEARER_TOKEN = f"{error_prefix} Initialization failed. Bearer token is invalid or expired." EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." @@ -118,8 +119,9 @@ class Error(Enum): INVALID_IDS_TYPE = f"{error_prefix} Validation error. 'ids' has a value of type {{}}. Specify 'ids' as list." INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction' has a value of type {{}}. Specify 'redaction' as type Skyflow.RedactionType." - INVALID_COLUMN_NAME = f"{error_prefix} Validation error. 'column' has a value of type {{}}. Specify 'column' as a string." - INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. columnValues key has a value of type {{}}. Specify columnValues key as list." + INVALID_COLUMN_NAME = f"{error_prefix} Validation error. column_name has a value of type {{}}. Specify 'column' as a string." + INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list." + INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed." INVALID_FIELDS_VALUE = f"{error_prefix} Validation error. fields key has a value of type{{}}. Specify fields key as list." BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"${error_prefix} Validation error. Both offset and limit cannot be present at the same time" INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer." @@ -366,7 +368,7 @@ class ErrorLogs(Enum): SKYFLOW_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id is required." EMPTY_SKYFLOW_ID = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id can not be empty." - COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. ColumnValues are required." + COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. column_values are required." EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Column group can not be null or empty in column values at index %s2." EMPTY_QUERY= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Query can not be empty." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index c6f294cd..17d2f424 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -105,27 +105,41 @@ def convert_detected_entity_to_entity_info(detected_entity): def construct_invoke_connection_request(request, connection_url, logger) -> PreparedRequest: url = parse_path_params(connection_url.rstrip('/'), request.path_params) - try: - if isinstance(request.headers, dict): - header = to_lowercase_keys(json.loads( - json.dumps(request.headers))) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - except Exception: - raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + header = None + content_type = None - if not HttpHeader.CONTENT_TYPE.lower() in header: - header[HttpHeader.CONTENT_TYPE_LOWERCASE] = ContentType.JSON.value + if request.headers is not None: + try: + if isinstance(request.headers, dict): + header = to_lowercase_keys(json.loads( + json.dumps(request.headers))) + + content_type = header.get(HttpHeader.CONTENT_TYPE_LOWERCASE) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code) - try: - if isinstance(request.body, dict): - json_data, files = get_data_from_content_type( - request.body, header[HttpHeader.CONTENT_TYPE_LOWERCASE] - ) - else: + json_data = None + files = {} + + if request.body is not None: + try: + if isinstance(request.body, dict): + json_data, files = get_data_from_content_type( + request.body, content_type + ) + else: + raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + except SkyflowError: + raise + except Exception as e: raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) - except Exception as e: - raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code) + + if files and header and content_type == ContentType.FORMDATA.value: + header.pop(HttpHeader.CONTENT_TYPE_LOWERCASE, None) validate_invoke_connection_params(logger, request.query_params, request.path_params) @@ -175,16 +189,54 @@ def render_key(parents): def get_data_from_content_type(data, content_type): converted_data = data files = {} + if content_type == ContentType.URLENCODED.value: converted_data = http_build_query(data) elif content_type == ContentType.FORMDATA.value: - converted_data = r_urlencode(list(), dict(), data) - files = {(None, None)} + print("Hello") + converted_data = None + files = {} + for key, value in data.items(): + files[key] = (None, str(value)) elif content_type == ContentType.JSON.value: converted_data = json.dumps(data) + elif content_type == ContentType.XML.value or content_type == 'application/xml' or content_type == 'text/xml': + if isinstance(data, dict): + converted_data = dict_to_xml(data) + else: + converted_data = str(data) + elif content_type == ContentType.HTML.value or content_type == 'text/html': + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) + else: + if isinstance(data, dict): + converted_data = json.dumps(data) + else: + converted_data = str(data) return converted_data, files +def dict_to_xml(data, root_tag='root'): + def build_xml(d, tag='item'): + if isinstance(d, dict): + xml_parts = [f'<{tag}>'] + for key, value in d.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + elif isinstance(d, list): + return ''.join([build_xml(item, tag) for item in d]) + else: + return f'<{tag}>{d}' + + xml_parts = [f'<{root_tag}>'] + for key, value in data.items(): + xml_parts.append(build_xml(value, key)) + xml_parts.append(f'') + return ''.join(xml_parts) + def get_metrics(): sdk_name_version = "skyflow-python@" + SDK_VERSION @@ -346,39 +398,50 @@ def parse_invoke_connection_response(api_response: requests.Response): content = api_response.content if isinstance(content, bytes): content = content.decode(EncodingType.UTF_8) + try: api_response.raise_for_status() - try: - data = json.loads(content) - metadata = {} - if HttpHeader.X_REQUEST_ID in api_response.headers: - metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] + + content_type = api_response.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE, '').lower() + + if ContentTypeConstants.APPLICATION_JSON in content_type or not content_type: + try: + data = json.loads(content) + except json.JSONDecodeError: + data = content + else: + data = content + + metadata = {} + if HttpHeader.X_REQUEST_ID in api_response.headers: + metadata['request_id'] = api_response.headers[HttpHeader.X_REQUEST_ID] - return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) - except Exception as e: - raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) + return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) + except HTTPError: message = SkyflowMessages.Error.API_ERROR.value.format(status_code) + request_id = api_response.headers.get(HttpHeader.X_REQUEST_ID) + try: - error_response = json.loads(content) - request_id = api_response.headers[HttpHeader.X_REQUEST_ID] + error_response = json.loads(content) error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT) - status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found + status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, status_code) http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS) grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE) details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS) message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) if error_from_client is not None: - if details is None: details = [] + if details is None: + details = [] error_from_client_bool = error_from_client.lower() == BooleanString.TRUE details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool}) raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details) + except json.JSONDecodeError: - message = SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content) - raise SkyflowError(message, status_code) + raise SkyflowError(content if content else message, status_code, request_id) def parse_deidentify_text_response(api_response: DeidentifyStringResponse): entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities] @@ -396,9 +459,15 @@ def log_and_reject_error(description, status_code, request_id, http_status=None, raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details) def handle_exception(error, logger): - # handle invalid cluster ID error scenario - if (isinstance(error, httpx.ConnectError)): - handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger) + if isinstance(error, httpx.ConnectError): + description = SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=logger) + return + + if not hasattr(error, 'headers') or not hasattr(error, 'body') or error.headers is None or error.body is None: + description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value + log_and_reject_error(description, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=logger) + return request_id = error.headers.get(HttpHeader.X_REQUEST_ID, 'unknown-request-id') content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE) @@ -410,9 +479,9 @@ def handle_exception(error, logger): elif ContentTypeConstants.TEXT_PLAIN in content_type: handle_text_error(error, data, request_id, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) else: - handle_generic_error(error, request_id, logger) + handle_generic_error_with_status(error, request_id, error.status, logger) def handle_json_error(err, data, request_id, logger): try: @@ -435,12 +504,9 @@ def handle_json_error(err, data, request_id, logger): def handle_text_error(err, data, request_id, logger): log_and_reject_error(data, err.status, request_id, logger = logger) -def handle_generic_error(err, request_id, logger): - handle_generic_error(err, request_id, err.status, logger = logger) - -def handle_generic_error(err, request_id, status, logger): +def handle_generic_error_with_status(err, request_id, status, logger): description = SkyflowMessages.Error.GENERIC_API_ERROR.value - log_and_reject_error(description, status, request_id, logger = logger) + log_and_reject_error(description, status, request_id, logger=logger) def encode_column_values(get_request): encoded_column_values = list() diff --git a/skyflow/utils/enums/content_types.py b/skyflow/utils/enums/content_types.py index 362c286a..f2db5b92 100644 --- a/skyflow/utils/enums/content_types.py +++ b/skyflow/utils/enums/content_types.py @@ -5,4 +5,5 @@ class ContentType(Enum): PLAINTEXT = 'text/plain' XML = 'text/xml' URLENCODED = 'application/x-www-form-urlencoded' - FORMDATA = 'multipart/form-data' \ No newline at end of file + FORMDATA = 'multipart/form-data' + HTML = 'text/html' \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 779fdfcc..ae0c2353 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -123,8 +123,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non ) if is_expired(credentials.get("token"), logger): raise SkyflowError( - SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) - if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, + SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value + if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value, invalid_input_error_code ) elif "api_key" in credentials: @@ -229,9 +229,7 @@ def validate_connection_config(logger, config): ) if "credentials" not in config: - raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", connection_id), invalid_input_error_code) - - validate_credentials(logger, config.get("credentials"), "connection", connection_id) + validate_credentials(logger, config.get("credentials"), "connection", connection_id) return True @@ -390,7 +388,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if hasattr(request, 'wait_time') and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: + if request.wait_time < 0 or request.wait_time > 64: raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): @@ -414,9 +412,6 @@ def validate_insert_request(logger, request): if key is None or key == "": log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format("INSERT"), logger = logger) - if value is None or value == "": - log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format("INSERT", key), logger = logger) - if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value("INSERT"), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code) @@ -575,7 +570,7 @@ def validate_get_request(logger, request): if column_name and not column_values: log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger) - SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code) + raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUES.value, invalid_input_error_code) if (column_name or column_values) and skyflow_ids: log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format("GET"), logger = logger) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index 2d77330e..45234a40 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -1,3 +1,4 @@ +from skyflow.error import SkyflowError from skyflow.generated.rest.client import Skyflow from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages @@ -86,7 +87,7 @@ def get_bearer_token(self, credentials): if is_expired(self.__bearer_token): self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) return self.__bearer_token diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 83b0ffbd..8aafc586 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -6,6 +6,7 @@ from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW +from skyflow.utils import get_credentials class Connection: @@ -13,15 +14,17 @@ def __init__(self, vault_client): self.__vault_client = vault_client def invoke(self, request: InvokeConnectionRequest): - session = requests.Session() - + log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) config = self.__vault_client.get_config() - bearer_token = self.__vault_client.get_bearer_token(config.get("credentials")) - connection_url = config.get("connection_url") - log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger()) invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) + + credentials = get_credentials(config.get("credentials"), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger()) + + bearer_token = self.__vault_client.get_bearer_token(credentials) + + session = requests.Session() if not 'X-Skyflow-Authorization'.lower() in invoke_connection_request.headers: invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6eaacf47..7ffe93ad 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -394,14 +394,18 @@ def test_parse_invoke_connection_response_successful(self, mock_response): @patch("requests.Response") def test_parse_invoke_connection_response_json_decode_error(self, mock_response): - + """Test that non-JSON content in successful response is returned as string.""" mock_response.status_code = 200 mock_response.content = "Non-JSON Content".encode('utf-8') + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() - with self.assertRaises(SkyflowError) as context: - parse_invoke_connection_response(mock_response) + result = parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Non-JSON Content")) + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Non-JSON Content") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) @patch("requests.Response") def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response): @@ -428,7 +432,9 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message( with self.assertRaises(SkyflowError) as context: parse_invoke_connection_response(mock_response) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Internal Server Error")) + self.assertEqual(context.exception.message, "Internal Server Error") + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") @patch("skyflow.utils._utils.log_and_reject_error") def test_handle_exception_json_error(self, mock_log_and_reject_error): @@ -597,3 +603,650 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self): self.assertEqual(result.text_index.end, 0) self.assertEqual(result.processed_index.start, 0) self.assertEqual(result.processed_index.end, 0) + + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_connect_error(self, mock_log_and_reject_error): + """Test handling httpx.ConnectError.""" + import httpx + mock_error = httpx.ConnectError("Connection refused") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + SkyflowMessages.ErrorCodes.INVALID_INPUT.value, + None, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error): + """Test handling error without headers attribute.""" + mock_error = Exception("Generic error") + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Generic error", + SkyflowMessages.ErrorCodes.SERVER_ERROR.value, + None, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error): + """Test handling error without body attribute.""" + mock_error = Mock() + mock_error.headers = {"x-request-id": "12345"} + delattr(mock_error, 'body') + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + self.assertEqual( + mock_log_and_reject_error.call_args[0][1], + SkyflowMessages.ErrorCodes.SERVER_ERROR.value + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_text_plain_error(self, mock_log_and_reject_error): + """Test handling text/plain content type error.""" + mock_error = Mock() + mock_error.headers = { + 'x-request-id': '1234', + 'content-type': 'text/plain' + } + mock_error.body = "Plain text error message" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Plain text error message", + 500, + "1234", + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error): + """Test handling generic error with unknown content type.""" + mock_error = Mock() + mock_error.headers = { + 'x-request-id': '1234', + 'content-type': 'application/xml' + } + mock_error.body = "XML error" + mock_error.status = 503 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + 503, + "1234", + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_no_content_type(self, mock_log_and_reject_error): + """Test handling error without content-type header.""" + mock_error = Mock() + mock_error.headers = {'x-request-id': '1234'} + mock_error.body = "Some error" + mock_error.status = 500 + mock_logger = Mock() + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + 500, + "1234", + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_json_string(self, mock_log_and_reject_error): + """Test handling JSON error when data is a JSON string.""" + error_json_string = json.dumps({ + "error": { + "message": "String JSON error", + "http_code": 422, + "http_status": "Unprocessable Entity", + "grpc_code": 3, + "details": ["validation failed"] + } + }) + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-3" + + handle_json_error(mock_error, error_json_string, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "String JSON error", + 422, + request_id, + "Unprocessable Entity", + 3, + ["validation failed"], + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error): + """Test handling JSON decode error.""" + invalid_json = "This is not valid JSON" + mock_error = Mock() + mock_error.status = 500 + mock_logger = Mock() + request_id = "test-request-id-4" + + handle_json_error(mock_error, invalid_json, request_id, mock_logger) + + # Should call with INVALID_JSON_RESPONSE error + mock_log_and_reject_error.assert_called_once() + self.assertEqual( + mock_log_and_reject_error.call_args[0][0], + SkyflowMessages.Error.INVALID_JSON_RESPONSE.value + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error): + """Test handling JSON error with missing error field.""" + error_dict = { + "message": "Error without error wrapper" + } + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-5" + + handle_json_error(mock_error, error_dict, request_id, mock_logger) + + # Should use defaults for missing fields + mock_log_and_reject_error.assert_called_once() + args = mock_log_and_reject_error.call_args[0] + # Default message when error field is missing + self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value) + # Default status code + self.assertEqual(args[1], 500) + self.assertEqual(args[2], request_id) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_text_error_with_status(self, mock_log_and_reject_error): + """Test handle_text_error extracts status correctly.""" + mock_error = Mock() + mock_error.status = 404 + mock_logger = Mock() + request_id = "test-request-id-6" + error_data = "Resource not found" + + from skyflow.utils._utils import handle_text_error + handle_text_error(mock_error, error_data, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Resource not found", + 404, + request_id, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_generic_error_with_status(self, mock_log_and_reject_error): + """Test handle_generic_error_with_status.""" + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-7" + status = 503 + + from skyflow.utils._utils import handle_generic_error_with_status + handle_generic_error_with_status(mock_error, request_id, status, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + 503, + request_id, + logger=mock_logger + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_none_error(self, mock_log_and_reject_error): + """Test handling None error object.""" + mock_logger = Mock() + + handle_exception(None, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + SkyflowMessages.Error.GENERIC_API_ERROR.value, + SkyflowMessages.ErrorCodes.SERVER_ERROR.value, + None, + logger=mock_logger + ) + + #failed + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error): + """Test handling empty string error.""" + mock_logger = Mock() + mock_error = Mock() + mock_error.headers = None + mock_error.body = None + + handle_exception(mock_error, mock_logger) + + mock_log_and_reject_error.assert_called_once() + # Should use str(error) or default message + self.assertEqual( + mock_log_and_reject_error.call_args[0][1], + SkyflowMessages.ErrorCodes.SERVER_ERROR.value + ) + + @patch("skyflow.utils._utils.log_and_reject_error") + def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error): + """Test handling JSON error when data is bytes.""" + error_dict = { + "error": { + "message": "Bytes error", + "http_code": 401, + "http_status": "Unauthorized" + } + } + error_bytes = json.dumps(error_dict).encode('utf-8') + + mock_error = Mock() + mock_logger = Mock() + request_id = "test-request-id-8" + + handle_json_error(mock_error, error_bytes, request_id, mock_logger) + + mock_log_and_reject_error.assert_called_once_with( + "Bytes error", + 401, + request_id, + "Unauthorized", + None, + [], + logger=mock_logger + ) + + # Add these new test methods to the TestUtils class: + + def test_construct_invoke_connection_request_with_no_headers(self): + """Test construct_invoke_connection_request when headers are None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param1": "value1"} + mock_connection_request.headers = None + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {"query": "test"} + + connection_url = "https://example.com/{param1}/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Headers should be None when not provided + self.assertIsNone(result.headers.get('Content-Type')) + + def test_construct_invoke_connection_request_with_xml_content_type(self): + """Test construct_invoke_connection_request with XML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/xml"} + mock_connection_request.body = {"root": {"child": "value"}} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers['content-type'], 'application/xml') + # Body should be converted to XML + self.assertIn('', result.body) + self.assertIn('value', result.body) + + def test_construct_invoke_connection_request_with_html_content_type(self): + """Test construct_invoke_connection_request with HTML content type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "text/html"} + mock_connection_request.body = {"message": "Hello"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertEqual(result.headers['content-type'], 'text/html') + # Body should be JSON string for HTML + self.assertEqual(result.body, json.dumps({"message": "Hello"})) + + def test_construct_invoke_connection_request_multipart_removes_content_type(self): + """Test that Content-Type is removed for multipart/form-data.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value} + mock_connection_request.body = {"field1": "value1", "field2": "value2"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + # Content-Type should be auto-generated by requests library + self.assertIn('multipart/form-data', result.headers.get('Content-Type', '')) + self.assertIn('boundary=', result.headers.get('Content-Type', '')) + + def test_construct_invoke_connection_request_with_no_body(self): + """Test construct_invoke_connection_request when body is None.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertIsInstance(result, PreparedRequest) + self.assertIsNone(result.body) + + def test_get_data_from_content_type_url_encoded(self): + """Test get_data_from_content_type with URL encoded content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key1": "value1", "key2": "value2"} + content_type = ContentType.URLENCODED.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, "key1=value1&key2=value2") + self.assertEqual(files, {}) + + def test_get_data_from_content_type_form_data(self): + """Test get_data_from_content_type with form data content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"field1": "value1", "field2": "value2"} + content_type = ContentType.FORMDATA.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIsNone(converted_data) + self.assertEqual(files["field1"], (None, "value1")) + self.assertEqual(files["field2"], (None, "value2")) + + def test_get_data_from_content_type_json(self): + """Test get_data_from_content_type with JSON content type.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = ContentType.JSON.value + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_dict(self): + """Test get_data_from_content_type with XML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"root": {"child": "value"}} + content_type = "application/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertIn("", converted_data) + self.assertIn("value", converted_data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_xml_with_string(self): + """Test get_data_from_content_type with XML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "value" + content_type = "text/xml" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_dict(self): + """Test get_data_from_content_type with HTML content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"message": "Hello"} + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_html_with_string(self): + """Test get_data_from_content_type with HTML content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "Hello" + content_type = "text/html" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_dict(self): + """Test get_data_from_content_type with unknown content type and dict data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = {"key": "value"} + content_type = "application/custom" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, json.dumps(data)) + self.assertEqual(files, {}) + + def test_get_data_from_content_type_unknown_type_with_string(self): + """Test get_data_from_content_type with unknown content type and string data.""" + from skyflow.utils._utils import get_data_from_content_type + + data = "plain text data" + content_type = "text/plain" + + converted_data, files = get_data_from_content_type(data, content_type) + + self.assertEqual(converted_data, data) + self.assertEqual(files, {}) + + def test_dict_to_xml_simple_dict(self): + """Test dict_to_xml with simple dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"name": "John", "age": "30"} + result = dict_to_xml(data) + + self.assertIn("John", result) + self.assertIn("30", result) + self.assertTrue(result.startswith("")) + self.assertTrue(result.endswith("")) + + def test_dict_to_xml_nested_dict(self): + """Test dict_to_xml with nested dictionary.""" + from skyflow.utils._utils import dict_to_xml + + data = {"person": {"name": "John", "age": "30"}} + result = dict_to_xml(data) + + self.assertIn("", result) + self.assertIn("John", result) + self.assertIn("30", result) + + def test_dict_to_xml_with_list(self): + """Test dict_to_xml with list values.""" + from skyflow.utils._utils import dict_to_xml + + data = {"items": ["item1", "item2", "item3"]} + result = dict_to_xml(data) + + self.assertIn("item1", result) + self.assertIn("item2", result) + self.assertIn("item3", result) + + @patch("requests.Response") + def test_parse_invoke_connection_response_xml_content(self, mock_response): + """Test parsing XML response content.""" + mock_response.status_code = 200 + mock_response.content = b"success" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/xml" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_url_encoded_content(self, mock_response): + """Test parsing URL encoded response content.""" + mock_response.status_code = 200 + mock_response.content = b"card_number=4111111111111111&cvv=123" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/x-www-form-urlencoded" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "card_number=4111111111111111&cvv=123") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_content(self, mock_response): + """Test parsing HTML response content.""" + mock_response.status_code = 200 + mock_response.content = b"Success" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "text/html" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_html_error(self, mock_response): + """Test parsing HTML error response.""" + html_error = "

Error 500

" + mock_response.status_code = 500 + mock_response.content = html_error.encode('utf-8') + mock_response.headers = { + "x-request-id": "1234", + "content-type": "text/html" + } + mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error")) + + with self.assertRaises(SkyflowError) as context: + parse_invoke_connection_response(mock_response) + + self.assertEqual(context.exception.message, html_error) + self.assertEqual(context.exception.http_code, 500) + self.assertEqual(context.exception.request_id, "1234") + + @patch("requests.Response") + def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, mock_response): + """Test that JSON decode error falls back to returning string content.""" + mock_response.status_code = 200 + mock_response.content = b"Not valid JSON but still success" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/json" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Not valid JSON but still success") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response): + """Test parsing response with no content-type but valid JSON.""" + mock_response.status_code = 200 + mock_response.content = json.dumps({"success": True}).encode('utf-8') + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, {"success": True}) + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_no_content_type_with_text(self, mock_response): + """Test parsing response with no content-type and non-JSON content.""" + mock_response.status_code = 200 + mock_response.content = b"Plain text response" + mock_response.headers = {"x-request-id": "1234"} + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Plain text response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) + + @patch("requests.Response") + def test_parse_invoke_connection_response_bytes_content(self, mock_response): + """Test parsing response with bytes content.""" + mock_response.status_code = 200 + mock_response.content = b"Binary data response" + mock_response.headers = { + "x-request-id": "1234", + "content-type": "application/octet-stream" + } + mock_response.raise_for_status = Mock() + + result = parse_invoke_connection_response(mock_response) + + self.assertIsInstance(result, InvokeConnectionResponse) + self.assertEqual(result.data, "Binary data response") + self.assertEqual(result.metadata["request_id"], "1234") + self.assertIsNone(result.errors) diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..36b74c20 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -1044,3 +1044,66 @@ def test_validate_detokenize_request_invalid_redaction_type(self): with self.assertRaises(SkyflowError) as context: validate_detokenize_request(self.logger, request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) + + + def test_validate_deidentify_file_request_wait_time_negative(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=-1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_greater_than_64(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=65, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_lower(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=0, + entities=[DetectEntities.SSN] + ) + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_boundary_upper(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_valid_float(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=32.5, + entities=[DetectEntities.SSN] + ) + # Should not raise an error + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_wait_time_float_out_of_range(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time=64.1, + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value) diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 4ccad1c7..e8fb4abe 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock import requests from skyflow.error import SkyflowError from skyflow.utils import SkyflowMessages, parse_invoke_connection_response @@ -30,10 +30,16 @@ def setUp(self): self.mock_vault_client = Mock() self.mock_vault_client.get_config.return_value = VAULT_CONFIG self.mock_vault_client.get_bearer_token.return_value = VALID_BEARER_TOKEN + self.mock_vault_client.get_logger.return_value = Mock() + self.mock_vault_client.get_common_skyflow_credentials.return_value = None self.connection = Connection(self.mock_vault_client) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_success(self, mock_send): + def test_invoke_success(self, mock_send, mock_get_credentials): + # Mock get_credentials to return credentials + mock_get_credentials.return_value = {"api_key": "test_api_key"} + # Mocking successful response mock_response = Mock() mock_response.status_code = SUCCESS_STATUS_CODE @@ -60,9 +66,36 @@ def test_invoke_success(self, mock_send): } self.assertEqual(vars(response), expected_response) self.mock_vault_client.get_bearer_token.assert_called_once() + mock_get_credentials.assert_called_once() + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_invalid_headers(self, mock_send): + def test_invoke_with_x_skyflow_authorization_already_present(self, mock_send, mock_get_credentials): + """Test that X-Skyflow-Authorization is not overwritten if already present in headers.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + custom_auth = "custom_bearer_token" + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers={"x-skyflow-authorization": custom_auth} + ) + + response = self.connection.invoke(request) + + # Verify bearer token from vault_client is NOT used + self.assertIsNotNone(response) + + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_headers(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=VALID_BODY, @@ -75,8 +108,10 @@ def test_invoke_invalid_headers(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) - @patch('requests.Session.send') - def test_invoke_invalid_body(self, mock_send): + @patch('skyflow.vault.controller._connections.get_credentials') + def test_invoke_invalid_body(self, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + request = InvokeConnectionRequest( method="POST", body=INVALID_BODY, @@ -89,8 +124,11 @@ def test_invoke_invalid_body(self, mock_send): self.connection.invoke(request) self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') - def test_invoke_request_error(self, mock_send): + def test_invoke_request_error(self, mock_send, mock_get_credentials): + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_response = Mock() mock_response.status_code = FAILURE_STATUS_CODE mock_response.content = ERROR_RESPONSE_CONTENT @@ -106,10 +144,98 @@ def test_invoke_request_error(self, mock_send): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - self.assertEqual(context.exception.message, f'Skyflow Python SDK {SDK_VERSION} Response {ERROR_RESPONSE_CONTENT} is not valid JSON.') self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(ERROR_RESPONSE_CONTENT)) self.assertEqual(context.exception.http_code, 400) + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_send_exception(self, mock_send, mock_get_credentials): + """Test handling of generic exception from session.send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_send.side_effect = Exception("Network error") + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_skyflow_error_re_raised(self, mock_send, mock_get_credentials): + """Test that SkyflowError is re-raised without wrapping.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + original_error = SkyflowError("Original error", 401) + mock_send.side_effect = original_error + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + with self.assertRaises(SkyflowError) as context: + self.connection.invoke(request) + # Should be the same original error + self.assertEqual(context.exception.message, "Original error") + self.assertEqual(context.exception.http_code, 401) + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('requests.Session.send') + def test_invoke_session_close_called(self, mock_send, mock_get_credentials): + """Test that session.close() is called after send().""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + with patch('requests.Session.close') as mock_close: + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify close was called + mock_close.assert_called_once() + + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.get_metrics') + @patch('requests.Session.send') + def test_invoke_adds_sky_metadata_header(self, mock_send, mock_get_metrics, mock_get_credentials): + """Test that sky-metadata header is added to request.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + mock_get_metrics.return_value = {"sdk_version": SDK_VERSION} + + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.POST, + body=VALID_BODY, + headers=VALID_HEADERS + ) + + response = self.connection.invoke(request) + + # Verify get_metrics was called + mock_get_metrics.assert_called_once() + self.assertIsNotNone(response) + def test_parse_invoke_connection_response_error_from_client(self): mock_response = Mock(spec=requests.Response) mock_response.status_code = FAILURE_STATUS_CODE @@ -128,3 +254,37 @@ def test_parse_invoke_connection_response_error_from_client(self): self.assertTrue(any(detail.get('error_from_client') == True for detail in exception.details)) self.assertEqual(exception.request_id, '12345') + @patch('skyflow.vault.controller._connections.get_credentials') + @patch('skyflow.vault.controller._connections.construct_invoke_connection_request') + def test_invoke_construct_request_called(self, mock_construct, mock_get_credentials): + """Test that construct_invoke_connection_request is called with correct parameters.""" + mock_get_credentials.return_value = {"api_key": "test_api_key"} + + mock_prepared_request = Mock(spec=requests.PreparedRequest) + mock_prepared_request.headers = {} + mock_construct.return_value = mock_prepared_request + + with patch('requests.Session.send') as mock_send: + mock_response = Mock() + mock_response.status_code = SUCCESS_STATUS_CODE + mock_response.content = SUCCESS_RESPONSE_CONTENT + mock_response.headers = {'x-request-id': 'test-request-id'} + mock_send.return_value = mock_response + + request = InvokeConnectionRequest( + method=RequestMethod.GET, + headers=VALID_HEADERS + ) + + self.connection.invoke(request) + + # Verify construct was called with connection_url from config + mock_construct.assert_called_once_with( + request, + VAULT_CONFIG["connection_url"], + self.mock_vault_client.get_logger() + ) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 48c5b7b3847e1ed189bb2d5f9f18c9ef455a5e40 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 2 Feb 2026 16:50:30 +0530 Subject: [PATCH 2/4] SK-2522: fix unit tests --- tests/utils/validations/test__validations.py | 2 +- tests/vault/controller/test__connection.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 36b74c20..4f3b5487 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -116,7 +116,7 @@ def test_validate_credentials_with_expired_token(self): with patch('skyflow.service_account.is_expired', return_value=True): with self.assertRaises(SkyflowError) as context: validate_credentials(self.logger, credentials) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value) def test_validate_credentials_empty_credentials(self): credentials = {} diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index e8fb4abe..35a13716 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -131,7 +131,9 @@ def test_invoke_request_error(self, mock_send, mock_get_credentials): mock_response = Mock() mock_response.status_code = FAILURE_STATUS_CODE - mock_response.content = ERROR_RESPONSE_CONTENT + mock_response.content = ERROR_RESPONSE_CONTENT.encode('utf-8') # Convert to bytes + mock_response.headers = {"x-request-id": "test-request-id"} + mock_response.raise_for_status.side_effect = requests.HTTPError("400 Error") mock_send.return_value = mock_response request = InvokeConnectionRequest( @@ -144,8 +146,10 @@ def test_invoke_request_error(self, mock_send, mock_get_credentials): with self.assertRaises(SkyflowError) as context: self.connection.invoke(request) - self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(ERROR_RESPONSE_CONTENT)) - self.assertEqual(context.exception.http_code, 400) + + self.assertEqual(context.exception.message, ERROR_RESPONSE_CONTENT) + self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE) + self.assertEqual(context.exception.request_id, "test-request-id") @patch('skyflow.vault.controller._connections.get_credentials') @patch('requests.Session.send') From a19fb00c8c2d31912f6f6c1b1fae212de9a99de2 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 2 Feb 2026 17:35:38 +0530 Subject: [PATCH 3/4] SK-2522: add unit tests --- tests/utils/test__utils.py | 167 +++++++++++++++++++++++++++++ tests/vault/client/test__client.py | 86 ++++++++++++++- 2 files changed, 252 insertions(+), 1 deletion(-) diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 7ffe93ad..09195b89 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -1250,3 +1250,170 @@ def test_parse_invoke_connection_response_bytes_content(self, mock_response): self.assertEqual(result.data, "Binary data response") self.assertEqual(result.metadata["request_id"], "1234") self.assertIsNone(result.errors) + + def test_construct_invoke_connection_request_headers_json_error(self): + """Test exception handling when json.dumps fails for headers.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + + class UnserializableObject: + def __repr__(self): + raise TypeError("Object is not JSON serializable") + + mock_connection_request.headers = {"key": UnserializableObject()} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('json.dumps', side_effect=TypeError("Object is not JSON serializable")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_headers_generic_exception(self): + """Test generic exception handling for headers processing.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": "application/json"} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('skyflow.utils._utils.to_lowercase_keys', side_effect=Exception("Generic error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_processing_exception(self): + """Test exception handling when body processing fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = {"key": "value"} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('skyflow.utils._utils.get_data_from_content_type', side_effect=Exception("Body processing error")): + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_json_dumps_exception(self): + """Test exception handling when json.dumps fails in get_data_from_content_type.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + + class UnserializableObject: + pass + + mock_connection_request.body = {"key": UnserializableObject()} + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_invalid_url_exception(self): + """Test exception handling when requests.Request.prepare() fails with invalid URL.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('requests.Request') as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Invalid URL structure") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_URL.value.format(connection_url) + ) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_prepare_exception(self): + """Test exception handling when prepare() method fails.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with patch('requests.Request') as mock_request_class: + mock_request_instance = Mock() + mock_request_instance.prepare.side_effect = Exception("Prepare failed") + mock_request_class.return_value = mock_request_instance + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual( + context.exception.message, + SkyflowMessages.Error.INVALID_URL.value.format(connection_url) + ) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + def test_construct_invoke_connection_request_body_not_dict_raises_error(self): + """Test that non-dict body raises SkyflowError which is caught and re-raised.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {} + mock_connection_request.headers = {"Content-Type": ContentType.JSON.value} + mock_connection_request.body = "not a dict" # Invalid body type + mock_connection_request.method.value = "POST" + mock_connection_request.query_params = {} + + connection_url = "https://example.com/endpoint" + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + + @patch('skyflow.utils._utils.validate_invoke_connection_params') + def test_construct_invoke_connection_request_validation_exception(self, mock_validate): + """Test that validation exceptions are properly propagated.""" + mock_connection_request = Mock() + mock_connection_request.path_params = {"param": "value"} + mock_connection_request.headers = None + mock_connection_request.body = None + mock_connection_request.method.value = "GET" + mock_connection_request.query_params = {"query": "value"} + + connection_url = "https://example.com/endpoint" + + mock_validate.side_effect = SkyflowError("Validation failed", 400) + + with self.assertRaises(SkyflowError) as context: + construct_invoke_connection_request(mock_connection_request, connection_url, logger=None) + + self.assertEqual(context.exception.message, "Validation failed") + self.assertEqual(context.exception.http_code, 400) diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..619c15ec 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -1,5 +1,8 @@ import unittest from unittest.mock import patch, MagicMock + +from skyflow.error import SkyflowError +from skyflow.utils import SkyflowMessages from skyflow.vault.client.client import VaultClient CONFIG = { @@ -97,4 +100,85 @@ def test_get_log_level(self): def test_get_logger(self): mock_logger = MagicMock() self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token") + def test_get_bearer_token_expired_token_raises_error(self, mock_generate_bearer_token, mock_is_expired): + """Test that expired token raises SkyflowError.""" + credentials = {"path": "/path/to/credentials.json"} + mock_generate_bearer_token.return_value = ("expired_token", None) + mock_is_expired.return_value = True + + with self.assertRaises(SkyflowError) as context: + self.vault_client.get_bearer_token(credentials) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + def test_get_bearer_token_expired_token_from_creds_string_raises_error(self, mock_generate_bearer_token_from_creds, mock_is_expired): + """Test that expired token from credentials string raises SkyflowError.""" + credentials = {"credentials_string": '{"key": "value"}'} + mock_generate_bearer_token_from_creds.return_value = ("expired_token", None) + mock_is_expired.return_value = True + + with self.assertRaises(SkyflowError) as context: + self.vault_client.get_bearer_token(credentials) + + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) + self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) + self.assertTrue(self.vault_client._VaultClient__is_config_updated) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token") + def test_get_bearer_token_reuses_valid_token(self, mock_generate_bearer_token, mock_is_expired): + """Test that valid bearer token is reused.""" + credentials = {"path": "/path/to/credentials.json"} + mock_generate_bearer_token.return_value = ("valid_token", None) + mock_is_expired.return_value = False + + token1 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token1, "valid_token") + mock_generate_bearer_token.assert_called_once() + + token2 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token2, "valid_token") + mock_generate_bearer_token.assert_called_once() + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token") + def test_get_bearer_token_regenerates_after_config_update(self, mock_generate_bearer_token, mock_is_expired): + """Test that bearer token is regenerated after config update.""" + credentials = {"path": "/path/to/credentials.json"} + mock_generate_bearer_token.side_effect = [("first_token", None), ("second_token", None)] + mock_is_expired.return_value = False + + token1 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token1, "first_token") + + self.vault_client.update_config({"new_key": "new_value"}) + + token2 = self.vault_client.get_bearer_token(credentials) + self.assertEqual(token2, "second_token") + self.assertEqual(mock_generate_bearer_token.call_count, 2) + + @patch("skyflow.vault.client.client.is_expired") + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_with_credentials_string(self, mock_log_info, mock_generate_bearer_token_from_creds, mock_is_expired): + """Test get_bearer_token with credentials_string.""" + credentials = {"credentials_string": '{"clientID": "test", "clientName": "test"}'} + mock_generate_bearer_token_from_creds.return_value = ("token_from_creds", None) + mock_is_expired.return_value = False + + token = self.vault_client.get_bearer_token(credentials) + + self.assertEqual(token, "token_from_creds") + mock_generate_bearer_token_from_creds.assert_called_once() + mock_log_info.assert_called_with( + SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value, + None + ) From b6e83c6f199d16653b93182e84e914ad7f57c31e Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Thu, 5 Feb 2026 18:13:20 +0530 Subject: [PATCH 4/4] SK-2522: resolve copilot comments --- skyflow/utils/_skyflow_messages.py | 4 ++-- skyflow/utils/constants.py | 2 ++ skyflow/utils/validations/_validations.py | 4 ++-- skyflow/vault/controller/_connections.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 068a1e23..6a31c078 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -48,7 +48,7 @@ class Error(Enum): INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." EXPIRED_BEARER_TOKEN = f"{error_prefix} Initialization failed. Bearer token is invalid or expired." - EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." + EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string." @@ -123,7 +123,7 @@ class Error(Enum): INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list." INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed." INVALID_FIELDS_VALUE = f"{error_prefix} Validation error. fields key has a value of type{{}}. Specify fields key as list." - BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"${error_prefix} Validation error. Both offset and limit cannot be present at the same time" + BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"{error_prefix} Validation error. Both offset and limit cannot be present at the same time" INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer." INVALID_LIMIT_VALUE = f"{error_prefix} Validation error. limit key has a value of type {{}}. Specify limit key as integer." INVALID_DOWNLOAD_URL_VALUE = f"{error_prefix} Validation error. download_url key has a value of type {{}}. Specify download_url key as boolean." diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py index 62aa4d11..401bffe5 100644 --- a/skyflow/utils/constants.py +++ b/skyflow/utils/constants.py @@ -35,6 +35,8 @@ class DetectStatus: FAILED = 'FAILED' UNKNOWN = 'UNKNOWN' +class Detect: + WAIT_TIME = 64 class FileExtension: JSON = 'json' diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 74bc26ce..4e3ead8a 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -9,7 +9,7 @@ from skyflow.utils.constants import ( ApiKey, ResponseField, RequestParameter, FileUploadField, - DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField + DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField, Detect ) from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ @@ -406,7 +406,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 or request.wait_time > 64: + if request.wait_time < 0 or request.wait_time > Detect.WAIT_TIME: raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index 7634c99c..76dbfaeb 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -5,7 +5,7 @@ parse_invoke_connection_response from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.connection import InvokeConnectionRequest -from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW +from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader from skyflow.utils import get_credentials