Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 17 additions & 50 deletions compute_modules/client/internal_query_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@

import json
import os
import ssl
import time
import traceback
from typing import Any, Callable, Dict, Iterable, List
from urllib.parse import urlparse

import requests

Expand All @@ -38,11 +36,6 @@
POST_RESTART_MAX_ATTEMPTS = 5


def _extract_path_from_url(url: str) -> str:
parsed_url = urlparse(url)
return parsed_url.path


class InternalQueryService:
def __init__(
self,
Expand All @@ -57,16 +50,12 @@ def __init__(
self.function_schema_conversions = function_schema_conversions
self.is_function_context_typed = is_function_context_typed
self.streaming = streaming
self.host = os.environ["RUNTIME_HOST"]
self.port = int(os.environ["RUNTIME_PORT"])
self.get_job_path = _extract_path_from_url(os.environ["GET_JOB_URI"])
self.post_result_path = _extract_path_from_url(os.environ["POST_RESULT_URI"])
self.post_schema_path = _extract_path_from_url(os.environ["POST_SCHEMA_URI"])
self.post_restart_path = _extract_path_from_url(os.environ["RESTART_NOTIFICATION_URI"])
self._initialize_auth_token()
self.get_job_url = os.environ["GET_JOB_URI_V2"]
self.post_result_url = os.environ["POST_RESULT_URI_V2"]
self.post_schema_url = os.environ["POST_SCHEMA_URI_V2"]
# self.post_restart_url = os.environ["RESTART_NOTIFICATION_URI_V2"]
self.post_restart_url = f"{os.environ['RUNTIME_API_V2']}/restart-notify"
self._initialize_headers()
self.certPath = os.environ["CONNECTIONS_TO_OTHER_PODS_CA_PATH"]
self.context = ssl.create_default_context(cafile=self.certPath)
self.connection_refused_count: int = 0
self.concurrency = int(os.environ.get("MAX_CONCURRENT_TASKS", 1))
self.logger = get_internal_logger()
Expand All @@ -84,22 +73,9 @@ def _set_logger_process_id(self, process_id: int) -> None:
"""Set the process_id for internal & public logger"""
COMPUTE_MODULES_ADAPTER_MANAGER.update_process_id(process_id=process_id)

def _initialize_auth_token(self) -> None:
try:
with open(os.environ["MODULE_AUTH_TOKEN"], "r", encoding="utf-8") as f:
self.moduleAuthToken = f.read()
except Exception as e:
self.logger.error(f"Failed to read auth token: {str(e)}")
raise

def _initialize_headers(self) -> None:
self.get_job_headers = {"Module-Auth-Token": self.moduleAuthToken}
self.post_result_headers = {
"Content-Type": "application/octet-stream",
"Module-Auth-Token": self.moduleAuthToken,
}
self.post_schema_headers = {"Content-Type": "application/json", "Module-Auth-Token": self.moduleAuthToken}
self.post_restart_headers = {"Module-Auth-Token": self.moduleAuthToken}
self.post_result_headers = {"Content-Type": "application/octet-stream"}
self.post_schema_headers = {"Content-Type": "application/json"}

def _iterable_to_json_generator(self, iterable: Iterable[Any]) -> Iterable[bytes]:
self.logger.debug("iterating over result")
Expand All @@ -111,20 +87,17 @@ def init_session(self) -> None:
"""Initialize requests.Session"""
self.session = requests.Session()

def build_url(self, path: str) -> str:
return f"https://{self.host}:{self.port}{path}"

def post_query_schemas(self) -> None:
"""Post the function schemas of the Compute Module"""
self.logger.debug(f"Posting function schemas: {self.function_schemas}")
self.logger.debug(f"post_schema_url: {self.post_schema_url}")
for i in range(POST_SCHEMAS_MAX_ATTEMPTS):
try:
with self.session.request(
method="POST",
url=self.build_url(self.post_schema_path),
url=self.post_schema_url,
json=self.function_schemas,
headers=self.post_schema_headers,
verify=self.certPath,
) as response:
self.logger.debug(
f"POST /schemas response status: {response.status_code} reason: {response.reason}"
Expand All @@ -143,9 +116,7 @@ def get_job_or_none(self) -> Any:
try:
with self.session.request(
method="GET",
url=self.build_url(self.get_job_path),
headers=self.get_job_headers,
verify=self.certPath,
url=self.get_job_url,
) as response:
result = None
if response.status_code == 200:
Expand Down Expand Up @@ -174,9 +145,9 @@ def report_job_result_failed(self, post_result_url: str, error: str) -> None:
url=post_result_url,
headers=self.post_result_headers,
data=json.dumps({"error": error}).encode("utf-8"),
verify=self.certPath,
) as response:
if response.status_code == 204:
# HTTP version returns 202 while witchcraft returns 204
if response.status_code in (202, 204):
self.logger.debug("Successfully reported that job result posting has failed")
return
else:
Expand All @@ -189,8 +160,7 @@ def report_job_result_failed(self, post_result_url: str, error: str) -> None:
raise RuntimeError(f"Unable to report that post result has failed after {POST_ERROR_MAX_ATTEMPTS} attempts")

def report_job_result(self, job_id: str, body: Any) -> None:
post_result_path = f"{self.post_result_path}/{job_id}"
post_result_url = self.build_url(post_result_path)
post_result_url = f"{self.post_result_url}/{job_id}"
self.logger.debug(f"Posting result to {post_result_url}")
for _ in range(POST_RESULT_MAX_ATTEMPTS):
try:
Expand All @@ -199,9 +169,9 @@ def report_job_result(self, job_id: str, body: Any) -> None:
url=post_result_url,
headers=self.post_result_headers,
data=body,
verify=self.certPath,
) as response:
if response.status_code == 204:
# HTTP version returns 202 while witchcraft returns 204
if response.status_code in (202, 204):
self.logger.debug("Successfully reported job result")
return
else:
Expand Down Expand Up @@ -287,16 +257,13 @@ def get_failed_query(exception: Exception) -> Dict[str, str]:
return {"exception": f"{str(exception)}: {traceback.format_exc()}"}

def report_restart(self) -> None:
post_restart_url = self.build_url(self.post_restart_path)
self.logger.debug(f"Reporting restart to {post_restart_url}")
self.logger.debug(f"Reporting restart to {self.post_restart_url}")

for _ in range(POST_RESTART_MAX_ATTEMPTS):
try:
with self.session.request(
method="POST",
url=post_restart_url,
headers=self.post_restart_headers,
verify=self.certPath,
url=self.post_restart_url,
) as response:
self.logger.debug(
f"Reporting restart response status: {response.status_code} reason: {response.reason}"
Expand Down