From 0985e60b54105739a0a20b3029e9b0ed422ea65a Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 7 May 2025 15:22:15 -0400 Subject: [PATCH 1/2] Init LogClient --- .../jumpstarter/jumpstarter/client/base.py | 17 ++++--- .../jumpstarter/jumpstarter/client/core.py | 49 +++++++++++++------ 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/client/base.py b/packages/jumpstarter/jumpstarter/client/base.py index 1616d4c02..c4d812514 100644 --- a/packages/jumpstarter/jumpstarter/client/base.py +++ b/packages/jumpstarter/jumpstarter/client/base.py @@ -11,10 +11,20 @@ from pydantic import ConfigDict from pydantic.dataclasses import dataclass -from .core import AsyncDriverClient +from .core import AsyncDriverClient, AsyncLogClient from jumpstarter.streams.blocking import BlockingStream +@dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True)) +class LogClient(AsyncLogClient): + portal: BlockingPortal + + @contextmanager + def log_stream(self): + with self.portal.wrap_async_context_manager(self.log_stream_async()): + yield + + @dataclass(kw_only=True, config=ConfigDict(arbitrary_types_allowed=True)) class DriverClient(AsyncDriverClient): """Base class for driver clients @@ -75,11 +85,6 @@ def stream(self, method="connect"): with self.portal.wrap_async_context_manager(self.stream_async(method)) as stream: yield BlockingStream(stream=stream, portal=self.portal) - @contextmanager - def log_stream(self): - with self.portal.wrap_async_context_manager(self.log_stream_async()): - yield - def open_stream(self) -> BlockingStream: """ Open a blocking stream session without a context manager. diff --git a/packages/jumpstarter/jumpstarter/client/core.py b/packages/jumpstarter/jumpstarter/client/core.py index 052be09f0..4a53b76d2 100644 --- a/packages/jumpstarter/jumpstarter/client/core.py +++ b/packages/jumpstarter/jumpstarter/client/core.py @@ -45,6 +45,39 @@ class DriverInvalidArgument(DriverError, ValueError): """ +@dataclass(kw_only=True) +class AsyncLogClient( + jumpstarter_pb2_grpc.ExporterServiceStub, +): + """ + Async log client base class + """ + + channel: Channel + logger: logging.Logger = field( + default_factory=lambda: logging.getLogger("exporter"), + ) + + def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + + jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel) + + @asynccontextmanager + async def log_stream_async(self): + async def log_stream_task(): + async for response in self.LogStream(empty_pb2.Empty()): + self.logger.log(logging.getLevelName(response.severity), response.message) + + async with create_task_group() as tg: + tg.start_soon(log_stream_task) + try: + yield + finally: + tg.cancel_scope.cancel() + + @dataclass(kw_only=True) class AsyncDriverClient( Metadata, @@ -150,19 +183,3 @@ async def resource_async( yield ResourceMetadata(**rstream.extra(MetadataStreamAttributes.metadata)).resource.model_dump( mode="json" ) - - def __log(self, level: int, msg: str): - self.logger.log(level, msg) - - @asynccontextmanager - async def log_stream_async(self): - async def log_stream(): - async for response in self.LogStream(empty_pb2.Empty()): - self.__log(logging.getLevelName(response.severity), response.message) - - async with create_task_group() as tg: - tg.start_soon(log_stream) - try: - yield - finally: - tg.cancel_scope.cancel() From 81d8cd5b2c6eddb836156541de047a596bd9e312 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 7 May 2025 15:48:50 -0400 Subject: [PATCH 2/2] Let jmp shell handle log stream instead of j --- packages/jumpstarter-cli/jumpstarter_cli/j.py | 3 +- .../jumpstarter-cli/jumpstarter_cli/shell.py | 35 +++++++++++++------ .../jumpstarter/client/__init__.py | 12 +++++-- .../jumpstarter/jumpstarter/client/client.py | 24 +++++++++++-- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/packages/jumpstarter-cli/jumpstarter_cli/j.py b/packages/jumpstarter-cli/jumpstarter_cli/j.py index 64f43b587..e484b5247 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/j.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/j.py @@ -17,8 +17,7 @@ async def cli(): async with BlockingPortal() as portal: with ExitStack() as stack: async with env_async(portal, stack) as client: - async with client.log_stream_async(): - await to_thread.run_sync(lambda: client.cli()(standalone_mode=False)) + await to_thread.run_sync(lambda: client.cli()(standalone_mode=False)) try: async with create_task_group() as tg: diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index ac0501a55..890cb88a0 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -2,10 +2,12 @@ from datetime import timedelta import click +from anyio.from_thread import start_blocking_portal from jumpstarter_cli_common.config import opt_config from jumpstarter_cli_common.exceptions import handle_exceptions from .common import opt_duration_partial, opt_selector +from jumpstarter.client import log_client_from_path from jumpstarter.common.utils import launch_shell from jumpstarter.config.client import ClientConfigV1Alpha1 from jumpstarter.config.exporter import ExporterConfigV1Alpha1 @@ -38,26 +40,37 @@ def shell(config, command: tuple[str, ...], lease_name, selector, duration): case ClientConfigV1Alpha1(): exit_code = 0 - with config.lease(selector=selector, lease_name=lease_name, duration=duration) as lease: - with lease.serve_unix() as path: - with lease.monitor(): - exit_code = launch_shell( - path, - "remote", - config.drivers.allow, - config.drivers.unsafe, - command=command, - ) + with ( + start_blocking_portal() as portal, + portal.wrap_async_context_manager( + config.lease_async(portal=portal, selector=selector, lease_name=lease_name, duration=duration) + ) as lease, + portal.wrap_async_context_manager(lease.monitor_async()), + portal.wrap_async_context_manager(lease.serve_unix_async()) as path, + portal.wrap_async_context_manager(log_client_from_path(path=path, portal=portal)) as log, + log.log_stream(), + ): + exit_code = launch_shell( + path, + "remote", + config.drivers.allow, + config.drivers.unsafe, + command=command, + ) sys.exit(exit_code) case ExporterConfigV1Alpha1(): + exit_code = 0 + with config.serve_unix() as path: # SAFETY: the exporter config is local thus considered trusted - launch_shell( + exit_code = launch_shell( path, "local", allow=[], unsafe=True, command=command, ) + + sys.exit(exit_code) diff --git a/packages/jumpstarter/jumpstarter/client/__init__.py b/packages/jumpstarter/jumpstarter/client/__init__.py index 8a7941922..cc79b10f5 100644 --- a/packages/jumpstarter/jumpstarter/client/__init__.py +++ b/packages/jumpstarter/jumpstarter/client/__init__.py @@ -1,5 +1,11 @@ -from .base import DriverClient -from .client import client_from_path +from .base import DriverClient, LogClient +from .client import client_from_path, log_client_from_path from .lease import Lease -__all__ = ["DriverClient", "client_from_path", "Lease"] +__all__ = [ + "DriverClient", + "LogClient", + "client_from_path", + "log_client_from_path", + "Lease", +] diff --git a/packages/jumpstarter/jumpstarter/client/client.py b/packages/jumpstarter/jumpstarter/client/client.py index b994cc725..886f539f8 100644 --- a/packages/jumpstarter/jumpstarter/client/client.py +++ b/packages/jumpstarter/jumpstarter/client/client.py @@ -8,18 +8,38 @@ from google.protobuf import empty_pb2 from jumpstarter_protocol import jumpstarter_pb2_grpc -from jumpstarter.client import DriverClient +from jumpstarter.client import DriverClient, LogClient from jumpstarter.common.importlib import import_class @asynccontextmanager -async def client_from_path(path: str, portal: BlockingPortal, stack: ExitStack, allow: list[str], unsafe: bool): +async def client_from_path( + path: str, + portal: BlockingPortal, + stack: ExitStack, + allow: list[str], + unsafe: bool, +): async with grpc.aio.secure_channel( f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) ) as channel: yield await client_from_channel(channel, portal, stack, allow, unsafe) +@asynccontextmanager +async def log_client_from_path( + path: str, + portal: BlockingPortal, +): + async with grpc.aio.secure_channel( + f"unix://{path}", grpc.local_channel_credentials(grpc.LocalConnectionType.UDS) + ) as channel: + yield LogClient( + channel=channel, + portal=portal, + ) + + async def client_from_channel( channel: grpc.aio.Channel, portal: BlockingPortal,