From c140d18a0dee0284fa0d77ebac23236140291b7e Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Tue, 9 Dec 2025 13:59:16 -0500 Subject: [PATCH 1/5] ENH: security for websockets --- bluesky_httpserver/authentication.py | 31 ++++++++++++++++++++++++-- bluesky_httpserver/routers/core_api.py | 30 +++++++++++++++++++++---- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/authentication.py index 673ca29..20b951c 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/authentication.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security +from fastapi import APIRouter, Depends, HTTPException, Request, Response, Security, WebSocket from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm, SecurityScopes @@ -202,7 +202,6 @@ def get_current_principal( # otherwise it is None. The original set of API key scopes is used for generating new # API keys. roles, scopes, api_key_scopes = {}, {}, None - if api_key is not None: if authenticators: # Tiled is in a multi-user configuration with authentication providers. @@ -356,6 +355,34 @@ def get_current_principal( return principal +def get_current_principal_websocket( + websocket: WebSocket, + scopes: str, +): + app = websocket.app + security_scopes = SecurityScopes(scopes=scopes or []) + settings = app.dependency_overrides[get_settings]() + authenticators = app.dependency_overrides[get_authenticators]() + api_access_manager = app.dependency_overrides[get_api_access_manager]() + + auth_header = websocket.headers.get("Authorization", "") + access_token, api_key = None, None + if auth_header.startswith("Bearer "): + access_token = auth_header[len("Bearer") :].strip() + if auth_header.startswith("ApiKey "): + api_key = auth_header[len("ApiKey") :].strip() + + return get_current_principal( + request=websocket, + security_scopes=security_scopes, + access_token=access_token, + api_key=api_key, + settings=settings, + authenticators=authenticators, + api_access_manager=api_access_manager, + ) + + def create_session(settings, identity_provider, id, scopes): with get_sessionmaker(settings.database_settings)() as db: # Have we seen this Identity before? diff --git a/bluesky_httpserver/routers/core_api.py b/bluesky_httpserver/routers/core_api.py index 7eaa74e..397972b 100644 --- a/bluesky_httpserver/routers/core_api.py +++ b/bluesky_httpserver/routers/core_api.py @@ -14,7 +14,7 @@ else: from pydantic_settings import BaseSettings -from ..authentication import get_current_principal +from ..authentication import get_current_principal, get_current_principal_websocket from ..console_output import ConsoleOutputEventStream, StreamingResponseFromClass from ..resources import SERVER_RESOURCES as SR from ..settings import get_settings @@ -1139,7 +1139,12 @@ def is_alive(self): @router.websocket("/console_output/ws") -async def console_output_ws(websocket: WebSocket): +async def console_output_ws(websocket: WebSocket, scopes=["read:console"]): + principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + if not principal: + await websocket.close(code=4001, reason="Invalid token") + return + await websocket.accept() q = SR.console_output_stream.add_queue(websocket) wsmon = WebSocketMonitor(websocket) @@ -1151,6 +1156,8 @@ async def console_output_ws(websocket: WebSocket): await websocket.send_text(msg) except asyncio.TimeoutError: pass + except RuntimeError: # 'send' after the client is disconnected + pass except WebSocketDisconnect: pass finally: @@ -1158,11 +1165,17 @@ async def console_output_ws(websocket: WebSocket): @router.websocket("/status/ws") -async def status_ws(websocket: WebSocket): +async def status_ws(websocket: WebSocket, scopes=["read:monitor"]): + principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + if not principal: + await websocket.close(code=4001, reason="Invalid token") + return + await websocket.accept() q = SR.system_info_stream.add_queue_status(websocket) wsmon = WebSocketMonitor(websocket) wsmon.start() + try: while wsmon.is_alive: try: @@ -1170,6 +1183,8 @@ async def status_ws(websocket: WebSocket): await websocket.send_text(msg) except asyncio.TimeoutError: pass + except RuntimeError: # 'send' after the client is disconnected + pass except WebSocketDisconnect: pass finally: @@ -1177,7 +1192,12 @@ async def status_ws(websocket: WebSocket): @router.websocket("/info/ws") -async def info_ws(websocket: WebSocket): +async def info_ws(websocket: WebSocket, scopes=["read:monitor"]): + principal = get_current_principal_websocket(websocket=websocket, scopes=scopes) + if not principal: + await websocket.close(code=4001, reason="Invalid token") + return + await websocket.accept() q = SR.system_info_stream.add_queue_info(websocket) wsmon = WebSocketMonitor(websocket) @@ -1189,6 +1209,8 @@ async def info_ws(websocket: WebSocket): await websocket.send_text(msg) except asyncio.TimeoutError: pass + except RuntimeError: # 'send' after the client is disconnected + pass except WebSocketDisconnect: pass finally: From a96ed4c51dbbfba9d110e07ddaada463af8d0134 Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Tue, 9 Dec 2025 14:33:14 -0500 Subject: [PATCH 2/5] TST: unit tests for sockets with authentication --- bluesky_httpserver/tests/test_console_output.py | 3 ++- bluesky_httpserver/tests/test_system_info_socket.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bluesky_httpserver/tests/test_console_output.py b/bluesky_httpserver/tests/test_console_output.py index 1b87e53..1f089ec 100644 --- a/bluesky_httpserver/tests/test_console_output.py +++ b/bluesky_httpserver/tests/test_console_output.py @@ -353,7 +353,8 @@ def __init__(self, api_key=API_KEY_FOR_TESTS, **kwargs): def run(self): websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api/console_output/ws" - with connect(websocket_uri) as websocket: + additional_headers = {"Authorization": f"ApiKey {self._api_key}"} + with connect(websocket_uri, additional_headers=additional_headers) as websocket: while not self._exit: try: msg_json = websocket.recv(timeout=0.1, decode=False) diff --git a/bluesky_httpserver/tests/test_system_info_socket.py b/bluesky_httpserver/tests/test_system_info_socket.py index b20c98c..75f2984 100644 --- a/bluesky_httpserver/tests/test_system_info_socket.py +++ b/bluesky_httpserver/tests/test_system_info_socket.py @@ -35,7 +35,8 @@ def __init__(self, *, endpoint, api_key=API_KEY_FOR_TESTS, **kwargs): def run(self): websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}" - with connect(websocket_uri) as websocket: + additional_headers = {"Authorization": f"ApiKey {self._api_key}"} + with connect(websocket_uri, additional_headers=additional_headers) as websocket: while not self._exit: try: msg_json = websocket.recv(timeout=0.1, decode=False) From ab7a13fd0bb26da5173354c5fc77f4e2df838117 Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Wed, 10 Dec 2025 16:28:31 -0500 Subject: [PATCH 3/5] TST: unit tests for authenticated websockets --- bluesky_httpserver/authentication.py | 24 ++- .../tests/test_auth_for_websockets.py | 168 ++++++++++++++++++ .../tests/test_system_info_socket.py | 21 ++- 3 files changed, 195 insertions(+), 18 deletions(-) create mode 100644 bluesky_httpserver/tests/test_auth_for_websockets.py diff --git a/bluesky_httpserver/authentication.py b/bluesky_httpserver/authentication.py index 20b951c..9effc80 100644 --- a/bluesky_httpserver/authentication.py +++ b/bluesky_httpserver/authentication.py @@ -372,15 +372,21 @@ def get_current_principal_websocket( if auth_header.startswith("ApiKey "): api_key = auth_header[len("ApiKey") :].strip() - return get_current_principal( - request=websocket, - security_scopes=security_scopes, - access_token=access_token, - api_key=api_key, - settings=settings, - authenticators=authenticators, - api_access_manager=api_access_manager, - ) + principal = None + try: + principal = get_current_principal( + request=websocket, + security_scopes=security_scopes, + access_token=access_token, + api_key=api_key, + settings=settings, + authenticators=authenticators, + api_access_manager=api_access_manager, + ) + except HTTPException as ex: + print(f"WebSocket connection failed: {ex}") + + return principal def create_session(settings, identity_provider, id, scopes): diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py new file mode 100644 index 0000000..e5dd535 --- /dev/null +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -0,0 +1,168 @@ +import json +import pprint +import threading +import time as ttime + +import pytest +from bluesky_queueserver.manager.tests.common import re_manager, re_manager_cmd # noqa F401 +from websockets.sync.client import connect + +from .conftest import fastapi_server_fs # noqa: F401 +from .conftest import ( + SERVER_ADDRESS, + SERVER_PORT, + request_to_json, + setup_server_with_config_file, + wait_for_environment_to_be_closed, + wait_for_environment_to_be_created, +) + +config_toy_test = """ +authentication: + allow_anonymous_access: True + providers: + - provider: toy + authenticator: bluesky_httpserver.authenticators:DictionaryAuthenticator + args: + users_to_passwords: + bob: bob_password + alice: alice_password + cara: cara_password + tom: tom_password +api_access: + policy: bluesky_httpserver.authorization:DictionaryAPIAccessControl + args: + users: + bob: + roles: + - admin + - expert + alice: + roles: advanced + tom: + roles: user +""" + + +class _ReceiveSystemInfoSocket(threading.Thread): + """ + Catch streaming console output by connecting to /console_output/ws socket and + save messages to the buffer. + """ + + def __init__(self, *, endpoint, api_key=None, token=None, **kwargs): + super().__init__(**kwargs) + self.received_data_buffer = [] + self._exit = False + self._api_key = api_key + self._token = token + self._endpoint = endpoint + + def run(self): + websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}" + if self._token is not None: + additional_headers = {"Authorization": f"Bearer {self._token}"} + elif self._api_key is not None: + additional_headers = {"Authorization": f"ApiKey {self._api_key}"} + else: + additional_headers = {} + + try: + with connect(websocket_uri, additional_headers=additional_headers) as websocket: + while not self._exit: + try: + msg_json = websocket.recv(timeout=0.1, decode=False) + try: + msg = json.loads(msg_json) + self.received_data_buffer.append(msg) + except json.JSONDecodeError: + pass + except TimeoutError: + pass + except Exception as ex: + print(f"Failed to connect to server: {ex}") + + def stop(self): + """ + Call this method to stop the thread. Then send a request to the server so that some output + is printed in ``stdout``. + """ + self._exit = True + + def __del__(self): + self.stop() + + +# fmt: off +@pytest.mark.parametrize("ws_auth_type", ["apikey", "token", "none"]) +# fmt: on +def test_websocket_auth_01( + tmpdir, + monkeypatch, + re_manager_cmd, # noqa: F811 + fastapi_server_fs, # noqa: F811 + ws_auth_type, +): + """ + ``/auth/apikey`` (GET): basic tests. + """ + + # Start RE Manager + params = ["--zmq-publish-console", "ON"] + re_manager_cmd(params) + + setup_server_with_config_file(config_file_str=config_toy_test, tmpdir=tmpdir, monkeypatch=monkeypatch) + fastapi_server_fs() + + resp1 = request_to_json("post", "/auth/provider/toy/token", login=("bob", "bob_password")) + assert "access_token" in pprint.pformat(resp1) + token = resp1["access_token"] + + resp3 = request_to_json( + "post", "/auth/apikey", json={"expires_in": 900, "note": "API key for testing"}, token=token + ) + assert "secret" in resp3, pprint.pformat(resp3) + assert "note" in resp3, pprint.pformat(resp3) + assert resp3["note"] == "API key for testing" + assert resp3["scopes"] == ["inherit"] + api_key = resp3["secret"] + + endpoint = "/status/ws" + if ws_auth_type == "none": + ws_params = {} + elif ws_auth_type == "apikey": + ws_params = {"api_key": api_key} + elif ws_auth_type == "token": + ws_params = {"token": token} + else: + assert False, f"Unknown authentication type: {ws_auth_type!r}" + + rsc = _ReceiveSystemInfoSocket(endpoint=endpoint, **ws_params) + rsc.start() + ttime.sleep(1) # Wait until the client connects to the socket + + resp1 = request_to_json("post", "/environment/open", api_key=api_key) + assert resp1["success"] is True, pprint.pformat(resp1) + + assert wait_for_environment_to_be_created(timeout=10, api_key=api_key) + + resp2b = request_to_json("post", "/environment/close", api_key=api_key) + assert resp2b["success"] is True, pprint.pformat(resp2b) + + assert wait_for_environment_to_be_closed(timeout=10, api_key=api_key) + + # Wait until capture is complete + ttime.sleep(2) + rsc.stop() + rsc.join() + + buffer = rsc.received_data_buffer + if ws_auth_type == "none": + assert len(buffer) == 0 + else: + assert len(buffer) > 0 + for msg in buffer: + assert "time" in msg, msg + assert isinstance(msg["time"], float), msg + assert "msg" in msg + assert isinstance(msg["msg"], dict) diff --git a/bluesky_httpserver/tests/test_system_info_socket.py b/bluesky_httpserver/tests/test_system_info_socket.py index 75f2984..4d25dd6 100644 --- a/bluesky_httpserver/tests/test_system_info_socket.py +++ b/bluesky_httpserver/tests/test_system_info_socket.py @@ -36,17 +36,20 @@ def __init__(self, *, endpoint, api_key=API_KEY_FOR_TESTS, **kwargs): def run(self): websocket_uri = f"ws://{SERVER_ADDRESS}:{SERVER_PORT}/api{self._endpoint}" additional_headers = {"Authorization": f"ApiKey {self._api_key}"} - with connect(websocket_uri, additional_headers=additional_headers) as websocket: - while not self._exit: - try: - msg_json = websocket.recv(timeout=0.1, decode=False) + try: + with connect(websocket_uri, additional_headers=additional_headers) as websocket: + while not self._exit: try: - msg = json.loads(msg_json) - self.received_data_buffer.append(msg) - except json.JSONDecodeError: + msg_json = websocket.recv(timeout=0.1, decode=False) + try: + msg = json.loads(msg_json) + self.received_data_buffer.append(msg) + except json.JSONDecodeError: + pass + except TimeoutError: pass - except TimeoutError: - pass + except Exception as ex: + print(f"Failed to connect to server: {ex}") def stop(self): """ From 023b08d6f3f234586073c5ccf6bf3d1e230ed72a Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Thu, 11 Dec 2025 09:51:58 -0500 Subject: [PATCH 4/5] TST: additional test cases for websocket authentication --- .../tests/test_auth_for_websockets.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bluesky_httpserver/tests/test_auth_for_websockets.py b/bluesky_httpserver/tests/test_auth_for_websockets.py index e5dd535..2a76109 100644 --- a/bluesky_httpserver/tests/test_auth_for_websockets.py +++ b/bluesky_httpserver/tests/test_auth_for_websockets.py @@ -94,7 +94,7 @@ def __del__(self): # fmt: off -@pytest.mark.parametrize("ws_auth_type", ["apikey", "token", "none"]) +@pytest.mark.parametrize("ws_auth_type", ["apikey", "token", "apikey_invalid", "token_invalid", "none"]) # fmt: on def test_websocket_auth_01( tmpdir, @@ -104,7 +104,8 @@ def test_websocket_auth_01( ws_auth_type, ): """ - ``/auth/apikey`` (GET): basic tests. + Test authentication for websockets. The test is run only on ``/status/ws`` websocket. + The other websockets are expected to use the same authentication scheme. """ # Start RE Manager @@ -132,8 +133,12 @@ def test_websocket_auth_01( ws_params = {} elif ws_auth_type == "apikey": ws_params = {"api_key": api_key} + elif ws_auth_type == "apikey_invalid": + ws_params = {"api_key": "InvalidApiKey"} elif ws_auth_type == "token": ws_params = {"token": token} + elif ws_auth_type == "token_invalid": + ws_params = {"token": "InvalidToken"} else: assert False, f"Unknown authentication type: {ws_auth_type!r}" @@ -157,12 +162,14 @@ def test_websocket_auth_01( rsc.join() buffer = rsc.received_data_buffer - if ws_auth_type == "none": + if ws_auth_type in ("none", "apikey_invalid", "token_invalid"): assert len(buffer) == 0 - else: + elif ws_auth_type in ("apikey", "token"): assert len(buffer) > 0 for msg in buffer: assert "time" in msg, msg assert isinstance(msg["time"], float), msg assert "msg" in msg assert isinstance(msg["msg"], dict) + else: + assert False, f"Unknown authentication type: {ws_auth_type!r}" From 2c138d04fbf083c381f74a77448921fc69cec576 Mon Sep 17 00:00:00 2001 From: Dmitri Gavrilov Date: Fri, 12 Dec 2025 15:52:46 -0500 Subject: [PATCH 5/5] ENH: add 'user:apikey' scopes to all default user groups --- bluesky_httpserver/authorization/_defaults.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bluesky_httpserver/authorization/_defaults.py b/bluesky_httpserver/authorization/_defaults.py index ecdf4dc..37448b8 100644 --- a/bluesky_httpserver/authorization/_defaults.py +++ b/bluesky_httpserver/authorization/_defaults.py @@ -73,6 +73,7 @@ "write:plan:control", "write:execute", "write:history:edit", + "user:apikeys", } _DEFAULT_SCOPES_USER = { @@ -91,6 +92,7 @@ "write:plan:control", "write:execute", "write:history:edit", + "user:apikeys", } _DEFAULT_SCOPES_OBSERVER = { @@ -103,6 +105,7 @@ "read:console", "read:lock", "read:testing", + "user:apikeys", } # =============================================================================================