Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions aw_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from aw_core.models import Event
from aw_transform.heartbeats import heartbeat_merge

from .config import load_config
from .config import load_config, load_local_server_api_key
from .singleinstance import SingleInstance

# FIXME: This line is probably badly placed
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(

server_host = host or server_config["hostname"]
server_port = port or server_config["port"]
self.server_api_key = load_local_server_api_key(str(server_host), server_port)
self.server_address = f"{protocol}://{server_host}:{server_port}"

self.instance = SingleInstance(
Expand All @@ -107,9 +108,15 @@ def __init__(
def _url(self, endpoint: str):
return f"{self.server_address}/api/0/{endpoint}"

def _headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
request_headers = dict(headers or {})
if self.server_api_key:
request_headers.setdefault("Authorization", f"Bearer {self.server_api_key}")
return request_headers

@always_raise_for_request_errors
def _get(self, endpoint: str, params: Optional[dict] = None) -> req.Response:
return req.get(self._url(endpoint), params=params)
return req.get(self._url(endpoint), params=params, headers=self._headers())

@always_raise_for_request_errors
def _post(
Expand All @@ -118,7 +125,9 @@ def _post(
data: Union[List[Any], Dict[str, Any]],
params: Optional[dict] = None,
) -> req.Response:
headers = {"Content-type": "application/json", "charset": "utf-8"}
headers = self._headers(
{"Content-type": "application/json", "charset": "utf-8"}
)
return req.post(
self._url(endpoint),
data=bytes(json.dumps(data), "utf8"),
Expand All @@ -130,7 +139,7 @@ def _post(
def _delete(self, endpoint: str, data: Any = None) -> req.Response:
if data is None:
data = {}
headers = {"Content-type": "application/json"}
headers = self._headers({"Content-type": "application/json"})
return req.delete(self._url(endpoint), data=json.dumps(data), headers=headers)

def get_info(self):
Expand Down
47 changes: 47 additions & 0 deletions aw_client/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import logging
import os
from typing import Optional, Union

import tomlkit
from aw_core import dirs
from aw_core.config import load_config_toml

logger = logging.getLogger(__name__)

default_config = """
[server]
hostname = "127.0.0.1"
Expand All @@ -19,3 +27,42 @@

def load_config():
return load_config_toml("aw-client", default_config)


def load_local_server_api_key(host: str, port: Union[int, str]) -> Optional[str]:
if host not in {"127.0.0.1", "localhost", "::1"}:
return None

try:
requested_port = int(str(port))
except (TypeError, ValueError):
return None

config_dir = dirs.get_config_dir("aw-server-rust")
candidates = (
("config.toml", 5600),
("config-testing.toml", 5666),
)

for filename, default_port in candidates:
config_path = os.path.join(config_dir, filename)
if not os.path.isfile(config_path):
continue

try:
with open(config_path, encoding="utf-8") as f:
config = tomlkit.parse(f.read())
configured_port = int(str(config.get("port", default_port)))
if configured_port != requested_port:
continue

auth_config = config.get("auth", {})
api_key = auth_config.get("api_key")
if api_key:
return str(api_key)
except Exception as e:
logger.warning(
"Failed to read aw-server-rust config %s: %s", config_path, e
)

return None
Loading
Loading