diff --git a/skyflow/service_account/_utils.py b/skyflow/service_account/_utils.py index 3f21ba2..346f5ca 100644 --- a/skyflow/service_account/_utils.py +++ b/skyflow/service_account/_utils.py @@ -33,18 +33,15 @@ def is_expired(token, logger = None): def generate_bearer_token(credentials_file_path, options = None, logger = None): try: log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger) - credentials_file =open(credentials_file_path, 'r') - except Exception: + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) + except OSError: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - - try: - credentials = json.load(credentials_file) - except Exception: - log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger) - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code) - - finally: - credentials_file.close() + result = get_service_account_token(credentials, options, logger) return result @@ -144,19 +141,15 @@ def get_signed_tokens(credentials_obj, options): def generate_signed_data_tokens(credentials_file_path, options): log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value) try: - credentials_file =open(credentials_file_path, 'r') - except Exception: + with open(credentials_file_path, 'r') as credentials_file: + try: + credentials = json.load(credentials_file) + except Exception: + raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), + invalid_input_error_code) + except FileNotFoundError: raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code) - try: - credentials = json.load(credentials_file) - except Exception: - raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), - invalid_input_error_code) - - finally: - credentials_file.close() - return get_signed_tokens(credentials, options) def generate_signed_data_tokens_from_creds(credentials, options): diff --git a/skyflow/vault/controller/_connections.py b/skyflow/vault/controller/_connections.py index ca8c7a1..d2abbc7 100644 --- a/skyflow/vault/controller/_connections.py +++ b/skyflow/vault/controller/_connections.py @@ -30,14 +30,17 @@ def invoke(self, request: InvokeConnectionRequest): log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger()) + response = None try: response = session.send(invoke_connection_request) - session.close() - invoke_connection_response = parse_invoke_connection_response(response) - return invoke_connection_response + return parse_invoke_connection_response(response) except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.INVOKE_CONNECTION_REQUEST_REJECTED.value, self.__vault_client.get_logger()) if isinstance(e, SkyflowError): raise e raise SkyflowError(SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value, - SkyflowMessages.ErrorCodes.SERVER_ERROR.value) \ No newline at end of file + SkyflowMessages.ErrorCodes.SERVER_ERROR.value) + finally: + if response is not None: + response.close() + session.close() diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index c6ef2fb..78ed5f9 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -63,22 +63,26 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): current_wait_time = 1 # Start with 1 second try: while True: - response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data - status = response.status - if status == DetectStatus.IN_PROGRESS: - if current_wait_time >= max_wait_time: - return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) - else: - next_wait_time = current_wait_time * 2 - if next_wait_time >= max_wait_time: - wait_time = max_wait_time - current_wait_time - current_wait_time = max_wait_time + http_response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()) + try: + response = http_response.data + status = response.status + if status == DetectStatus.IN_PROGRESS: + if current_wait_time >= max_wait_time: + return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS) else: - wait_time = next_wait_time - current_wait_time = next_wait_time - time.sleep(wait_time) - elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: - return response + next_wait_time = current_wait_time * 2 + if next_wait_time >= max_wait_time: + wait_time = max_wait_time - current_wait_time + current_wait_time = max_wait_time + else: + wait_time = next_wait_time + current_wait_time = next_wait_time + time.sleep(wait_time) + elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED: + return response + finally: + http_response.close() except Exception as e: raise e @@ -231,9 +235,12 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS], request_options=self.__get_headers() ) - deidentify_text_response = parse_deidentify_text_response(api_response) - log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) - return deidentify_text_response + try: + deidentify_text_response = parse_deidentify_text_response(api_response) + log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) + return deidentify_text_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DEIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger()) @@ -255,9 +262,12 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo format=reidentify_text_body[DeidentifyField.FORMAT], request_options=self.__get_headers() ) - reidentify_text_response = parse_reidentify_text_response(api_response) - log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) - return reidentify_text_response + try: + reidentify_text_response = parse_reidentify_text_response(api_response) + log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger()) + return reidentify_text_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.REIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger()) @@ -272,7 +282,7 @@ def __get_file_from_request(self, request: DeidentifyFileRequest): # Check for file_path if file is not provided if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None: - return open(file_input.file_path, 'rb') + return open(file_input.file_path, 'rb') def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) @@ -282,8 +292,16 @@ def deidentify_file(self, request: DeidentifyFileRequest): file_obj = self.__get_file_from_request(request) file_name = getattr(file_obj, FileUploadField.NAME, None) file_extension = self._get_file_extension(file_name) if file_name else None - file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) + + # Track if we need to close the file (only if it was opened from file_path) + file_needs_closing = hasattr(request.file, 'file_path') and request.file.file_path is not None + + try: + file_content = file_obj.read() + base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8) + finally: + if file_needs_closing and hasattr(file_obj, 'close'): + file_obj.close() try: if file_extension == FileExtension.TXT: @@ -421,16 +439,19 @@ def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) api_response = api_call(**api_kwargs) - run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) + try: + run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None) - processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - if request.output_directory and processed_response.status == DetectStatus.SUCCESS: - name_without_ext, _ = os.path.splitext(file_name) - self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) + processed_response = self.__poll_for_processed_file(run_id, request.wait_time) + if request.output_directory and processed_response.status == DetectStatus.SUCCESS: + name_without_ext, _ = os.path.splitext(file_name) + self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) - parsed_response = self.__parse_deidentify_file_response(processed_response, run_id) - log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger()) - return parsed_response + parsed_response = self.__parse_deidentify_file_response(processed_response, run_id) + log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger()) + return parsed_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value, @@ -446,17 +467,20 @@ def get_detect_run(self, request: GetDetectRunRequest): files_api = self.__vault_client.get_detect_file_api().with_raw_response run_id = request.run_id try: - response = files_api.get_run( + http_response = files_api.get_run( run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers() ) - if response.data.status == DetectStatus.IN_PROGRESS: - parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) - else: - parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) - log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) - return parsed_response + try: + if http_response.data.status == DetectStatus.IN_PROGRESS: + parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)) + else: + parsed_response = self.__parse_deidentify_file_response(http_response.data, run_id, http_response.data.status) + log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) + return parsed_response + finally: + http_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value, self.__vault_client.get_logger()) diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 856a196..883a299 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -112,9 +112,12 @@ def insert(self, request: InsertRequest): api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) - insert_response = parse_insert_response(api_response, request.continue_on_error) - log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) - return insert_response + try: + insert_response = parse_insert_response(api_response, request.continue_on_error) + log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) + return insert_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.INSERT_RECORDS_REJECTED.value, self.__vault_client.get_logger()) @@ -140,9 +143,12 @@ def update(self, request: UpdateRequest): byot=request.token_mode.value, request_options = self.__get_headers() ) - log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) - update_response = parse_update_record_response(api_response) - return update_response + try: + log_info(SkyflowMessages.Info.UPDATE_SUCCESS.value, self.__vault_client.get_logger()) + update_response = parse_update_record_response(api_response) + return update_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.UPDATE_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) @@ -161,9 +167,12 @@ def delete(self, request: DeleteRequest): skyflow_ids=request.ids, request_options=self.__get_headers() ) - log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) - delete_response = parse_delete_response(api_response) - return delete_response + try: + log_info(SkyflowMessages.Info.DELETE_SUCCESS.value, self.__vault_client.get_logger()) + delete_response = parse_delete_response(api_response) + return delete_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DELETE_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) @@ -191,9 +200,12 @@ def get(self, request: GetRequest): column_values=request.column_values, request_options=self.__get_headers() ) - log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) - get_response = parse_get_response(api_response) - return get_response + try: + log_info(SkyflowMessages.Info.GET_SUCCESS.value, self.__vault_client.get_logger()) + get_response = parse_get_response(api_response) + return get_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.GET_REQUEST_REJECTED.value, self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) @@ -211,9 +223,12 @@ def query(self, request: QueryRequest): query=request.query, request_options=self.__get_headers() ) - log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) - query_response = parse_query_response(api_response) - return query_response + try: + log_info(SkyflowMessages.Info.QUERY_SUCCESS.value, self.__vault_client.get_logger()) + query_response = parse_query_response(api_response) + return query_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.QUERY_REQUEST_REJECTED.value, self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) @@ -239,9 +254,12 @@ def detokenize(self, request: DetokenizeRequest): continue_on_error = request.continue_on_error, request_options=self.__get_headers() ) - log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) - detokenize_response = parse_detokenize_response(api_response) - return detokenize_response + try: + log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) + detokenize_response = parse_detokenize_response(api_response) + return detokenize_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DETOKENIZE_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) @@ -264,9 +282,12 @@ def tokenize(self, request: TokenizeRequest): tokenization_parameters=records_list, request_options=self.__get_headers() ) - tokenize_response = parse_tokenize_response(api_response) - log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) - return tokenize_response + try: + tokenize_response = parse_tokenize_response(api_response) + log_info(SkyflowMessages.Info.TOKENIZE_SUCCESS.value, self.__vault_client.get_logger()) + return tokenize_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.TOKENIZE_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) @@ -287,13 +308,16 @@ def upload_file(self, request: FileUploadRequest): return_file_metadata= False, request_options=self.__get_headers() ) - log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) - log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) - upload_response = FileUploadResponse( - skyflow_id=api_response.data.skyflow_id, - errors=None - ) - return upload_response + try: + log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) + log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger()) + upload_response = FileUploadResponse( + skyflow_id=api_response.data.skyflow_id, + errors=None + ) + return upload_response + finally: + api_response.close() except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.FILE_UPLOAD_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger())