Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"wandb",
"pillow",
"tqdm",
"simplejpeg",
]
readme = "README.md"
maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }]
Expand Down
26 changes: 20 additions & 6 deletions src/agents/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import dataclasses
from dataclasses import asdict
from multiprocessing import shared_memory
Expand All @@ -6,8 +7,9 @@
import json_numpy
import numpy as np
import rpyc
import simplejpeg

from agents.policies import Act, Agent, Obs, SharedMemoryPayload
from agents.policies import Act, Agent, CameraDataType, Obs, SharedMemoryPayload


def dataclass_from_dict(klass, d):
Expand All @@ -20,7 +22,7 @@ def dataclass_from_dict(klass, d):


class RemoteAgent(Agent):
def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False):
def __init__(self, host: str, port: int, model: str, on_same_machine: bool = False, jpeg_encoding: bool = False):
"""Connect to a remote agent service.

Args:
Expand All @@ -29,15 +31,18 @@ def __init__(self, host: str, port: int, model: str, on_same_machine: bool = Fal
model (str): Name of the model to connect to.
on_same_machine (bool, optional): If True, assumes the agent is running on the same machine and uses
shared memory for more efficient communication. Defaults to False.
jpeg_encoding (bool, optional): If True the image data is jpeg encoded for smaller transfer size.
Defaults to False.
"""
self.on_same_machine = on_same_machine
self.jpeg_encoding = jpeg_encoding
self._shm: dict[str, shared_memory.SharedMemory] = {}
self.c = rpyc.connect(
host, port, config={"allow_pickle": True, "allow_public_attrs": True, "sync_request_timeout": 300}
)
assert model == self.c.root.name()

def _to_shared_memory(self, obs: Obs) -> Obs:
def _process(self, obs: Obs) -> Obs:
if self.on_same_machine:
camera_dict = {}
for camera_name, camera_data in obs.cameras.items():
Expand All @@ -54,17 +59,26 @@ def _to_shared_memory(self, obs: Obs) -> Obs:
dtype=camera_data.dtype.name,
)
obs.cameras = camera_dict
obs.camera_data_in_shared_memory = True
obs.camera_data_type = CameraDataType.SHARED_MEMORY
elif self.jpeg_encoding:
camera_dict = {}
for camera_name, camera_data in obs.cameras.items():
assert isinstance(camera_data, np.ndarray)
camera_dict[camera_name] = base64.urlsafe_b64encode(
simplejpeg.encode_jpeg(np.ascontiguousarray(camera_data))
).decode("utf-8")
obs.cameras = camera_dict
obs.camera_data_type = CameraDataType.JPEG_ENCODED
return obs

def act(self, obs: Obs) -> Act:
obs = self._to_shared_memory(obs)
obs = self._process(obs)
obs = json_numpy.dumps(asdict(obs))
# action, done, info
return dataclass_from_dict(Act, json_numpy.loads(self.c.root.act(obs)))

def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
obs = self._to_shared_memory(obs)
obs = self._process(obs)
obs_dict = asdict(obs)
# info
return json_numpy.loads(self.c.root.reset(json_numpy.dumps((obs_dict, instruction, kwargs))))
Expand Down
26 changes: 20 additions & 6 deletions src/agents/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, Union

import numpy as np
import simplejpeg
from PIL import Image


Expand All @@ -22,12 +23,18 @@ class SharedMemoryPayload:
dtype: str = "uint8"


class CameraDataType:
SHARED_MEMORY = "shared_memory"
JPEG_ENCODED = "jpeg_encoded"
RAW = "raw"


@dataclass(kw_only=True)
class Obs:
cameras: dict[str, np.ndarray | SharedMemoryPayload] = field(default_factory=dict)
cameras: dict[str, np.ndarray | SharedMemoryPayload | str] = field(default_factory=dict)
camera_data_type: str = CameraDataType.RAW
gripper: float | None = None
info: dict[str, Any] = field(default_factory=dict)
camera_data_in_shared_memory: bool = False


@dataclass(kw_only=True)
Expand All @@ -53,9 +60,9 @@ def initialize(self):
# heavy initialization, e.g. loading models
pass

def _from_shared_memory(self, obs: Obs) -> Obs:
def _to_numpy(self, obs: Obs) -> Obs:
"""transparently uses shared memory if configured and modifies obs in place"""
if obs.camera_data_in_shared_memory:
if obs.camera_data_type == CameraDataType.SHARED_MEMORY:
camera_dict = {}
for camera_name, camera_data in obs.cameras.items():
assert isinstance(camera_data, SharedMemoryPayload)
Expand All @@ -65,12 +72,19 @@ def _from_shared_memory(self, obs: Obs) -> Obs:
camera_data.shape, dtype=camera_data.dtype, buffer=self._shm[camera_data.shm_name].buf
)
obs.cameras = camera_dict
elif obs.camera_data_type == CameraDataType.JPEG_ENCODED:
camera_dict = {}
for camera_name, camera_data in obs.cameras.items():
assert isinstance(camera_data, str)
camera_dict[camera_name] = simplejpeg.decode_jpeg(base64.urlsafe_b64decode(camera_data))
obs.cameras = camera_dict
obs.camera_data_type = CameraDataType.RAW
return obs

def act(self, obs: Obs) -> Act:
assert self.instruction is not None, "forgot reset?"
self.step += 1
self._from_shared_memory(obs)
self._to_numpy(obs)

return Act(action=np.zeros(7, dtype=np.float32), done=False, info={})

Expand All @@ -79,7 +93,7 @@ def reset(self, obs: Obs, instruction: Any, **kwargs) -> dict[str, Any]:
self.step = 0
self.episode += 1
self.instruction = instruction
self._from_shared_memory(obs)
self._to_numpy(obs)
# info
return {}

Expand Down
6 changes: 3 additions & 3 deletions src/agents/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import rpyc

from agents.client import dataclass_from_dict
from agents.policies import Agent, Obs, SharedMemoryPayload
from agents.policies import Agent, CameraDataType, Obs, SharedMemoryPayload

logging.basicConfig(
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
Expand Down Expand Up @@ -51,7 +51,7 @@ def act(self, obs_bytes: bytes) -> str:
assert self._is_initialized, "AgentService not initialized, wait until is_initialized is True"
# action, done, info
obs = typing.cast(Obs, dataclass_from_dict(Obs, json_numpy.loads(obs_bytes)))
if obs.camera_data_in_shared_memory:
if obs.camera_data_type == CameraDataType.SHARED_MEMORY:
obs.cameras = {
camera_name: dataclass_from_dict(SharedMemoryPayload, camera_data)
for camera_name, camera_data in obs.cameras.items()
Expand All @@ -64,7 +64,7 @@ def reset(self, args: bytes) -> str:
# info
obs, instruction, kwargs = json_numpy.loads(args)
obs_dclass = typing.cast(Obs, dataclass_from_dict(Obs, obs))
if obs_dclass.camera_data_in_shared_memory:
if obs_dclass.camera_data_type == CameraDataType.SHARED_MEMORY:
obs_dclass.cameras = {
camera_name: dataclass_from_dict(SharedMemoryPayload, camera_data)
for camera_name, camera_data in obs_dclass.cameras.items()
Expand Down
22 changes: 22 additions & 0 deletions src/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ def _test_connection(agent: RemoteAgent):
assert not a1.done


def _test_connection_jpeg(agent: RemoteAgent):
data = np.zeros((256, 256, 3), dtype=np.uint8)
obs = Obs(cameras=dict(rgb_side=data))
instruction = "do something"
reset_info = agent.reset(obs, instruction)
assert reset_info["instruction"] == instruction
assert reset_info["shapes"] == {"rgb_side": [256, 256, 3]}
assert reset_info["dtype"] == {"rgb_side": "uint8"}
assert (reset_info["data"]["rgb_side"] == data).all()


def test_connection_numpy_serialization():
with start_server("test", {}, 8080, "localhost") as p:
sleep(2)
Expand All @@ -56,3 +67,14 @@ def test_connection_numpy_shm():
sleep(0.1)
_test_connection(agent)
p.send_signal(subprocess.signal.SIGINT)


def test_connection_numpy_jpeg():
with start_server("test", {}, 8080, "localhost") as p:
sleep(2)
agent = RemoteAgent("localhost", 8080, "test", jpeg_encoding=True)
with agent:
while not agent.is_initialized():
sleep(0.1)
_test_connection_jpeg(agent)
p.send_signal(subprocess.signal.SIGINT)