From 8398849b089e1d764b70cbffdd9335279c83e768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Thu, 15 Jan 2026 19:43:18 +0100 Subject: [PATCH] feat: add transparent jpeg support --- pyproject.toml | 1 + src/agents/client.py | 26 ++++++++++++++++++++------ src/agents/policies.py | 26 ++++++++++++++++++++------ src/agents/server.py | 6 +++--- src/tests/test_connection.py | 22 ++++++++++++++++++++++ 5 files changed, 66 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 854f312..6c4070b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "wandb", "pillow", "tqdm", + "simplejpeg", ] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] diff --git a/src/agents/client.py b/src/agents/client.py index 607b6aa..4fd8e8a 100644 --- a/src/agents/client.py +++ b/src/agents/client.py @@ -1,3 +1,4 @@ +import base64 import dataclasses from dataclasses import asdict from multiprocessing import shared_memory @@ -6,8 +7,9 @@ import json_numpy import numpy as np import rpyc +import simplejpeg -from agents.policies import Act, Agent, Obs, SharedMemoryPayload +from agents.policies import Act, Agent, CameraDataType, Obs, SharedMemoryPayload def dataclass_from_dict(klass, d): @@ -20,7 +22,7 @@ def dataclass_from_dict(klass, d): class RemoteAgent(Agent): - def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False): + def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False, jpeg_encoding: bool = False): """Connect to a remote agent service. Args: @@ -29,15 +31,18 @@ def __init__(self, host: str, port: int, model: str, on_same_machine: bool = Fal model (str): Name of the model to connect to. on_same_machine (bool, optional): If True, assumes the agent is running on the same machine and uses shared memory for more efficient communication. Defaults to False. + jpeg_encoding (bool, optional): If True the image data is jpeg encoded for smaller transfer size. + Defaults to False. """ self.on_same_machine = on_same_machine + self.jpeg_encoding = jpeg_encoding self._shm: dict[str, shared_memory.SharedMemory] = {} self.c = rpyc.connect( host, port, config={"allow_pickle": True, "allow_public_attrs": True, "sync_request_timeout": 300} ) assert model == self.c.root.name() - def _to_shared_memory(self, obs: Obs) -> Obs: + def _process(self, obs: Obs) -> Obs: if self.on_same_machine: camera_dict = {} for camera_name, camera_data in obs.cameras.items(): @@ -54,17 +59,26 @@ def _to_shared_memory(self, obs: Obs) -> Obs: dtype=camera_data.dtype.name, ) obs.cameras = camera_dict - obs.camera_data_in_shared_memory = True + obs.camera_data_type = CameraDataType.SHARED_MEMORY + elif self.jpeg_encoding: + camera_dict = {} + for camera_name, camera_data in obs.cameras.items(): + assert isinstance(camera_data, np.ndarray) + camera_dict[camera_name] = base64.urlsafe_b64encode( + simplejpeg.encode_jpeg(np.ascontiguousarray(camera_data)) + ).decode("utf-8") + obs.cameras = camera_dict + obs.camera_data_type = CameraDataType.JPEG_ENCODED return obs def act(self, obs: Obs) -> Act: - obs = self._to_shared_memory(obs) + obs = self._process(obs) obs = json_numpy.dumps(asdict(obs)) # action, done, info return dataclass_from_dict(Act, json_numpy.loads(self.c.root.act(obs))) def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: - obs = self._to_shared_memory(obs) + obs = self._process(obs) obs_dict = asdict(obs) # info return json_numpy.loads(self.c.root.reset(json_numpy.dumps((obs_dict, instruction, kwargs)))) diff --git a/src/agents/policies.py b/src/agents/policies.py index a56a52e..19c09a9 100644 --- a/src/agents/policies.py +++ b/src/agents/policies.py @@ -12,6 +12,7 @@ from typing import Any, Union import numpy as np +import simplejpeg from PIL import Image @@ -22,12 +23,18 @@ class SharedMemoryPayload: dtype: str = "uint8" +class CameraDataType: + SHARED_MEMORY = "shared_memory" + JPEG_ENCODED = "jpeg_encoded" + RAW = "raw" + + @dataclass(kw_only=True) class Obs: - cameras: dict[str, np.ndarray | SharedMemoryPayload] = field(default_factory=dict) + cameras: dict[str, np.ndarray | SharedMemoryPayload | str] = field(default_factory=dict) + camera_data_type: str = CameraDataType.RAW gripper: float | None = None info: dict[str, Any] = field(default_factory=dict) - camera_data_in_shared_memory: bool = False @dataclass(kw_only=True) @@ -53,9 +60,9 @@ def initialize(self): # heavy initialization, e.g. loading models pass - def _from_shared_memory(self, obs: Obs) -> Obs: + def _to_numpy(self, obs: Obs) -> Obs: """transparently uses shared memory if configured and modifies obs in place""" - if obs.camera_data_in_shared_memory: + if obs.camera_data_type == CameraDataType.SHARED_MEMORY: camera_dict = {} for camera_name, camera_data in obs.cameras.items(): assert isinstance(camera_data, SharedMemoryPayload) @@ -65,12 +72,19 @@ def _from_shared_memory(self, obs: Obs) -> Obs: camera_data.shape, dtype=camera_data.dtype, buffer=self._shm[camera_data.shm_name].buf ) obs.cameras = camera_dict + elif obs.camera_data_type == CameraDataType.JPEG_ENCODED: + camera_dict = {} + for camera_name, camera_data in obs.cameras.items(): + assert isinstance(camera_data, str) + camera_dict[camera_name] = simplejpeg.decode_jpeg(base64.urlsafe_b64decode(camera_data)) + obs.cameras = camera_dict + obs.camera_data_type = CameraDataType.RAW return obs def act(self, obs: Obs) -> Act: assert self.instruction is not None, "forgot reset?" self.step += 1 - self._from_shared_memory(obs) + self._to_numpy(obs) return Act(action=np.zeros(7, dtype=np.float32), done=False, info={}) @@ -79,7 +93,7 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]: self.step = 0 self.episode += 1 self.instruction = instruction - self._from_shared_memory(obs) + self._to_numpy(obs) # info return {} diff --git a/src/agents/server.py b/src/agents/server.py index f16a463..bb8b3cf 100644 --- a/src/agents/server.py +++ b/src/agents/server.py @@ -10,7 +10,7 @@ import rpyc from agents.client import dataclass_from_dict -from agents.policies import Agent, Obs, SharedMemoryPayload +from agents.policies import Agent, CameraDataType, Obs, SharedMemoryPayload logging.basicConfig( format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", @@ -51,7 +51,7 @@ def act(self, obs_bytes: bytes) -> str: assert self._is_initialized, "AgentService not initialized, wait until is_initialized is True" # action, done, info obs = typing.cast(Obs, dataclass_from_dict(Obs, json_numpy.loads(obs_bytes))) - if obs.camera_data_in_shared_memory: + if obs.camera_data_type == CameraDataType.SHARED_MEMORY: obs.cameras = { camera_name: dataclass_from_dict(SharedMemoryPayload, camera_data) for camera_name, camera_data in obs.cameras.items() @@ -64,7 +64,7 @@ def reset(self, args: bytes) -> str: # info obs, instruction, kwargs = json_numpy.loads(args) obs_dclass = typing.cast(Obs, dataclass_from_dict(Obs, obs)) - if obs_dclass.camera_data_in_shared_memory: + if obs_dclass.camera_data_type == CameraDataType.SHARED_MEMORY: obs_dclass.cameras = { camera_name: dataclass_from_dict(SharedMemoryPayload, camera_data) for camera_name, camera_data in obs_dclass.cameras.items() diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 6be9823..dea4e20 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -36,6 +36,17 @@ def _test_connection(agent: RemoteAgent): assert not a1.done +def _test_connection_jpeg(agent: RemoteAgent): + data = np.zeros((256, 256, 3), dtype=np.uint8) + obs = Obs(cameras=dict(rgb_side=data)) + instruction = "do something" + reset_info = agent.reset(obs, instruction) + assert reset_info["instruction"] == instruction + assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]} + assert reset_info["dtype"] == {"rgb_side": "uint8"} + assert (reset_info["data"]["rgb_side"] == data).all() + + def test_connection_numpy_serialization(): with start_server("test", {}, 8080, "localhost") as p: sleep(2) @@ -56,3 +67,14 @@ def test_connection_numpy_shm(): sleep(0.1) _test_connection(agent) p.send_signal(subprocess.signal.SIGINT) + + +def test_connection_numpy_jpeg(): + with start_server("test", {}, 8080, "localhost") as p: + sleep(2) + agent = RemoteAgent("localhost", 8080, "test", jpeg_encoding=True) + with agent: + while not agent.is_initialized(): + sleep(0.1) + _test_connection_jpeg(agent) + p.send_signal(subprocess.signal.SIGINT)