diff --git a/compute_modules/client/internal_query_client.py b/compute_modules/client/internal_query_client.py index 5cece41..fafa05c 100644 --- a/compute_modules/client/internal_query_client.py +++ b/compute_modules/client/internal_query_client.py @@ -15,11 +15,9 @@ import json import os -import ssl import time import traceback from typing import Any, Callable, Dict, Iterable, List -from urllib.parse import urlparse import requests @@ -38,11 +36,6 @@ POST_RESTART_MAX_ATTEMPTS = 5 -def _extract_path_from_url(url: str) -> str: - parsed_url = urlparse(url) - return parsed_url.path - - class InternalQueryService: def __init__( self, @@ -57,16 +50,12 @@ def __init__( self.function_schema_conversions = function_schema_conversions self.is_function_context_typed = is_function_context_typed self.streaming = streaming - self.host = os.environ["RUNTIME_HOST"] - self.port = int(os.environ["RUNTIME_PORT"]) - self.get_job_path = _extract_path_from_url(os.environ["GET_JOB_URI"]) - self.post_result_path = _extract_path_from_url(os.environ["POST_RESULT_URI"]) - self.post_schema_path = _extract_path_from_url(os.environ["POST_SCHEMA_URI"]) - self.post_restart_path = _extract_path_from_url(os.environ["RESTART_NOTIFICATION_URI"]) - self._initialize_auth_token() + self.get_job_url = os.environ["GET_JOB_URI_V2"] + self.post_result_url = os.environ["POST_RESULT_URI_V2"] + self.post_schema_url = os.environ["POST_SCHEMA_URI_V2"] + # self.post_restart_url = os.environ["RESTART_NOTIFICATION_URI_V2"] + self.post_restart_url = f"{os.environ['RUNTIME_API_V2']}/restart-notify" self._initialize_headers() - self.certPath = os.environ["CONNECTIONS_TO_OTHER_PODS_CA_PATH"] - self.context = ssl.create_default_context(cafile=self.certPath) self.connection_refused_count: int = 0 self.concurrency = int(os.environ.get("MAX_CONCURRENT_TASKS", 1)) self.logger = get_internal_logger() @@ -84,22 +73,9 @@ def _set_logger_process_id(self, process_id: int) -> None: """Set the process_id for internal & public logger""" COMPUTE_MODULES_ADAPTER_MANAGER.update_process_id(process_id=process_id) - def _initialize_auth_token(self) -> None: - try: - with open(os.environ["MODULE_AUTH_TOKEN"], "r", encoding="utf-8") as f: - self.moduleAuthToken = f.read() - except Exception as e: - self.logger.error(f"Failed to read auth token: {str(e)}") - raise - def _initialize_headers(self) -> None: - self.get_job_headers = {"Module-Auth-Token": self.moduleAuthToken} - self.post_result_headers = { - "Content-Type": "application/octet-stream", - "Module-Auth-Token": self.moduleAuthToken, - } - self.post_schema_headers = {"Content-Type": "application/json", "Module-Auth-Token": self.moduleAuthToken} - self.post_restart_headers = {"Module-Auth-Token": self.moduleAuthToken} + self.post_result_headers = {"Content-Type": "application/octet-stream"} + self.post_schema_headers = {"Content-Type": "application/json"} def _iterable_to_json_generator(self, iterable: Iterable[Any]) -> Iterable[bytes]: self.logger.debug("iterating over result") @@ -111,20 +87,17 @@ def init_session(self) -> None: """Initialize requests.Session""" self.session = requests.Session() - def build_url(self, path: str) -> str: - return f"https://{self.host}:{self.port}{path}" - def post_query_schemas(self) -> None: """Post the function schemas of the Compute Module""" self.logger.debug(f"Posting function schemas: {self.function_schemas}") + self.logger.debug(f"post_schema_url: {self.post_schema_url}") for i in range(POST_SCHEMAS_MAX_ATTEMPTS): try: with self.session.request( method="POST", - url=self.build_url(self.post_schema_path), + url=self.post_schema_url, json=self.function_schemas, headers=self.post_schema_headers, - verify=self.certPath, ) as response: self.logger.debug( f"POST /schemas response status: {response.status_code} reason: {response.reason}" @@ -143,9 +116,7 @@ def get_job_or_none(self) -> Any: try: with self.session.request( method="GET", - url=self.build_url(self.get_job_path), - headers=self.get_job_headers, - verify=self.certPath, + url=self.get_job_url, ) as response: result = None if response.status_code == 200: @@ -174,9 +145,9 @@ def report_job_result_failed(self, post_result_url: str, error: str) -> None: url=post_result_url, headers=self.post_result_headers, data=json.dumps({"error": error}).encode("utf-8"), - verify=self.certPath, ) as response: - if response.status_code == 204: + # HTTP version returns 202 while witchcraft returns 204 + if response.status_code in (202, 204): self.logger.debug("Successfully reported that job result posting has failed") return else: @@ -189,8 +160,7 @@ def report_job_result_failed(self, post_result_url: str, error: str) -> None: raise RuntimeError(f"Unable to report that post result has failed after {POST_ERROR_MAX_ATTEMPTS} attempts") def report_job_result(self, job_id: str, body: Any) -> None: - post_result_path = f"{self.post_result_path}/{job_id}" - post_result_url = self.build_url(post_result_path) + post_result_url = f"{self.post_result_url}/{job_id}" self.logger.debug(f"Posting result to {post_result_url}") for _ in range(POST_RESULT_MAX_ATTEMPTS): try: @@ -199,9 +169,9 @@ def report_job_result(self, job_id: str, body: Any) -> None: url=post_result_url, headers=self.post_result_headers, data=body, - verify=self.certPath, ) as response: - if response.status_code == 204: + # HTTP version returns 202 while witchcraft returns 204 + if response.status_code in (202, 204): self.logger.debug("Successfully reported job result") return else: @@ -287,16 +257,13 @@ def get_failed_query(exception: Exception) -> Dict[str, str]: return {"exception": f"{str(exception)}: {traceback.format_exc()}"} def report_restart(self) -> None: - post_restart_url = self.build_url(self.post_restart_path) - self.logger.debug(f"Reporting restart to {post_restart_url}") + self.logger.debug(f"Reporting restart to {self.post_restart_url}") for _ in range(POST_RESTART_MAX_ATTEMPTS): try: with self.session.request( method="POST", - url=post_restart_url, - headers=self.post_restart_headers, - verify=self.certPath, + url=self.post_restart_url, ) as response: self.logger.debug( f"Reporting restart response status: {response.status_code} reason: {response.reason}"