diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py
index 21665972..6a31c078 100644
--- a/skyflow/utils/_skyflow_messages.py
+++ b/skyflow/utils/_skyflow_messages.py
@@ -42,12 +42,13 @@ class Error(Enum):
EMPTY_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Specify a valid file path."
EMPTY_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Specify a valid file path."
INVALID_CREDENTIAL_FILE_PATH_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials for {{}} with id {{}}. Expected file path to be a string."
- INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a string."
+ INVALID_CREDENTIAL_FILE_PATH = f"{error_prefix} Initialization failed. Invalid credentials. Expected file path to be a valid file path."
EMPTY_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid token for {{}} with id {{}}.Specify a valid credentials token."
EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token."
INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string."
INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string."
- EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
+ EXPIRED_BEARER_TOKEN = f"{error_prefix} Initialization failed. Bearer token is invalid or expired."
+ EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key."
EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key."
INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string."
@@ -118,10 +119,11 @@ class Error(Enum):
INVALID_IDS_TYPE = f"{error_prefix} Validation error. 'ids' has a value of type {{}}. Specify 'ids' as list."
INVALID_REDACTION_TYPE = f"{error_prefix} Validation error. 'redaction' has a value of type {{}}. Specify 'redaction' as type Skyflow.RedactionType."
- INVALID_COLUMN_NAME = f"{error_prefix} Validation error. 'column' has a value of type {{}}. Specify 'column' as a string."
- INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. columnValues key has a value of type {{}}. Specify columnValues key as list."
+ INVALID_COLUMN_NAME = f"{error_prefix} Validation error. column_name has a value of type {{}}. Specify 'column' as a string."
+ INVALID_COLUMN_VALUE = f"{error_prefix} Validation error. column_values key has a value of type {{}}. Specify column_values key as list."
+ INVALID_COLUMN_VALUES = f"{error_prefix} Validation error. column_values key is an empty list. Specify at least one column value when column_name is passed."
INVALID_FIELDS_VALUE = f"{error_prefix} Validation error. fields key has a value of type{{}}. Specify fields key as list."
- BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"${error_prefix} Validation error. Both offset and limit cannot be present at the same time"
+ BOTH_OFFSET_AND_LIMIT_SPECIFIED = f"{error_prefix} Validation error. Both offset and limit cannot be present at the same time"
INVALID_OFF_SET_VALUE = f"{error_prefix} Validation error. offset key has a value of type {{}}. Specify offset key as integer."
INVALID_LIMIT_VALUE = f"{error_prefix} Validation error. limit key has a value of type {{}}. Specify limit key as integer."
INVALID_DOWNLOAD_URL_VALUE = f"{error_prefix} Validation error. download_url key has a value of type {{}}. Specify download_url key as boolean."
@@ -366,7 +368,7 @@ class ErrorLogs(Enum):
SKYFLOW_ID_IS_REQUIRED = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id is required."
EMPTY_SKYFLOW_ID = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Skyflow Id can not be empty."
- COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. ColumnValues are required."
+ COLUMN_VALUES_IS_REQUIRED_TOKENIZE = f"{ERROR}: [{error_prefix}] Invalid {{}} request. column_values are required."
EMPTY_COLUMN_GROUP_IN_COLUMN_VALUES = f"{ERROR}: [{error_prefix}] Invalid {{}} request. Column group can not be null or empty in column values at index %s2."
EMPTY_QUERY= f"{ERROR}: [{error_prefix}] Invalid {{}} request. Query can not be empty."
diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py
index 83c93b0c..567227f7 100644
--- a/skyflow/utils/_utils.py
+++ b/skyflow/utils/_utils.py
@@ -106,27 +106,41 @@ def convert_detected_entity_to_entity_info(detected_entity):
def construct_invoke_connection_request(request, connection_url, logger) -> PreparedRequest:
url = parse_path_params(connection_url.rstrip('/'), request.path_params)
- try:
- if isinstance(request.headers, dict):
- header = to_lowercase_keys(json.loads(
- json.dumps(request.headers)))
- else:
- raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code)
- except Exception:
- raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code)
+ header = None
+ content_type = None
- if not HttpHeader.CONTENT_TYPE.lower() in header:
- header[HttpHeader.CONTENT_TYPE_LOWERCASE] = ContentType.JSON.value
+ if request.headers is not None:
+ try:
+ if isinstance(request.headers, dict):
+ header = to_lowercase_keys(json.loads(
+ json.dumps(request.headers)))
+
+ content_type = header.get(HttpHeader.CONTENT_TYPE_LOWERCASE)
+ else:
+ raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code)
+ except SkyflowError:
+ raise
+ except Exception:
+ raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value, invalid_input_error_code)
- try:
- if isinstance(request.body, dict):
- json_data, files = get_data_from_content_type(
- request.body, header[HttpHeader.CONTENT_TYPE_LOWERCASE]
- )
- else:
+ json_data = None
+ files = {}
+
+ if request.body is not None:
+ try:
+ if isinstance(request.body, dict):
+ json_data, files = get_data_from_content_type(
+ request.body, content_type
+ )
+ else:
+ raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code)
+ except SkyflowError:
+ raise
+ except Exception as e:
raise SkyflowError(SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code)
- except Exception as e:
- raise SkyflowError( SkyflowMessages.Error.INVALID_REQUEST_BODY.value, invalid_input_error_code)
+
+ if files and header and content_type == ContentType.FORMDATA.value:
+ header.pop(HttpHeader.CONTENT_TYPE_LOWERCASE, None)
validate_invoke_connection_params(logger, request.query_params, request.path_params)
@@ -176,16 +190,54 @@ def render_key(parents):
def get_data_from_content_type(data, content_type):
converted_data = data
files = {}
+
if content_type == ContentType.URLENCODED.value:
converted_data = http_build_query(data)
elif content_type == ContentType.FORMDATA.value:
- converted_data = r_urlencode(list(), dict(), data)
- files = {(None, None)}
+ print("Hello")
+ converted_data = None
+ files = {}
+ for key, value in data.items():
+ files[key] = (None, str(value))
elif content_type == ContentType.JSON.value:
converted_data = json.dumps(data)
+ elif content_type == ContentType.XML.value or content_type == 'application/xml' or content_type == 'text/xml':
+ if isinstance(data, dict):
+ converted_data = dict_to_xml(data)
+ else:
+ converted_data = str(data)
+ elif content_type == ContentType.HTML.value or content_type == 'text/html':
+ if isinstance(data, dict):
+ converted_data = json.dumps(data)
+ else:
+ converted_data = str(data)
+ else:
+ if isinstance(data, dict):
+ converted_data = json.dumps(data)
+ else:
+ converted_data = str(data)
return converted_data, files
+def dict_to_xml(data, root_tag='root'):
+ def build_xml(d, tag='item'):
+ if isinstance(d, dict):
+ xml_parts = [f'<{tag}>']
+ for key, value in d.items():
+ xml_parts.append(build_xml(value, key))
+ xml_parts.append(f'{tag}>')
+ return ''.join(xml_parts)
+ elif isinstance(d, list):
+ return ''.join([build_xml(item, tag) for item in d])
+ else:
+ return f'<{tag}>{d}{tag}>'
+
+ xml_parts = [f'<{root_tag}>']
+ for key, value in data.items():
+ xml_parts.append(build_xml(value, key))
+ xml_parts.append(f'{root_tag}>')
+ return ''.join(xml_parts)
+
def get_metrics():
sdk_name_version = SdkPrefix.SKYFLOW_PYTHON + SDK_VERSION
@@ -347,39 +399,50 @@ def parse_invoke_connection_response(api_response: requests.Response):
content = api_response.content
if isinstance(content, bytes):
content = content.decode(EncodingType.UTF_8)
+
try:
api_response.raise_for_status()
- try:
- data = json.loads(content)
- metadata = {}
- if HttpHeader.X_REQUEST_ID in api_response.headers:
- metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID]
+
+ content_type = api_response.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE, '').lower()
+
+ if ContentTypeConstants.APPLICATION_JSON in content_type or not content_type:
+ try:
+ data = json.loads(content)
+ except json.JSONDecodeError:
+ data = content
+ else:
+ data = content
+
+ metadata = {}
+ if HttpHeader.X_REQUEST_ID in api_response.headers:
+ metadata[ResponseField.REQUEST_ID] = api_response.headers[HttpHeader.X_REQUEST_ID]
- return InvokeConnectionResponse(data=data, metadata=metadata, errors=None)
- except Exception as e:
- raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code)
+ return InvokeConnectionResponse(data=data, metadata=metadata, errors=None)
+
except HTTPError:
message = SkyflowMessages.Error.API_ERROR.value.format(status_code)
+ request_id = api_response.headers.get(HttpHeader.X_REQUEST_ID)
+
try:
- error_response = json.loads(content)
- request_id = api_response.headers[HttpHeader.X_REQUEST_ID]
+ error_response = json.loads(content)
error_from_client = api_response.headers.get(HttpHeader.ERROR_FROM_CLIENT)
- status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, 500) # Default to 500 if not found
+ status_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_CODE, status_code)
http_status = error_response.get(ResponseField.ERROR, {}).get(ResponseField.HTTP_STATUS)
grpc_code = error_response.get(ResponseField.ERROR, {}).get(ResponseField.GRPC_CODE)
details = error_response.get(ResponseField.ERROR, {}).get(ResponseField.DETAILS)
message = error_response.get(ResponseField.ERROR, {}).get(ResponseField.MESSAGE, SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value)
if error_from_client is not None:
- if details is None: details = []
+ if details is None:
+ details = []
error_from_client_bool = error_from_client.lower() == BooleanString.TRUE
details.append({ResponseField.ERROR_FROM_CLIENT: error_from_client_bool})
raise SkyflowError(message, status_code, request_id, grpc_code, http_status, details)
+
except json.JSONDecodeError:
- message = SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content)
- raise SkyflowError(message, status_code)
+ raise SkyflowError(content if content else message, status_code, request_id)
def parse_deidentify_text_response(api_response: DeidentifyStringResponse):
entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities]
@@ -397,9 +460,15 @@ def log_and_reject_error(description, status_code, request_id, http_status=None,
raise SkyflowError(description, status_code, request_id, grpc_code, http_status, details)
def handle_exception(error, logger):
- # handle invalid cluster ID error scenario
- if (isinstance(error, httpx.ConnectError)):
- handle_generic_error(error, None, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, logger)
+ if isinstance(error, httpx.ConnectError):
+ description = SkyflowMessages.Error.GENERIC_API_ERROR.value
+ log_and_reject_error(description, SkyflowMessages.ErrorCodes.INVALID_INPUT.value, None, logger=logger)
+ return
+
+ if not hasattr(error, 'headers') or not hasattr(error, 'body') or error.headers is None or error.body is None:
+ description = str(error) if error else SkyflowMessages.Error.GENERIC_API_ERROR.value
+ log_and_reject_error(description, SkyflowMessages.ErrorCodes.SERVER_ERROR.value, None, logger=logger)
+ return
request_id = error.headers.get(HttpHeader.X_REQUEST_ID, ErrorDefaults.UNKNOWN_REQUEST_ID)
content_type = error.headers.get(HttpHeader.CONTENT_TYPE_LOWERCASE)
@@ -411,9 +480,9 @@ def handle_exception(error, logger):
elif ContentTypeConstants.TEXT_PLAIN in content_type:
handle_text_error(error, data, request_id, logger)
else:
- handle_generic_error(error, request_id, logger)
+ handle_generic_error_with_status(error, request_id, error.status, logger)
else:
- handle_generic_error(error, request_id, logger)
+ handle_generic_error_with_status(error, request_id, error.status, logger)
def handle_json_error(err, data, request_id, logger):
try:
@@ -436,12 +505,9 @@ def handle_json_error(err, data, request_id, logger):
def handle_text_error(err, data, request_id, logger):
log_and_reject_error(data, err.status, request_id, logger = logger)
-def handle_generic_error(err, request_id, logger):
- handle_generic_error(err, request_id, err.status, logger = logger)
-
-def handle_generic_error(err, request_id, status, logger):
+def handle_generic_error_with_status(err, request_id, status, logger):
description = SkyflowMessages.Error.GENERIC_API_ERROR.value
- log_and_reject_error(description, status, request_id, logger = logger)
+ log_and_reject_error(description, status, request_id, logger=logger)
def encode_column_values(get_request):
encoded_column_values = list()
diff --git a/skyflow/utils/constants.py b/skyflow/utils/constants.py
index 62aa4d11..401bffe5 100644
--- a/skyflow/utils/constants.py
+++ b/skyflow/utils/constants.py
@@ -35,6 +35,8 @@ class DetectStatus:
FAILED = 'FAILED'
UNKNOWN = 'UNKNOWN'
+class Detect:
+ WAIT_TIME = 64
class FileExtension:
JSON = 'json'
diff --git a/skyflow/utils/enums/content_types.py b/skyflow/utils/enums/content_types.py
index 362c286a..f2db5b92 100644
--- a/skyflow/utils/enums/content_types.py
+++ b/skyflow/utils/enums/content_types.py
@@ -5,4 +5,5 @@ class ContentType(Enum):
PLAINTEXT = 'text/plain'
XML = 'text/xml'
URLENCODED = 'application/x-www-form-urlencoded'
- FORMDATA = 'multipart/form-data'
\ No newline at end of file
+ FORMDATA = 'multipart/form-data'
+ HTML = 'text/html'
\ No newline at end of file
diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py
index 2ac5783c..4e3ead8a 100644
--- a/skyflow/utils/validations/_validations.py
+++ b/skyflow/utils/validations/_validations.py
@@ -9,7 +9,7 @@
from skyflow.utils.constants import (
ApiKey, ResponseField, RequestParameter,
FileUploadField,
- DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField
+ DeidentifyFileRequestField, RequestOperation, ConfigType, SqlCommand, ConfigField, OptionField, CredentialField, Detect
)
from skyflow.utils.logger import log_info, log_error_log
from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \
@@ -142,8 +142,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
)
if is_expired(credentials.get(CredentialField.TOKEN), logger):
raise SkyflowError(
- SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id)
- if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value,
+ SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value
+ if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value,
invalid_input_error_code
)
elif CredentialField.API_KEY in credentials:
@@ -247,10 +247,8 @@ def validate_connection_config(logger, config):
SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format(connection_id)
)
- if ConfigField.CREDENTIALS not in config:
- raise SkyflowError(SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format(ConfigType.CONNECTION, connection_id), invalid_input_error_code)
-
- validate_credentials(logger, config.get(ConfigField.CREDENTIALS), ConfigType.CONNECTION, connection_id)
+ if "credentials" in config:
+ validate_credentials(logger, config.get("credentials"), "connection", connection_id)
return True
@@ -408,7 +406,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest):
if hasattr(request, DeidentifyFileRequestField.WAIT_TIME) and request.wait_time is not None:
if not isinstance(request.wait_time, (int, float)):
raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code)
- if request.wait_time < 0 and request.wait_time > 64: # noqa: PLR2004
+ if request.wait_time < 0 or request.wait_time > Detect.WAIT_TIME:
raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code)
def validate_insert_request(logger, request):
@@ -432,9 +430,6 @@ def validate_insert_request(logger, request):
if key is None or key == "":
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_KEY_IN_VALUES.value.format(RequestOperation.INSERT), logger = logger)
- if value is None or value == "":
- log_error_log(SkyflowMessages.ErrorLogs.EMPTY_OR_NULL_VALUE_IN_VALUES.value.format(RequestOperation.INSERT, key), logger = logger)
-
if request.upsert is not None and (not isinstance(request.upsert, str) or not request.upsert.strip()):
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_UPSERT.value(RequestOperation.INSERT), logger = logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_UPSERT_OPTIONS_TYPE.value, invalid_input_error_code)
@@ -592,8 +587,8 @@ def validate_get_request(logger, request):
raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(column_values)), invalid_input_error_code)
if column_name and not column_values:
- log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format(RequestOperation.GET), logger = logger)
- SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_NAME.value.format(type(column_name)), invalid_input_error_code)
+ log_error_log(SkyflowMessages.ErrorLogs.COLUMN_NAME_IS_REQUIRED.value.format("GET"), logger = logger)
+ raise SkyflowError(SkyflowMessages.Error.INVALID_COLUMN_VALUES.value, invalid_input_error_code)
if (column_name or column_values) and skyflow_ids:
log_error_log(SkyflowMessages.ErrorLogs.BOTH_IDS_AND_COLUMN_NAME_PASSED.value.format(RequestOperation.GET), logger = logger)
diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py
index 2d77330e..45234a40 100644
--- a/skyflow/vault/client/client.py
+++ b/skyflow/vault/client/client.py
@@ -1,3 +1,4 @@
+from skyflow.error import SkyflowError
from skyflow.generated.rest.client import Skyflow
from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired
from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages
@@ -86,7 +87,7 @@ def get_bearer_token(self, credentials):
if is_expired(self.__bearer_token):
self.__is_config_updated = True
- raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+ raise SkyflowError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
return self.__bearer_token
diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py
index ca8c7a1d..76dbfaeb 100644
--- a/skyflow/vault/controller/_connections.py
+++ b/skyflow/vault/controller/_connections.py
@@ -6,6 +6,7 @@
from skyflow.utils.logger import log_info, log_error_log
from skyflow.vault.connection import InvokeConnectionRequest
from skyflow.utils.constants import SKY_META_DATA_HEADER, SKYFLOW, HttpHeader
+from skyflow.utils import get_credentials
class Connection:
@@ -13,15 +14,17 @@ def __init__(self, vault_client):
self.__vault_client = vault_client
def invoke(self, request: InvokeConnectionRequest):
- session = requests.Session()
-
+ log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger())
config = self.__vault_client.get_config()
- bearer_token = self.__vault_client.get_bearer_token(config.get("credentials"))
-
connection_url = config.get("connection_url")
- log_info(SkyflowMessages.Info.VALIDATING_INVOKE_CONNECTION_REQUEST.value, self.__vault_client.get_logger())
invoke_connection_request = construct_invoke_connection_request(request, connection_url, self.__vault_client.get_logger())
log_info(SkyflowMessages.Info.INVOKE_CONNECTION_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
+
+ credentials = get_credentials(config.get("credentials"), self.__vault_client.get_common_skyflow_credentials(), self.__vault_client.get_logger())
+
+ bearer_token = self.__vault_client.get_bearer_token(credentials)
+
+ session = requests.Session()
if not HttpHeader.X_SKYFLOW_AUTHORIZATION_HEADER.lower() in invoke_connection_request.headers:
invoke_connection_request.headers[SKYFLOW.X_SKYFLOW_AUTHORIZATION] = bearer_token
diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py
index 6eaacf47..09195b89 100644
--- a/tests/utils/test__utils.py
+++ b/tests/utils/test__utils.py
@@ -394,14 +394,18 @@ def test_parse_invoke_connection_response_successful(self, mock_response):
@patch("requests.Response")
def test_parse_invoke_connection_response_json_decode_error(self, mock_response):
-
+ """Test that non-JSON content in successful response is returned as string."""
mock_response.status_code = 200
mock_response.content = "Non-JSON Content".encode('utf-8')
+ mock_response.headers = {"x-request-id": "1234"}
+ mock_response.raise_for_status = Mock()
- with self.assertRaises(SkyflowError) as context:
- parse_invoke_connection_response(mock_response)
+ result = parse_invoke_connection_response(mock_response)
- self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Non-JSON Content"))
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "Non-JSON Content")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
@patch("requests.Response")
def test_parse_invoke_connection_response_http_error_with_json_error_message(self, mock_response):
@@ -428,7 +432,9 @@ def test_parse_invoke_connection_response_http_error_without_json_error_message(
with self.assertRaises(SkyflowError) as context:
parse_invoke_connection_response(mock_response)
- self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format("Internal Server Error"))
+ self.assertEqual(context.exception.message, "Internal Server Error")
+ self.assertEqual(context.exception.http_code, 500)
+ self.assertEqual(context.exception.request_id, "1234")
@patch("skyflow.utils._utils.log_and_reject_error")
def test_handle_exception_json_error(self, mock_log_and_reject_error):
@@ -597,3 +603,817 @@ def test__convert_detected_entity_to_entity_info_with_minimal_data(self):
self.assertEqual(result.text_index.end, 0)
self.assertEqual(result.processed_index.start, 0)
self.assertEqual(result.processed_index.end, 0)
+
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_connect_error(self, mock_log_and_reject_error):
+ """Test handling httpx.ConnectError."""
+ import httpx
+ mock_error = httpx.ConnectError("Connection refused")
+ mock_logger = Mock()
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ SkyflowMessages.Error.GENERIC_API_ERROR.value,
+ SkyflowMessages.ErrorCodes.INVALID_INPUT.value,
+ None,
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_no_headers_attribute(self, mock_log_and_reject_error):
+ """Test handling error without headers attribute."""
+ mock_error = Exception("Generic error")
+ mock_logger = Mock()
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ "Generic error",
+ SkyflowMessages.ErrorCodes.SERVER_ERROR.value,
+ None,
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_no_body_attribute(self, mock_log_and_reject_error):
+ """Test handling error without body attribute."""
+ mock_error = Mock()
+ mock_error.headers = {"x-request-id": "12345"}
+ delattr(mock_error, 'body')
+ mock_logger = Mock()
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once()
+ self.assertEqual(
+ mock_log_and_reject_error.call_args[0][1],
+ SkyflowMessages.ErrorCodes.SERVER_ERROR.value
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_text_plain_error(self, mock_log_and_reject_error):
+ """Test handling text/plain content type error."""
+ mock_error = Mock()
+ mock_error.headers = {
+ 'x-request-id': '1234',
+ 'content-type': 'text/plain'
+ }
+ mock_error.body = "Plain text error message"
+ mock_error.status = 500
+ mock_logger = Mock()
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ "Plain text error message",
+ 500,
+ "1234",
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_generic_error_with_status(self, mock_log_and_reject_error):
+ """Test handling generic error with unknown content type."""
+ mock_error = Mock()
+ mock_error.headers = {
+ 'x-request-id': '1234',
+ 'content-type': 'application/xml'
+ }
+ mock_error.body = "XML error"
+ mock_error.status = 503
+ mock_logger = Mock()
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ SkyflowMessages.Error.GENERIC_API_ERROR.value,
+ 503,
+ "1234",
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_no_content_type(self, mock_log_and_reject_error):
+ """Test handling error without content-type header."""
+ mock_error = Mock()
+ mock_error.headers = {'x-request-id': '1234'}
+ mock_error.body = "Some error"
+ mock_error.status = 500
+ mock_logger = Mock()
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ SkyflowMessages.Error.GENERIC_API_ERROR.value,
+ 500,
+ "1234",
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_json_error_with_json_string(self, mock_log_and_reject_error):
+ """Test handling JSON error when data is a JSON string."""
+ error_json_string = json.dumps({
+ "error": {
+ "message": "String JSON error",
+ "http_code": 422,
+ "http_status": "Unprocessable Entity",
+ "grpc_code": 3,
+ "details": ["validation failed"]
+ }
+ })
+
+ mock_error = Mock()
+ mock_logger = Mock()
+ request_id = "test-request-id-3"
+
+ handle_json_error(mock_error, error_json_string, request_id, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ "String JSON error",
+ 422,
+ request_id,
+ "Unprocessable Entity",
+ 3,
+ ["validation failed"],
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_json_error_with_invalid_json(self, mock_log_and_reject_error):
+ """Test handling JSON decode error."""
+ invalid_json = "This is not valid JSON"
+ mock_error = Mock()
+ mock_error.status = 500
+ mock_logger = Mock()
+ request_id = "test-request-id-4"
+
+ handle_json_error(mock_error, invalid_json, request_id, mock_logger)
+
+ # Should call with INVALID_JSON_RESPONSE error
+ mock_log_and_reject_error.assert_called_once()
+ self.assertEqual(
+ mock_log_and_reject_error.call_args[0][0],
+ SkyflowMessages.Error.INVALID_JSON_RESPONSE.value
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_json_error_missing_error_field(self, mock_log_and_reject_error):
+ """Test handling JSON error with missing error field."""
+ error_dict = {
+ "message": "Error without error wrapper"
+ }
+
+ mock_error = Mock()
+ mock_logger = Mock()
+ request_id = "test-request-id-5"
+
+ handle_json_error(mock_error, error_dict, request_id, mock_logger)
+
+ # Should use defaults for missing fields
+ mock_log_and_reject_error.assert_called_once()
+ args = mock_log_and_reject_error.call_args[0]
+ # Default message when error field is missing
+ self.assertEqual(args[0], SkyflowMessages.Error.UNKNOWN_ERROR_DEFAULT_MESSAGE.value)
+ # Default status code
+ self.assertEqual(args[1], 500)
+ self.assertEqual(args[2], request_id)
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_text_error_with_status(self, mock_log_and_reject_error):
+ """Test handle_text_error extracts status correctly."""
+ mock_error = Mock()
+ mock_error.status = 404
+ mock_logger = Mock()
+ request_id = "test-request-id-6"
+ error_data = "Resource not found"
+
+ from skyflow.utils._utils import handle_text_error
+ handle_text_error(mock_error, error_data, request_id, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ "Resource not found",
+ 404,
+ request_id,
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_generic_error_with_status(self, mock_log_and_reject_error):
+ """Test handle_generic_error_with_status."""
+ mock_error = Mock()
+ mock_logger = Mock()
+ request_id = "test-request-id-7"
+ status = 503
+
+ from skyflow.utils._utils import handle_generic_error_with_status
+ handle_generic_error_with_status(mock_error, request_id, status, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ SkyflowMessages.Error.GENERIC_API_ERROR.value,
+ 503,
+ request_id,
+ logger=mock_logger
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_with_none_error(self, mock_log_and_reject_error):
+ """Test handling None error object."""
+ mock_logger = Mock()
+
+ handle_exception(None, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ SkyflowMessages.Error.GENERIC_API_ERROR.value,
+ SkyflowMessages.ErrorCodes.SERVER_ERROR.value,
+ None,
+ logger=mock_logger
+ )
+
+ #failed
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_exception_with_empty_string_error(self, mock_log_and_reject_error):
+ """Test handling empty string error."""
+ mock_logger = Mock()
+ mock_error = Mock()
+ mock_error.headers = None
+ mock_error.body = None
+
+ handle_exception(mock_error, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once()
+ # Should use str(error) or default message
+ self.assertEqual(
+ mock_log_and_reject_error.call_args[0][1],
+ SkyflowMessages.ErrorCodes.SERVER_ERROR.value
+ )
+
+ @patch("skyflow.utils._utils.log_and_reject_error")
+ def test_handle_json_error_with_bytes_data(self, mock_log_and_reject_error):
+ """Test handling JSON error when data is bytes."""
+ error_dict = {
+ "error": {
+ "message": "Bytes error",
+ "http_code": 401,
+ "http_status": "Unauthorized"
+ }
+ }
+ error_bytes = json.dumps(error_dict).encode('utf-8')
+
+ mock_error = Mock()
+ mock_logger = Mock()
+ request_id = "test-request-id-8"
+
+ handle_json_error(mock_error, error_bytes, request_id, mock_logger)
+
+ mock_log_and_reject_error.assert_called_once_with(
+ "Bytes error",
+ 401,
+ request_id,
+ "Unauthorized",
+ None,
+ [],
+ logger=mock_logger
+ )
+
+ # Add these new test methods to the TestUtils class:
+
+ def test_construct_invoke_connection_request_with_no_headers(self):
+ """Test construct_invoke_connection_request when headers are None."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {"param1": "value1"}
+ mock_connection_request.headers = None
+ mock_connection_request.body = {"key": "value"}
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {"query": "test"}
+
+ connection_url = "https://example.com/{param1}/endpoint"
+
+ result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertIsInstance(result, PreparedRequest)
+ # Headers should be None when not provided
+ self.assertIsNone(result.headers.get('Content-Type'))
+
+ def test_construct_invoke_connection_request_with_xml_content_type(self):
+ """Test construct_invoke_connection_request with XML content type."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": "application/xml"}
+ mock_connection_request.body = {"root": {"child": "value"}}
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertIsInstance(result, PreparedRequest)
+ self.assertEqual(result.headers['content-type'], 'application/xml')
+ # Body should be converted to XML
+ self.assertIn('', result.body)
+ self.assertIn('value', result.body)
+
+ def test_construct_invoke_connection_request_with_html_content_type(self):
+ """Test construct_invoke_connection_request with HTML content type."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": "text/html"}
+ mock_connection_request.body = {"message": "Hello"}
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertIsInstance(result, PreparedRequest)
+ self.assertEqual(result.headers['content-type'], 'text/html')
+ # Body should be JSON string for HTML
+ self.assertEqual(result.body, json.dumps({"message": "Hello"}))
+
+ def test_construct_invoke_connection_request_multipart_removes_content_type(self):
+ """Test that Content-Type is removed for multipart/form-data."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": ContentType.FORMDATA.value}
+ mock_connection_request.body = {"field1": "value1", "field2": "value2"}
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertIsInstance(result, PreparedRequest)
+ # Content-Type should be auto-generated by requests library
+ self.assertIn('multipart/form-data', result.headers.get('Content-Type', ''))
+ self.assertIn('boundary=', result.headers.get('Content-Type', ''))
+
+ def test_construct_invoke_connection_request_with_no_body(self):
+ """Test construct_invoke_connection_request when body is None."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
+ mock_connection_request.body = None
+ mock_connection_request.method.value = "GET"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ result = construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertIsInstance(result, PreparedRequest)
+ self.assertIsNone(result.body)
+
+ def test_get_data_from_content_type_url_encoded(self):
+ """Test get_data_from_content_type with URL encoded content type."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = {"key1": "value1", "key2": "value2"}
+ content_type = ContentType.URLENCODED.value
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, "key1=value1&key2=value2")
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_form_data(self):
+ """Test get_data_from_content_type with form data content type."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = {"field1": "value1", "field2": "value2"}
+ content_type = ContentType.FORMDATA.value
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertIsNone(converted_data)
+ self.assertEqual(files["field1"], (None, "value1"))
+ self.assertEqual(files["field2"], (None, "value2"))
+
+ def test_get_data_from_content_type_json(self):
+ """Test get_data_from_content_type with JSON content type."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = {"key": "value"}
+ content_type = ContentType.JSON.value
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, json.dumps(data))
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_xml_with_dict(self):
+ """Test get_data_from_content_type with XML content type and dict data."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = {"root": {"child": "value"}}
+ content_type = "application/xml"
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertIn("", converted_data)
+ self.assertIn("value", converted_data)
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_xml_with_string(self):
+ """Test get_data_from_content_type with XML content type and string data."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = "value"
+ content_type = "text/xml"
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, data)
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_html_with_dict(self):
+ """Test get_data_from_content_type with HTML content type and dict data."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = {"message": "Hello"}
+ content_type = "text/html"
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, json.dumps(data))
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_html_with_string(self):
+ """Test get_data_from_content_type with HTML content type and string data."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = "Hello"
+ content_type = "text/html"
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, data)
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_unknown_type_with_dict(self):
+ """Test get_data_from_content_type with unknown content type and dict data."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = {"key": "value"}
+ content_type = "application/custom"
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, json.dumps(data))
+ self.assertEqual(files, {})
+
+ def test_get_data_from_content_type_unknown_type_with_string(self):
+ """Test get_data_from_content_type with unknown content type and string data."""
+ from skyflow.utils._utils import get_data_from_content_type
+
+ data = "plain text data"
+ content_type = "text/plain"
+
+ converted_data, files = get_data_from_content_type(data, content_type)
+
+ self.assertEqual(converted_data, data)
+ self.assertEqual(files, {})
+
+ def test_dict_to_xml_simple_dict(self):
+ """Test dict_to_xml with simple dictionary."""
+ from skyflow.utils._utils import dict_to_xml
+
+ data = {"name": "John", "age": "30"}
+ result = dict_to_xml(data)
+
+ self.assertIn("John", result)
+ self.assertIn("30", result)
+ self.assertTrue(result.startswith(""))
+ self.assertTrue(result.endswith(""))
+
+ def test_dict_to_xml_nested_dict(self):
+ """Test dict_to_xml with nested dictionary."""
+ from skyflow.utils._utils import dict_to_xml
+
+ data = {"person": {"name": "John", "age": "30"}}
+ result = dict_to_xml(data)
+
+ self.assertIn("", result)
+ self.assertIn("John", result)
+ self.assertIn("30", result)
+
+ def test_dict_to_xml_with_list(self):
+ """Test dict_to_xml with list values."""
+ from skyflow.utils._utils import dict_to_xml
+
+ data = {"items": ["item1", "item2", "item3"]}
+ result = dict_to_xml(data)
+
+ self.assertIn("item1", result)
+ self.assertIn("item2", result)
+ self.assertIn("item3", result)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_xml_content(self, mock_response):
+ """Test parsing XML response content."""
+ mock_response.status_code = 200
+ mock_response.content = b"success"
+ mock_response.headers = {
+ "x-request-id": "1234",
+ "content-type": "application/xml"
+ }
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "success")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_url_encoded_content(self, mock_response):
+ """Test parsing URL encoded response content."""
+ mock_response.status_code = 200
+ mock_response.content = b"card_number=4111111111111111&cvv=123"
+ mock_response.headers = {
+ "x-request-id": "1234",
+ "content-type": "application/x-www-form-urlencoded"
+ }
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "card_number=4111111111111111&cvv=123")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_html_content(self, mock_response):
+ """Test parsing HTML response content."""
+ mock_response.status_code = 200
+ mock_response.content = b"Success"
+ mock_response.headers = {
+ "x-request-id": "1234",
+ "content-type": "text/html"
+ }
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "Success")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_html_error(self, mock_response):
+ """Test parsing HTML error response."""
+ html_error = "Error 500
"
+ mock_response.status_code = 500
+ mock_response.content = html_error.encode('utf-8')
+ mock_response.headers = {
+ "x-request-id": "1234",
+ "content-type": "text/html"
+ }
+ mock_response.raise_for_status = Mock(side_effect=HTTPError("500 Error"))
+
+ with self.assertRaises(SkyflowError) as context:
+ parse_invoke_connection_response(mock_response)
+
+ self.assertEqual(context.exception.message, html_error)
+ self.assertEqual(context.exception.http_code, 500)
+ self.assertEqual(context.exception.request_id, "1234")
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_json_decode_falls_back_to_string(self, mock_response):
+ """Test that JSON decode error falls back to returning string content."""
+ mock_response.status_code = 200
+ mock_response.content = b"Not valid JSON but still success"
+ mock_response.headers = {
+ "x-request-id": "1234",
+ "content-type": "application/json"
+ }
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "Not valid JSON but still success")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_no_content_type_with_json(self, mock_response):
+ """Test parsing response with no content-type but valid JSON."""
+ mock_response.status_code = 200
+ mock_response.content = json.dumps({"success": True}).encode('utf-8')
+ mock_response.headers = {"x-request-id": "1234"}
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, {"success": True})
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_no_content_type_with_text(self, mock_response):
+ """Test parsing response with no content-type and non-JSON content."""
+ mock_response.status_code = 200
+ mock_response.content = b"Plain text response"
+ mock_response.headers = {"x-request-id": "1234"}
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "Plain text response")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ @patch("requests.Response")
+ def test_parse_invoke_connection_response_bytes_content(self, mock_response):
+ """Test parsing response with bytes content."""
+ mock_response.status_code = 200
+ mock_response.content = b"Binary data response"
+ mock_response.headers = {
+ "x-request-id": "1234",
+ "content-type": "application/octet-stream"
+ }
+ mock_response.raise_for_status = Mock()
+
+ result = parse_invoke_connection_response(mock_response)
+
+ self.assertIsInstance(result, InvokeConnectionResponse)
+ self.assertEqual(result.data, "Binary data response")
+ self.assertEqual(result.metadata["request_id"], "1234")
+ self.assertIsNone(result.errors)
+
+ def test_construct_invoke_connection_request_headers_json_error(self):
+ """Test exception handling when json.dumps fails for headers."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+
+ class UnserializableObject:
+ def __repr__(self):
+ raise TypeError("Object is not JSON serializable")
+
+ mock_connection_request.headers = {"key": UnserializableObject()}
+ mock_connection_request.body = None
+ mock_connection_request.method.value = "GET"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with patch('json.dumps', side_effect=TypeError("Object is not JSON serializable")):
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ def test_construct_invoke_connection_request_headers_generic_exception(self):
+ """Test generic exception handling for headers processing."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": "application/json"}
+ mock_connection_request.body = None
+ mock_connection_request.method.value = "GET"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with patch('skyflow.utils._utils.to_lowercase_keys', side_effect=Exception("Generic error")):
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ def test_construct_invoke_connection_request_body_processing_exception(self):
+ """Test exception handling when body processing fails."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
+ mock_connection_request.body = {"key": "value"}
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with patch('skyflow.utils._utils.get_data_from_content_type', side_effect=Exception("Body processing error")):
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ def test_construct_invoke_connection_request_body_json_dumps_exception(self):
+ """Test exception handling when json.dumps fails in get_data_from_content_type."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
+
+ class UnserializableObject:
+ pass
+
+ mock_connection_request.body = {"key": UnserializableObject()}
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ def test_construct_invoke_connection_request_invalid_url_exception(self):
+ """Test exception handling when requests.Request.prepare() fails with invalid URL."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = None
+ mock_connection_request.body = None
+ mock_connection_request.method.value = "GET"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with patch('requests.Request') as mock_request_class:
+ mock_request_instance = Mock()
+ mock_request_instance.prepare.side_effect = Exception("Invalid URL structure")
+ mock_request_class.return_value = mock_request_instance
+
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(
+ context.exception.message,
+ SkyflowMessages.Error.INVALID_URL.value.format(connection_url)
+ )
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ def test_construct_invoke_connection_request_prepare_exception(self):
+ """Test exception handling when prepare() method fails."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
+ mock_connection_request.body = None
+ mock_connection_request.method.value = "GET"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with patch('requests.Request') as mock_request_class:
+ mock_request_instance = Mock()
+ mock_request_instance.prepare.side_effect = Exception("Prepare failed")
+ mock_request_class.return_value = mock_request_instance
+
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(
+ context.exception.message,
+ SkyflowMessages.Error.INVALID_URL.value.format(connection_url)
+ )
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ def test_construct_invoke_connection_request_body_not_dict_raises_error(self):
+ """Test that non-dict body raises SkyflowError which is caught and re-raised."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {}
+ mock_connection_request.headers = {"Content-Type": ContentType.JSON.value}
+ mock_connection_request.body = "not a dict" # Invalid body type
+ mock_connection_request.method.value = "POST"
+ mock_connection_request.query_params = {}
+
+ connection_url = "https://example.com/endpoint"
+
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+
+ @patch('skyflow.utils._utils.validate_invoke_connection_params')
+ def test_construct_invoke_connection_request_validation_exception(self, mock_validate):
+ """Test that validation exceptions are properly propagated."""
+ mock_connection_request = Mock()
+ mock_connection_request.path_params = {"param": "value"}
+ mock_connection_request.headers = None
+ mock_connection_request.body = None
+ mock_connection_request.method.value = "GET"
+ mock_connection_request.query_params = {"query": "value"}
+
+ connection_url = "https://example.com/endpoint"
+
+ mock_validate.side_effect = SkyflowError("Validation failed", 400)
+
+ with self.assertRaises(SkyflowError) as context:
+ construct_invoke_connection_request(mock_connection_request, connection_url, logger=None)
+
+ self.assertEqual(context.exception.message, "Validation failed")
+ self.assertEqual(context.exception.http_code, 400)
diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py
index 48332a55..4f3b5487 100644
--- a/tests/utils/validations/test__validations.py
+++ b/tests/utils/validations/test__validations.py
@@ -116,7 +116,7 @@ def test_validate_credentials_with_expired_token(self):
with patch('skyflow.service_account.is_expired', return_value=True):
with self.assertRaises(SkyflowError) as context:
validate_credentials(self.logger, credentials)
- self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value)
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_BEARER_TOKEN.value)
def test_validate_credentials_empty_credentials(self):
credentials = {}
@@ -1044,3 +1044,66 @@ def test_validate_detokenize_request_invalid_redaction_type(self):
with self.assertRaises(SkyflowError) as context:
validate_detokenize_request(self.logger, request)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid"))))
+
+
+ def test_validate_deidentify_file_request_wait_time_negative(self):
+ file_input = FileInput(file_path=self.temp_file_path)
+ request = DeidentifyFileRequest(
+ file=file_input,
+ wait_time=-1,
+ entities=[DetectEntities.SSN]
+ )
+ with self.assertRaises(SkyflowError) as context:
+ validate_deidentify_file_request(self.logger, request)
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value)
+
+ def test_validate_deidentify_file_request_wait_time_greater_than_64(self):
+ file_input = FileInput(file_path=self.temp_file_path)
+ request = DeidentifyFileRequest(
+ file=file_input,
+ wait_time=65,
+ entities=[DetectEntities.SSN]
+ )
+ with self.assertRaises(SkyflowError) as context:
+ validate_deidentify_file_request(self.logger, request)
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value)
+
+ def test_validate_deidentify_file_request_wait_time_valid_boundary_lower(self):
+ file_input = FileInput(file_path=self.temp_file_path)
+ request = DeidentifyFileRequest(
+ file=file_input,
+ wait_time=0,
+ entities=[DetectEntities.SSN]
+ )
+ validate_deidentify_file_request(self.logger, request)
+
+ def test_validate_deidentify_file_request_wait_time_valid_boundary_upper(self):
+ file_input = FileInput(file_path=self.temp_file_path)
+ request = DeidentifyFileRequest(
+ file=file_input,
+ wait_time=64,
+ entities=[DetectEntities.SSN]
+ )
+ # Should not raise an error
+ validate_deidentify_file_request(self.logger, request)
+
+ def test_validate_deidentify_file_request_wait_time_valid_float(self):
+ file_input = FileInput(file_path=self.temp_file_path)
+ request = DeidentifyFileRequest(
+ file=file_input,
+ wait_time=32.5,
+ entities=[DetectEntities.SSN]
+ )
+ # Should not raise an error
+ validate_deidentify_file_request(self.logger, request)
+
+ def test_validate_deidentify_file_request_wait_time_float_out_of_range(self):
+ file_input = FileInput(file_path=self.temp_file_path)
+ request = DeidentifyFileRequest(
+ file=file_input,
+ wait_time=64.1,
+ entities=[DetectEntities.SSN]
+ )
+ with self.assertRaises(SkyflowError) as context:
+ validate_deidentify_file_request(self.logger, request)
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value)
diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py
index 565b1e6f..619c15ec 100644
--- a/tests/vault/client/test__client.py
+++ b/tests/vault/client/test__client.py
@@ -1,5 +1,8 @@
import unittest
from unittest.mock import patch, MagicMock
+
+from skyflow.error import SkyflowError
+from skyflow.utils import SkyflowMessages
from skyflow.vault.client.client import VaultClient
CONFIG = {
@@ -97,4 +100,85 @@ def test_get_log_level(self):
def test_get_logger(self):
mock_logger = MagicMock()
self.vault_client.set_logger("INFO", mock_logger)
- self.assertEqual(self.vault_client.get_logger(), mock_logger)
\ No newline at end of file
+ self.assertEqual(self.vault_client.get_logger(), mock_logger)
+
+ @patch("skyflow.vault.client.client.is_expired")
+ @patch("skyflow.vault.client.client.generate_bearer_token")
+ def test_get_bearer_token_expired_token_raises_error(self, mock_generate_bearer_token, mock_is_expired):
+ """Test that expired token raises SkyflowError."""
+ credentials = {"path": "/path/to/credentials.json"}
+ mock_generate_bearer_token.return_value = ("expired_token", None)
+ mock_is_expired.return_value = True
+
+ with self.assertRaises(SkyflowError) as context:
+ self.vault_client.get_bearer_token(credentials)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+ self.assertTrue(self.vault_client._VaultClient__is_config_updated)
+
+ @patch("skyflow.vault.client.client.is_expired")
+ @patch("skyflow.vault.client.client.generate_bearer_token_from_creds")
+ def test_get_bearer_token_expired_token_from_creds_string_raises_error(self, mock_generate_bearer_token_from_creds, mock_is_expired):
+ """Test that expired token from credentials string raises SkyflowError."""
+ credentials = {"credentials_string": '{"key": "value"}'}
+ mock_generate_bearer_token_from_creds.return_value = ("expired_token", None)
+ mock_is_expired.return_value = True
+
+ with self.assertRaises(SkyflowError) as context:
+ self.vault_client.get_bearer_token(credentials)
+
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)
+ self.assertTrue(self.vault_client._VaultClient__is_config_updated)
+
+ @patch("skyflow.vault.client.client.is_expired")
+ @patch("skyflow.vault.client.client.generate_bearer_token")
+ def test_get_bearer_token_reuses_valid_token(self, mock_generate_bearer_token, mock_is_expired):
+ """Test that valid bearer token is reused."""
+ credentials = {"path": "/path/to/credentials.json"}
+ mock_generate_bearer_token.return_value = ("valid_token", None)
+ mock_is_expired.return_value = False
+
+ token1 = self.vault_client.get_bearer_token(credentials)
+ self.assertEqual(token1, "valid_token")
+ mock_generate_bearer_token.assert_called_once()
+
+ token2 = self.vault_client.get_bearer_token(credentials)
+ self.assertEqual(token2, "valid_token")
+ mock_generate_bearer_token.assert_called_once()
+
+ @patch("skyflow.vault.client.client.is_expired")
+ @patch("skyflow.vault.client.client.generate_bearer_token")
+ def test_get_bearer_token_regenerates_after_config_update(self, mock_generate_bearer_token, mock_is_expired):
+ """Test that bearer token is regenerated after config update."""
+ credentials = {"path": "/path/to/credentials.json"}
+ mock_generate_bearer_token.side_effect = [("first_token", None), ("second_token", None)]
+ mock_is_expired.return_value = False
+
+ token1 = self.vault_client.get_bearer_token(credentials)
+ self.assertEqual(token1, "first_token")
+
+ self.vault_client.update_config({"new_key": "new_value"})
+
+ token2 = self.vault_client.get_bearer_token(credentials)
+ self.assertEqual(token2, "second_token")
+ self.assertEqual(mock_generate_bearer_token.call_count, 2)
+
+ @patch("skyflow.vault.client.client.is_expired")
+ @patch("skyflow.vault.client.client.generate_bearer_token_from_creds")
+ @patch("skyflow.vault.client.client.log_info")
+ def test_get_bearer_token_with_credentials_string(self, mock_log_info, mock_generate_bearer_token_from_creds, mock_is_expired):
+ """Test get_bearer_token with credentials_string."""
+ credentials = {"credentials_string": '{"clientID": "test", "clientName": "test"}'}
+ mock_generate_bearer_token_from_creds.return_value = ("token_from_creds", None)
+ mock_is_expired.return_value = False
+
+ token = self.vault_client.get_bearer_token(credentials)
+
+ self.assertEqual(token, "token_from_creds")
+ mock_generate_bearer_token_from_creds.assert_called_once()
+ mock_log_info.assert_called_with(
+ SkyflowMessages.Info.GENERATE_BEARER_TOKEN_FROM_CREDENTIALS_STRING_TRIGGERED.value,
+ None
+ )
diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py
index 4ccad1c7..35a13716 100644
--- a/tests/vault/controller/test__connection.py
+++ b/tests/vault/controller/test__connection.py
@@ -1,5 +1,5 @@
import unittest
-from unittest.mock import Mock, patch
+from unittest.mock import Mock, patch, MagicMock
import requests
from skyflow.error import SkyflowError
from skyflow.utils import SkyflowMessages, parse_invoke_connection_response
@@ -30,10 +30,16 @@ def setUp(self):
self.mock_vault_client = Mock()
self.mock_vault_client.get_config.return_value = VAULT_CONFIG
self.mock_vault_client.get_bearer_token.return_value = VALID_BEARER_TOKEN
+ self.mock_vault_client.get_logger.return_value = Mock()
+ self.mock_vault_client.get_common_skyflow_credentials.return_value = None
self.connection = Connection(self.mock_vault_client)
+ @patch('skyflow.vault.controller._connections.get_credentials')
@patch('requests.Session.send')
- def test_invoke_success(self, mock_send):
+ def test_invoke_success(self, mock_send, mock_get_credentials):
+ # Mock get_credentials to return credentials
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
# Mocking successful response
mock_response = Mock()
mock_response.status_code = SUCCESS_STATUS_CODE
@@ -60,9 +66,36 @@ def test_invoke_success(self, mock_send):
}
self.assertEqual(vars(response), expected_response)
self.mock_vault_client.get_bearer_token.assert_called_once()
+ mock_get_credentials.assert_called_once()
+ @patch('skyflow.vault.controller._connections.get_credentials')
@patch('requests.Session.send')
- def test_invoke_invalid_headers(self, mock_send):
+ def test_invoke_with_x_skyflow_authorization_already_present(self, mock_send, mock_get_credentials):
+ """Test that X-Skyflow-Authorization is not overwritten if already present in headers."""
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
+ mock_response = Mock()
+ mock_response.status_code = SUCCESS_STATUS_CODE
+ mock_response.content = SUCCESS_RESPONSE_CONTENT
+ mock_response.headers = {'x-request-id': 'test-request-id'}
+ mock_send.return_value = mock_response
+
+ custom_auth = "custom_bearer_token"
+ request = InvokeConnectionRequest(
+ method=RequestMethod.POST,
+ body=VALID_BODY,
+ headers={"x-skyflow-authorization": custom_auth}
+ )
+
+ response = self.connection.invoke(request)
+
+ # Verify bearer token from vault_client is NOT used
+ self.assertIsNotNone(response)
+
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ def test_invoke_invalid_headers(self, mock_get_credentials):
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
request = InvokeConnectionRequest(
method="POST",
body=VALID_BODY,
@@ -75,8 +108,10 @@ def test_invoke_invalid_headers(self, mock_send):
self.connection.invoke(request)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_HEADERS.value)
- @patch('requests.Session.send')
- def test_invoke_invalid_body(self, mock_send):
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ def test_invoke_invalid_body(self, mock_get_credentials):
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
request = InvokeConnectionRequest(
method="POST",
body=INVALID_BODY,
@@ -89,11 +124,16 @@ def test_invoke_invalid_body(self, mock_send):
self.connection.invoke(request)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REQUEST_BODY.value)
+ @patch('skyflow.vault.controller._connections.get_credentials')
@patch('requests.Session.send')
- def test_invoke_request_error(self, mock_send):
+ def test_invoke_request_error(self, mock_send, mock_get_credentials):
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
mock_response = Mock()
mock_response.status_code = FAILURE_STATUS_CODE
- mock_response.content = ERROR_RESPONSE_CONTENT
+ mock_response.content = ERROR_RESPONSE_CONTENT.encode('utf-8') # Convert to bytes
+ mock_response.headers = {"x-request-id": "test-request-id"}
+ mock_response.raise_for_status.side_effect = requests.HTTPError("400 Error")
mock_send.return_value = mock_response
request = InvokeConnectionRequest(
@@ -106,9 +146,99 @@ def test_invoke_request_error(self, mock_send):
with self.assertRaises(SkyflowError) as context:
self.connection.invoke(request)
- self.assertEqual(context.exception.message, f'Skyflow Python SDK {SDK_VERSION} Response {ERROR_RESPONSE_CONTENT} is not valid JSON.')
- self.assertEqual(context.exception.message, SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(ERROR_RESPONSE_CONTENT))
- self.assertEqual(context.exception.http_code, 400)
+
+ self.assertEqual(context.exception.message, ERROR_RESPONSE_CONTENT)
+ self.assertEqual(context.exception.http_code, FAILURE_STATUS_CODE)
+ self.assertEqual(context.exception.request_id, "test-request-id")
+
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ @patch('requests.Session.send')
+ def test_invoke_session_send_exception(self, mock_send, mock_get_credentials):
+ """Test handling of generic exception from session.send()."""
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
+ mock_send.side_effect = Exception("Network error")
+
+ request = InvokeConnectionRequest(
+ method=RequestMethod.POST,
+ body=VALID_BODY,
+ headers=VALID_HEADERS
+ )
+
+ with self.assertRaises(SkyflowError) as context:
+ self.connection.invoke(request)
+ self.assertEqual(context.exception.message, SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value)
+ self.assertEqual(context.exception.http_code, SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
+
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ @patch('requests.Session.send')
+ def test_invoke_skyflow_error_re_raised(self, mock_send, mock_get_credentials):
+ """Test that SkyflowError is re-raised without wrapping."""
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
+ original_error = SkyflowError("Original error", 401)
+ mock_send.side_effect = original_error
+
+ request = InvokeConnectionRequest(
+ method=RequestMethod.POST,
+ body=VALID_BODY,
+ headers=VALID_HEADERS
+ )
+
+ with self.assertRaises(SkyflowError) as context:
+ self.connection.invoke(request)
+ # Should be the same original error
+ self.assertEqual(context.exception.message, "Original error")
+ self.assertEqual(context.exception.http_code, 401)
+
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ @patch('requests.Session.send')
+ def test_invoke_session_close_called(self, mock_send, mock_get_credentials):
+ """Test that session.close() is called after send()."""
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
+ mock_response = Mock()
+ mock_response.status_code = SUCCESS_STATUS_CODE
+ mock_response.content = SUCCESS_RESPONSE_CONTENT
+ mock_response.headers = {'x-request-id': 'test-request-id'}
+ mock_send.return_value = mock_response
+
+ with patch('requests.Session.close') as mock_close:
+ request = InvokeConnectionRequest(
+ method=RequestMethod.GET,
+ headers=VALID_HEADERS
+ )
+
+ response = self.connection.invoke(request)
+
+ # Verify close was called
+ mock_close.assert_called_once()
+
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ @patch('skyflow.vault.controller._connections.get_metrics')
+ @patch('requests.Session.send')
+ def test_invoke_adds_sky_metadata_header(self, mock_send, mock_get_metrics, mock_get_credentials):
+ """Test that sky-metadata header is added to request."""
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+ mock_get_metrics.return_value = {"sdk_version": SDK_VERSION}
+
+ mock_response = Mock()
+ mock_response.status_code = SUCCESS_STATUS_CODE
+ mock_response.content = SUCCESS_RESPONSE_CONTENT
+ mock_response.headers = {'x-request-id': 'test-request-id'}
+ mock_send.return_value = mock_response
+
+ request = InvokeConnectionRequest(
+ method=RequestMethod.POST,
+ body=VALID_BODY,
+ headers=VALID_HEADERS
+ )
+
+ response = self.connection.invoke(request)
+
+ # Verify get_metrics was called
+ mock_get_metrics.assert_called_once()
+ self.assertIsNotNone(response)
def test_parse_invoke_connection_response_error_from_client(self):
mock_response = Mock(spec=requests.Response)
@@ -128,3 +258,37 @@ def test_parse_invoke_connection_response_error_from_client(self):
self.assertTrue(any(detail.get('error_from_client') == True for detail in exception.details))
self.assertEqual(exception.request_id, '12345')
+ @patch('skyflow.vault.controller._connections.get_credentials')
+ @patch('skyflow.vault.controller._connections.construct_invoke_connection_request')
+ def test_invoke_construct_request_called(self, mock_construct, mock_get_credentials):
+ """Test that construct_invoke_connection_request is called with correct parameters."""
+ mock_get_credentials.return_value = {"api_key": "test_api_key"}
+
+ mock_prepared_request = Mock(spec=requests.PreparedRequest)
+ mock_prepared_request.headers = {}
+ mock_construct.return_value = mock_prepared_request
+
+ with patch('requests.Session.send') as mock_send:
+ mock_response = Mock()
+ mock_response.status_code = SUCCESS_STATUS_CODE
+ mock_response.content = SUCCESS_RESPONSE_CONTENT
+ mock_response.headers = {'x-request-id': 'test-request-id'}
+ mock_send.return_value = mock_response
+
+ request = InvokeConnectionRequest(
+ method=RequestMethod.GET,
+ headers=VALID_HEADERS
+ )
+
+ self.connection.invoke(request)
+
+ # Verify construct was called with connection_url from config
+ mock_construct.assert_called_once_with(
+ request,
+ VAULT_CONFIG["connection_url"],
+ self.mock_vault_client.get_logger()
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()
\ No newline at end of file