diff --git a/Tsimulation/README.md b/Tsimulation/README.md index 7c667ed52..9189aefed 100644 --- a/Tsimulation/README.md +++ b/Tsimulation/README.md @@ -109,6 +109,9 @@ Metadata in `store.attrs`: `embodiment="pushshapes_sim"`, `task_name="pushshapes"`, `task_description` (JSON of env args), `total_frames`, `fps`, `features`. +New bulk-written collections use the compact one-chunk-per-array Zarr layout, +matching the output of `scripts/rechunk_zarr_dataset.py`. + Full schema rationale is in [`SCHEMA_NOTES.md`](SCHEMA_NOTES.md). ## Plugging into EgoVerse training diff --git a/Tsimulation/SCHEMA_NOTES.md b/Tsimulation/SCHEMA_NOTES.md index 303201da2..388a3002d 100644 --- a/Tsimulation/SCHEMA_NOTES.md +++ b/Tsimulation/SCHEMA_NOTES.md @@ -46,13 +46,13 @@ Each `episode_{idx}.zarr/` is opened as a Zarr v3 group. Two kinds of arrays: 1. **Numeric arrays** — one per feature, shape `(T, ...)`. Examples used in the repo: `observations.state`, `actions`, `actions_joints`, - `actions_cartesian`. Chunks `(chunk_timesteps, ...)` (default 100), with - sharding aligned to the padded length. + `actions_cartesian`. Bulk-written episodes are stored as one chunk per + array so new collections match the compact rechunked layout. 2. **Image arrays** — one per camera, shape `(T,)` of `zarr.core.dtype.VariableLengthBytes`. Each element is a JPEG-encoded frame produced by `simplejpeg.encode_jpeg(img, quality=85, colorspace="RGB")`. - Chunks are `(1,)` so frames can be decoded independently. Image keys are - prefixed `observations.images.{cam}` by convention. + Bulk-written episodes are stored as one chunk for the whole image array. + Image keys are prefixed `observations.images.{cam}` by convention. ## Required metadata (store.attrs) @@ -71,7 +71,8 @@ self.metadata["features"]`, so every store must populate: ## Image format - Encoding: JPEG via `simplejpeg.encode_jpeg(..., quality=85, colorspace="RGB")`. -- Storage: `zarr.core.dtype.VariableLengthBytes()` dtype, one frame per chunk. +- Storage: `zarr.core.dtype.VariableLengthBytes()` dtype, one chunk per + episode for bulk-written collections. - Shape recorded in `features[key]["shape"]` as `[H, W, 3]`, dimension names `["height", "width", "channel"]`, `dtype` string `"jpeg"`. - Most EgoVerse data is high-resolution (e.g. 480x640); for PushShapes we use @@ -81,11 +82,9 @@ self.metadata["features"]`, so every store must populate: ## Compression -EgoVerse uses **Zarr v3 sharding** plus JPEG for images. The task prompt -called for `numcodecs.Blosc(cname='lz4')` — that is a Zarr v2 convention and -does not apply here. The existing `ZarrWriter` already picks the right -codecs by relying on Zarr v3's defaults; we delegate to it rather than -re-deciding. +EgoVerse uses Zarr v3 plus JPEG for images. The compact bulk-write path now +matches the layout produced by `scripts/rechunk_zarr_dataset.py`, which keeps +per-episode file counts low while preserving the same data values. ## Embodiment diff --git a/Tsimulation/collect/balance.py b/Tsimulation/collect/balance.py new file mode 100644 index 000000000..eda144eb9 --- /dev/null +++ b/Tsimulation/collect/balance.py @@ -0,0 +1,169 @@ +"""Coverage-balanced collection helpers. + +Buckets each episode by (object-start quadrant, goal quadrant), giving 16 +cells (4 x 4). Used by the mouse / scripted collectors to reject episodes +whose bucket is already at quota, so the saved dataset has roughly equal +coverage across all (start, goal) combinations. + +Convention: pygame / pymunk image coords, origin at top-left, +y down. +So "top-right" on screen is high-x, low-y. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Callable + +import zarr + +N_QUADRANTS = 4 +N_BUCKETS = N_QUADRANTS * N_QUADRANTS # 16 (object-quad x goal-quad) +N_PUSHER_BUCKETS = N_QUADRANTS # 4 (pusher quadrant only) + +# Visual labels for buckets when printing the histogram. +_QUADRANT_LABELS = ("TL", "TR", "BL", "BR") + +BucketFn = Callable[[dict, float], int] + + +def quadrant(x: float, y: float, world_size: float) -> int: + """0=TL, 1=TR, 2=BL, 3=BR for a point in a world_size x world_size arena.""" + mid = world_size / 2.0 + qx = 1 if x >= mid else 0 + qy = 1 if y >= mid else 0 + return qy * 2 + qx + + +def bucket_for(episode_init: dict, world_size: float) -> int: + """Bucket id 0..15 = object_quadrant * 4 + goal_quadrant.""" + obj_x, obj_y = episode_init["object_pose"][:2] + goal_x, goal_y = episode_init["goal_pose"][:2] + obj_q = quadrant(float(obj_x), float(obj_y), world_size) + goal_q = quadrant(float(goal_x), float(goal_y), world_size) + return obj_q * N_QUADRANTS + goal_q + + +def bucket_pusher_quad(episode_init: dict, world_size: float) -> int: + """Bucket id 0..3 = pusher quadrant only. + + Intended for scenarios where the object/goal positions are constrained + (e.g. object starts on the goal) and the only meaningful spatial signal + is where the pusher starts.""" + ax, ay = episode_init["agent_pos"][:2] + return quadrant(float(ax), float(ay), world_size) + + +def count_existing_buckets( + folder: Path, + world_size: float, + bucket_fn: BucketFn = bucket_for, + num_buckets: int = N_BUCKETS, + filename_contains: str | None = None, +) -> list[int]: + """Scan ``folder`` for ``*.zarr`` episode groups and tally per-bucket counts + from each one's stored ``episode_init`` attribute. + + ``bucket_fn`` decides which bucket an episode falls in. + ``num_buckets`` sizes the returned list (must agree with bucket_fn's range). + ``filename_contains``, if set, restricts the scan to entries whose name + contains that substring (e.g. ``"ontarget"`` to count only tagged files).""" + counts = [0] * num_buckets + folder = Path(folder) + if not folder.exists(): + return counts + for entry in sorted(folder.iterdir()): + if not entry.is_dir() or not entry.name.endswith(".zarr"): + continue + if filename_contains is not None and filename_contains not in entry.name: + continue + try: + group = zarr.open_group(str(entry), mode="r") + raw = group.attrs.get("episode_init") + if raw is None: + continue + ep_init = json.loads(raw) + counts[bucket_fn(ep_init, world_size)] += 1 + except Exception: + continue + return counts + + +class BucketTracker: + """Per-bucket counter with an acceptance test and a printable histogram. + + Defaults to 16 buckets laid out as a 4x4 (object-quadrant x goal-quadrant) + grid for the standard collect scenario. Pass ``num_buckets=N_PUSHER_BUCKETS`` + for the 4-bucket pusher-only scheme used by on-target collection.""" + + def __init__( + self, + target_per_bucket: int, + initial_counts: list[int] | None = None, + num_buckets: int = N_BUCKETS, + ): + if target_per_bucket < 1: + raise ValueError("target_per_bucket must be >= 1") + if num_buckets not in (N_BUCKETS, N_PUSHER_BUCKETS): + raise ValueError( + f"num_buckets must be {N_BUCKETS} or {N_PUSHER_BUCKETS}, got {num_buckets}" + ) + self.target = int(target_per_bucket) + self.num_buckets = int(num_buckets) + if initial_counts is None: + self.counts = [0] * num_buckets + else: + if len(initial_counts) != num_buckets: + raise ValueError( + f"initial_counts must have {num_buckets} entries, got {len(initial_counts)}" + ) + self.counts = [int(c) for c in initial_counts] + + def has_room(self, b: int) -> bool: + return self.counts[b] < self.target + + def increment(self, b: int) -> None: + self.counts[b] += 1 + + @property + def total(self) -> int: + return sum(self.counts) + + @property + def goal_total(self) -> int: + return self.target * self.num_buckets + + @property + def filled(self) -> bool: + return all(c >= self.target for c in self.counts) + + def histogram(self) -> str: + if self.num_buckets == N_BUCKETS: + return self._histogram_4x4() + return self._histogram_pusher_row() + + def _histogram_4x4(self) -> str: + """4x4 grid: rows = object-start quadrant, cols = goal quadrant.""" + header = " goal: " + " ".join(f"{l:>3}" for l in _QUADRANT_LABELS) + rows = [header] + for i, obj_lbl in enumerate(_QUADRANT_LABELS): + cells = [] + for j in range(N_QUADRANTS): + c = self.counts[i * N_QUADRANTS + j] + marker = "*" if c >= self.target else " " + cells.append(f"{c:>2}{marker}") + rows.append(f"object {obj_lbl}: " + " ".join(cells)) + rows.append(f" total: {self.total}/{self.goal_total} (target {self.target}/bucket)") + return "\n".join(rows) + + def _histogram_pusher_row(self) -> str: + """Single-row 4-bucket layout for pusher-quadrant balancing.""" + header = "pusher quad: " + " ".join(f"{l:>4}" for l in _QUADRANT_LABELS) + cells = [] + for i in range(N_PUSHER_BUCKETS): + c = self.counts[i] + marker = "*" if c >= self.target else " " + cells.append(f"{c:>3}{marker}") + body = " " + " ".join(cells) + footer = f" total: {self.total}/{self.goal_total} (target {self.target}/bucket)" + return "\n".join([header, body, footer]) diff --git a/Tsimulation/collect/gympusht_collect.py b/Tsimulation/collect/gympusht_collect.py new file mode 100644 index 000000000..d2b9fe9fa --- /dev/null +++ b/Tsimulation/collect/gympusht_collect.py @@ -0,0 +1,405 @@ +"""Mouse-driven demonstration collection on the upstream huggingface/gym-pusht +``PushT-v0`` environment, saving to our zarr schema. + +This script does NOT modify or reimplement gym-pusht — it drives the env via +the public ``gymnasium`` API (``gym.make`` + ``reset`` + ``step`` + ``render``). +gym-pusht requires pymunk <7, so run this from the isolated venv that has +that constraint:: + + PYTHONPATH=. .venv-gympusht/bin/python -m Tsimulation.collect.gympusht_collect \\ + --output-dir data/pushshapes_demos/gympusht_circle \\ + --num-episodes 50 + +Hotkeys (pygame window must have focus): + SPACE start / pause recording in the current episode + S commit the current episode and reset for the next + R abort the current episode (discard buffer) and reset + Q / X flush and exit + +Zarr schema written matches existing PushShapesEnv collections so downstream +training / replay code works unchanged: + + observations.state (T, 5) [agent_x, agent_y, obj_x, obj_y, obj_theta] + observations.images.front_img_1 (T,) JPEG-encoded 96x96x3 + observations.pusher_cmd_pose (T, 2) commanded mouse target + actions (T, 2) same as cmd (gym-pusht action == target) + reward (T, 1) env reward + goal_pose (T, 3) constant per-episode, broadcast + annotations (0,) empty (parity with mouse_collect) + +env_args saved in ``task_description`` mark these episodes as ``collector="gympusht"`` +so a downstream consumer can route them differently from PushShapesEnv data. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path + +import numpy as np +import pygame + +import gym_pusht # noqa: F401 — side-effect registers gym_pusht/PushT-v0 +import gymnasium as gym + +from Tsimulation.collect.balance import ( + N_BUCKETS, + BucketTracker, + bucket_for, + count_existing_buckets, +) +from Tsimulation.collect.zarr_writer import ZarrDemoWriter + +ENV_ID = "gym_pusht/PushT-v0" +WORLD_SIZE = 512 # gym-pusht's action range and image-space arena + +WINDOW_SCALE = 2 # display window = WORLD_SIZE * WINDOW_SCALE +WINDOW_SIZE = WORLD_SIZE * WINDOW_SCALE + +OVERLAY_COLOR = (20, 20, 20) +RECORDING_COLOR = (210, 60, 60) +PAUSED_COLOR = (180, 140, 0) +OVERLAY_HEIGHT = 92 +OVERLAY_BG = (255, 255, 255, 200) + +_BALANCE_MAX_REDRAWS = 200 + + +def _draw_overlay( + screen: pygame.Surface, + font: pygame.font.Font, + *, + saved: int, + target: int, + step: int, + coverage: float, + recording: bool, + output_path: Path, + next_idx: int, +) -> None: + panel = pygame.Surface((WINDOW_SIZE, OVERLAY_HEIGHT), pygame.SRCALPHA) + panel.fill(OVERLAY_BG) + screen.blit(panel, (0, 0)) + + status, color = ("REC", RECORDING_COLOR) if recording else ("PAUSED", PAUSED_COLOR) + badge = font.render(status, True, color) + screen.blit(badge, (WINDOW_SIZE - badge.get_width() - 10, 6)) + + lines = [ + f"saved {saved}/{target} next idx={next_idx:06d}", + f"step {step} coverage {coverage * 100:5.1f}%", + f"out: {output_path}", + "[SPACE] record [S] save [R] abort [Q] quit", + ] + for i, line in enumerate(lines): + screen.blit(font.render(line, True, OVERLAY_COLOR), (8, 6 + i * 20)) + + +def _state_from(obs: dict, info: dict) -> np.ndarray: + """Pack agent_pos (from obs) + block_pose (from info) into our 5-vec + `observations.state` layout `[agent_x, agent_y, obj_x, obj_y, obj_theta]`.""" + agent = np.asarray(obs["agent_pos"], dtype=np.float64).reshape(-1) + block = np.asarray(info["block_pose"], dtype=np.float64).reshape(-1) + return np.concatenate([agent[:2], block[:3]], axis=0) + + +def _goal_pose_from(info: dict) -> np.ndarray: + return np.asarray(info["goal_pose"], dtype=np.float64).reshape(-1) + + +def _episode_init_from(obs: dict, info: dict) -> dict: + """Stable JSON-serializable dict mirroring PushShapesEnv.get_episode_init(). + + Useful for replay-init resume parity (matching new_circle_2's episode_init + layout) and for documenting what env produced this episode.""" + agent = np.asarray(obs["agent_pos"], dtype=np.float64).tolist() + block = np.asarray(info["block_pose"], dtype=np.float64).tolist() + goal = np.asarray(info["goal_pose"], dtype=np.float64).tolist() + return { + "agent_pos": list(agent[:2]), + "object_pose": list(block[:3]), + "goal_pose": list(goal[:3]), + "object_shape": "T", # gym-pusht is T-only + "pusher_shape": "circle", # gym-pusht's pusher is a fixed circle + "obstacle_level": 0, # gym-pusht has no obstacles + "obstacles": [], + "reset_seed": int(info["reset_seed"]) if "reset_seed" in info else None, + } + + +def _resize_render(rgb: np.ndarray, target: int) -> pygame.Surface: + """Render frame -> pygame Surface scaled to ``target`` square.""" + # rgb is (H, W, 3) uint8. pygame wants (W, H, 3) for make_surface. + surf = pygame.surfarray.make_surface(np.transpose(rgb, (1, 0, 2))) + if surf.get_width() != target: + surf = pygame.transform.scale(surf, (target, target)) + return surf + + +def _reset_with_balance( + env: gym.Env, + tracker: BucketTracker | None, + seed: int, +) -> tuple[dict, dict, int]: + """Reset (possibly multiple times if balanced) until landing in a bucket + with room. Returns (obs, info, bucket_id).""" + obs, info = env.reset(seed=seed) + if tracker is None: + return obs, info, -1 + for _ in range(_BALANCE_MAX_REDRAWS): + ep_init = _episode_init_from(obs, info) + b = bucket_for(ep_init, float(WORLD_SIZE)) + if tracker.has_room(b): + return obs, info, b + seed += 1 + obs, info = env.reset(seed=seed) + ep_init = _episode_init_from(obs, info) + return obs, info, bucket_for(ep_init, float(WORLD_SIZE)) + + +def run(args: argparse.Namespace) -> int: + if not args.output_dir: + print("error: --output-dir is required", file=sys.stderr) + return 2 + + pygame.init() + pygame.display.init() + pygame.font.init() + pygame.display.set_caption(f"gym-pusht mouse collect [{ENV_ID}]") + screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE)) + clock = pygame.time.Clock() + font = pygame.font.Font(None, 22) + + # gym-pusht: pixels_agent_pos gives 96x96 image + agent_pos in one obs dict; + # info carries block_pose / goal_pose / coverage that our schema needs. + env = gym.make( + ENV_ID, + obs_type="pixels_agent_pos", + render_mode="rgb_array", + ) + + env_args = { + "env_id": ENV_ID, + "object_shape": "T", + "pusher_shape": "circle", + "obstacle_level": 0, + "image_size": args.image_size, + "fps": args.fps, + "collector": "gympusht", + "mode": "standard", + } + output_dir = Path(args.output_dir) + + writer = ZarrDemoWriter( + path=output_dir, + env_args=env_args, + image_size=args.image_size, + fps=args.fps, + tag=args.tag, + ) + + tracker: BucketTracker | None = None + if args.balance: + per_bucket = args.per_bucket + if per_bucket is None: + per_bucket = max(1, -(-args.num_episodes // N_BUCKETS)) # ceil div + initial_counts = count_existing_buckets( + output_dir, + float(WORLD_SIZE), + bucket_fn=bucket_for, + num_buckets=N_BUCKETS, + filename_contains=args.tag, + ) + tracker = BucketTracker( + target_per_bucket=per_bucket, + initial_counts=initial_counts, + num_buckets=N_BUCKETS, + ) + target_total = tracker.goal_total + print( + f"balance: target {per_bucket}/bucket x {N_BUCKETS} buckets " + f"= {target_total} episodes (found {sum(initial_counts)} pre-existing " + f"in {output_dir})" + ) + if sum(initial_counts): + print(tracker.histogram()) + else: + target_total = args.num_episodes + + seed_counter = int(args.seed) + obs, info, current_bucket = _reset_with_balance(env, tracker, seed_counter) + seed_counter += 1 + coverage = info.get("coverage", 0.0) + writer.start_episode(init_state=_episode_init_from(obs, info)) + recording = True + saved = 0 + running = True + + while running: + if tracker is not None and tracker.filled: + break + if tracker is None and saved >= target_total: + break + + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + elif event.type == pygame.KEYDOWN: + if event.key in (pygame.K_q, pygame.K_x): + running = False + elif event.key == pygame.K_SPACE: + recording = not recording + elif event.key == pygame.K_r: + writer.abort_episode() + obs, info, current_bucket = _reset_with_balance(env, tracker, seed_counter) + seed_counter += 1 + coverage = info.get("coverage", 0.0) + writer.start_episode(init_state=_episode_init_from(obs, info)) + recording = True + elif event.key == pygame.K_s: + if writer.steps_in_episode > 0: + idx = writer.commit_episode() + if idx >= 0: + saved += 1 + if tracker is not None: + tracker.increment(current_bucket) + print( + f"saved episode {idx:06d} bucket={current_bucket:>2} " + f"({saved}/{target_total})" + ) + if tracker is not None: + print(tracker.histogram()) + obs, info, current_bucket = _reset_with_balance(env, tracker, seed_counter) + seed_counter += 1 + coverage = info.get("coverage", 0.0) + writer.start_episode(init_state=_episode_init_from(obs, info)) + recording = True + + # Mouse XY -> action in env coords [0, WORLD_SIZE]. + mx, my = pygame.mouse.get_pos() + action = np.array( + [ + float(np.clip(mx / WINDOW_SCALE, 0.0, float(WORLD_SIZE))), + float(np.clip(my / WINDOW_SCALE, 0.0, float(WORLD_SIZE))), + ], + dtype=np.float32, + ) + + # Pre-step image (96x96 from obs) and state, so (state[t], action[t]) + # are aligned: state is the state BEFORE action[t] is applied. + pre_pixels = np.asarray(obs["pixels"], dtype=np.uint8) + pre_state = _state_from(obs, info) + pre_goal = _goal_pose_from(info) + + obs, reward, terminated, truncated, info = env.step(action) + coverage = info.get("coverage", 0.0) + + if recording: + writer.add_step( + image=pre_pixels, + pusher_obs_pose=pre_state[:2], + object_obs_pose=pre_state[2:5], + pusher_cmd_pose=action.astype(np.float64), + action=action.astype(np.float64), + reward=float(reward), + goal_pose=pre_goal, + ) + + # Display: render the env at its native 680x680, scale to window. + render_rgb = env.render() + if render_rgb is not None: + surf = _resize_render(render_rgb, WINDOW_SIZE) + screen.blit(surf, (0, 0)) + _draw_overlay( + screen, + font, + saved=saved, + target=target_total, + step=writer.steps_in_episode, + coverage=coverage, + recording=recording, + output_path=output_dir, + next_idx=writer.next_episode_index, + ) + pygame.display.flip() + clock.tick(args.fps) + + if terminated or truncated: + if writer.steps_in_episode > 0 and terminated: + idx = writer.commit_episode() + if idx >= 0: + saved += 1 + if tracker is not None: + tracker.increment(current_bucket) + print( + f"auto-saved episode {idx:06d} bucket={current_bucket:>2} " + f"({saved}/{target_total})" + ) + if tracker is not None: + print(tracker.histogram()) + else: + writer.abort_episode() + more_to_collect = ( + not tracker.filled if tracker is not None else saved < target_total + ) + if more_to_collect: + obs, info, current_bucket = _reset_with_balance(env, tracker, seed_counter) + seed_counter += 1 + coverage = info.get("coverage", 0.0) + writer.start_episode(init_state=_episode_init_from(obs, info)) + recording = True + + writer.close() + env.close() + pygame.display.quit() + pygame.quit() + print(f"done. saved {saved} episodes to {output_dir}") + if tracker is not None: + print(tracker.histogram()) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--output-dir", + required=True, + help="directory to write episode_*.zarr stores into", + ) + p.add_argument("--num-episodes", type=int, default=50) + p.add_argument( + "--image-size", + type=int, + default=96, + help="must match gym-pusht's pixels obs (96 in upstream). Stored as-is.", + ) + p.add_argument("--fps", type=int, default=30) + p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--balance", + action="store_true", + help="balance saved episodes across the 16 (object_quadrant, goal_quadrant)" + " buckets via rejection-sampling at each reset", + ) + p.add_argument( + "--per-bucket", + type=int, + default=None, + help="with --balance: episodes per bucket. Default ceil(num_episodes/16)", + ) + p.add_argument( + "--tag", + default=None, + help="alphanumeric tag inserted into saved filenames (e.g. 'gympusht')", + ) + return p + + +def main(argv: list[str] | None = None) -> int: + return run(build_parser().parse_args(argv)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/Tsimulation/collect/mouse_collect.py b/Tsimulation/collect/mouse_collect.py index e19bb5e99..83acc028e 100644 --- a/Tsimulation/collect/mouse_collect.py +++ b/Tsimulation/collect/mouse_collect.py @@ -24,15 +24,32 @@ from __future__ import annotations import argparse +import json +import math import sys from pathlib import Path import numpy as np import pygame - +import zarr + +from Tsimulation.collect.balance import ( + N_BUCKETS, + N_PUSHER_BUCKETS, + BucketTracker, + bucket_for, + bucket_pusher_quad, + count_existing_buckets, +) from Tsimulation.collect.zarr_writer import ZarrDemoWriter from Tsimulation.pushshapes.env import PushShapesEnv +_BALANCE_MAX_REDRAWS = 200 # cap rejection-sampling per reset before bailing + +_MODE_STANDARD = "standard" +_MODE_ON_TARGET = "on-target" +_DEFAULT_ONTARGET_TAG = "ontarget" + WORLD_SIZE = 512 WINDOW_SCALE = 2 WINDOW_SIZE = WORLD_SIZE * WINDOW_SCALE @@ -80,7 +97,180 @@ def _episode_output_dir(root: Path, pusher: str, obstacles: int) -> Path: return root / pusher / str(obstacles) +def _apply_on_target( + env: PushShapesEnv, + rng: np.random.Generator, + min_angle_rad: float, +) -> None: + """Override the object pose so it spawns at the goal's (x, y) with an + intentionally-bad orientation: angle drawn uniformly in [-pi, pi] but + rejected if the wrap-aware delta to the goal angle is below + ``min_angle_rad``. Pusher and goal are left as the env sampled them.""" + init = env.get_episode_init() + goal_x, goal_y, goal_theta = init["goal_pose"] + # Rejection sample a "sufficiently wrong" angle. + for _ in range(64): + candidate = float(rng.uniform(-math.pi, math.pi)) + delta = abs(((candidate - goal_theta + math.pi) % (2 * math.pi)) - math.pi) + if delta >= min_angle_rad: + new_theta = candidate + break + else: + # extremely unlikely; fall back to +180deg from goal + new_theta = goal_theta + math.pi + env.set_state( + object_pose=(float(goal_x), float(goal_y), float(new_theta)) + ) + + +_REPLAY_SOURCE_KEY = "_replay_source" + + +def _load_replay_inits(source_dir: Path) -> list[dict]: + """Read ``episode_init`` from every ``*.zarr`` under ``source_dir`` (sorted + by filename). Each returned init has ``_replay_source`` injected with the + source episode's filename so resume can skip already-recorded ones.""" + inits: list[dict] = [] + for entry in sorted(source_dir.iterdir()): + if not entry.is_dir() or not entry.name.endswith(".zarr"): + continue + try: + g = zarr.open_group(str(entry), mode="r") + raw = g.attrs.get("episode_init") + if raw is None: + continue + init = json.loads(raw) + init[_REPLAY_SOURCE_KEY] = entry.name + inits.append(init) + except Exception: + continue + return inits + + +def _init_pose_key(init: dict) -> tuple: + """Resume key for output episodes without a ``_replay_source`` marker. + + Includes ``obstacle_level`` so an obs0 output episode never spuriously + matches a non-obs0 source episode that happens to share the same pose + (the source dataset contains several such cross-level pose duplicates). + 6-decimal rounding absorbs JSON round-trip jitter without collapsing + genuinely distinct sampled poses.""" + def _r(xs): + return tuple(round(float(v), 6) for v in xs) + return ( + _r(init["agent_pos"]), + _r(init["object_pose"]), + _r(init["goal_pose"]), + int(init.get("obstacle_level", -1)), + ) + + +def _collected_resume_keys(output_dir: Path) -> tuple[set[str], set[tuple]]: + """Scan ``output_dir`` for already-collected replay episodes. + + Returns ``(source_names, pose_keys)`` — an episode is considered already + done if either set matches a source init. Pose-key fallback covers + output episodes written before the ``_replay_source`` marker existed.""" + names: set[str] = set() + pose_keys: set[tuple] = set() + if not output_dir.exists(): + return names, pose_keys + for entry in sorted(output_dir.iterdir()): + if not entry.is_dir() or not entry.name.endswith(".zarr"): + continue + try: + g = zarr.open_group(str(entry), mode="r") + raw = g.attrs.get("episode_init") + if raw is None: + continue + init = json.loads(raw) + saved = init.get(_REPLAY_SOURCE_KEY) + if saved: + names.add(saved) + pose_keys.add(_init_pose_key(init)) + except Exception: + continue + return names, pose_keys + + +def _reset_to_init(env: PushShapesEnv, init: dict) -> tuple[dict, dict]: + """Reset, then override the env's sampled poses with the saved init so the + episode starts from the exact (agent, object, goal) configuration.""" + _obs, info = env.reset() + env.set_state( + agent_pos=tuple(init["agent_pos"]), + object_pose=tuple(init["object_pose"]), + goal_pose=tuple(init["goal_pose"]), + ) + return env._get_obs(), info + + +def _reset_with_balance( + env: PushShapesEnv, + tracker: BucketTracker | None, + bucket_fn, + on_target_rng: np.random.Generator | None = None, + min_angle_rad: float = 0.0, +) -> tuple[dict, dict, int]: + """Call env.reset() (repeatedly if balanced) until an episode lands in a + bucket that still has room. ``bucket_fn`` decides how to bucket the + post-reset state. If ``on_target_rng`` is provided, the object is moved + onto the goal with a bad orientation after each reset, BEFORE the bucket + check (so balancing reflects the actual saved state).""" + obs, info = env.reset() + if on_target_rng is not None: + _apply_on_target(env, on_target_rng, min_angle_rad) + obs = env._get_obs() + if tracker is None: + b = bucket_fn(env.get_episode_init(), float(env.WORLD_SIZE)) + return obs, info, b + for _ in range(_BALANCE_MAX_REDRAWS): + b = bucket_fn(env.get_episode_init(), float(env.WORLD_SIZE)) + if tracker.has_room(b): + return obs, info, b + obs, info = env.reset() + if on_target_rng is not None: + _apply_on_target(env, on_target_rng, min_angle_rad) + obs = env._get_obs() + # All redraws landed in full buckets — accept whatever we have so the UI + # doesn't wedge. Caller's bucket counter will overshoot at most by one. + b = bucket_fn(env.get_episode_init(), float(env.WORLD_SIZE)) + return obs, info, b + + def run(args: argparse.Namespace) -> int: + if not args.output and not args.output_dir: + print("error: provide either --output or --output-dir", file=sys.stderr) + return 2 + + replay_inits: list[dict] | None = None + if args.replay_init_from is not None: + if args.balance: + print( + "error: --balance is incompatible with --replay-init-from " + "(replay uses the exact saved poses)", + file=sys.stderr, + ) + return 2 + if args.mode != _MODE_STANDARD: + print( + "error: --replay-init-from only supports --mode standard", + file=sys.stderr, + ) + return 2 + src = Path(args.replay_init_from) + if not src.is_dir(): + print(f"error: --replay-init-from {src} is not a directory", file=sys.stderr) + return 2 + replay_inits = _load_replay_inits(src) + if not replay_inits: + print( + f"error: no episodes with stored episode_init found under {src}", + file=sys.stderr, + ) + return 2 + print(f"replay-init: loaded {len(replay_inits)} inits from {src}") + pygame.init() pygame.display.init() pygame.font.init() @@ -107,26 +297,191 @@ def run(args: argparse.Namespace) -> int: "image_size": args.image_size, "fps": args.fps, "collector": "mouse", + "mode": args.mode, } - output_dir = _episode_output_dir(Path(args.output), args.pusher, args.obstacles) + output_dir = ( + Path(args.output_dir) + if args.output_dir + else _episode_output_dir(Path(args.output), args.pusher, args.obstacles) + ) + + # Replay-init filter + resume. + if replay_inits is not None: + # 1. Drop source inits whose obstacle_level doesn't match the live env. + # This is the safety net for the "leaked non-obs0 inits got replayed in + # an obs0 env" failure mode — without it, iterating an unfiltered source + # dataset against a fixed --obstacles will silently mis-record poses. + before = len(replay_inits) + replay_inits = [ + init + for init in replay_inits + if int(init.get("obstacle_level", -1)) == int(args.obstacles) + ] + wrong_level = before - len(replay_inits) + if wrong_level: + print( + f"replay-init: dropped {wrong_level} source inits whose " + f"obstacle_level != {args.obstacles} (env's --obstacles)" + ) + if not replay_inits: + print("nothing to collect after obstacle-level filter — exiting.") + return 0 + + # 2. Resume: skip inits already covered in output_dir, by source-name + # marker OR by pose-key fallback (output episodes written before the + # _replay_source marker existed are matched by pose alone). + saved_names, saved_pose_keys = _collected_resume_keys(output_dir) + if saved_names or saved_pose_keys: + before = len(replay_inits) + replay_inits = [ + init + for init in replay_inits + if init[_REPLAY_SOURCE_KEY] not in saved_names + and _init_pose_key(init) not in saved_pose_keys + ] + skipped = before - len(replay_inits) + print( + f"replay-init resume: skipping {skipped} inits already saved " + f"in {output_dir}; {len(replay_inits)} remaining" + ) + if not replay_inits: + print("nothing left to collect — exiting.") + return 0 + + # Mode-specific knobs: bucketing function, writer tag, on-target rng/angle. + if args.mode == _MODE_ON_TARGET: + bucket_fn = bucket_pusher_quad + num_buckets = N_PUSHER_BUCKETS + writer_tag = args.tag if args.tag is not None else _DEFAULT_ONTARGET_TAG + on_target_rng = np.random.default_rng(args.seed) + min_angle_rad = math.radians(args.min_angle_deg) + env_args["min_angle_deg"] = args.min_angle_deg + else: + bucket_fn = bucket_for + num_buckets = N_BUCKETS + writer_tag = args.tag # may be None + on_target_rng = None + min_angle_rad = 0.0 + + # --fixed-goal: parse X,Y,THETA and apply to every reset so the goal pose + # is constant across the dataset (matches single-goal benchmarks like + # Diffusion Policy's PushT). Random agent/object positions are preserved. + fixed_goal: tuple[float, float, float] | None = None + if args.fixed_goal is not None: + try: + parts = [float(x) for x in args.fixed_goal.split(",")] + except ValueError: + print( + f"error: --fixed-goal must be three comma-separated floats X,Y,THETA, got {args.fixed_goal!r}", + file=sys.stderr, + ) + return 2 + if len(parts) != 3: + print( + f"error: --fixed-goal must be exactly X,Y,THETA, got {len(parts)} values", + file=sys.stderr, + ) + return 2 + fixed_goal = (parts[0], parts[1], parts[2]) + env_args["fixed_goal"] = list(fixed_goal) + if args.balance and args.mode == _MODE_STANDARD: + print( + "warning: --fixed-goal + --balance --mode standard collapses the " + "goal_quadrant axis of the 16 bucket grid (all goals → one quadrant); " + "effective balancing reduces to 4 object-quadrant buckets.", + file=sys.stderr, + ) + writer = ZarrDemoWriter( path=output_dir, env_args=env_args, image_size=args.image_size, fps=args.fps, + tag=writer_tag, ) - obs, info = env.reset() + init_idx = 0 # only used in replay mode; advances on commit, stays on abort. + + def _do_reset() -> tuple[dict, dict, int]: + """Reset the env using whichever mode is active. + + Returns ``(obs, info, current_bucket)``. ``current_bucket`` is -1 in + replay mode (no bucketing) — the caller's tracker is also None then. + + If ``--fixed-goal`` is set, the env's randomized goal is overridden + with the fixed pose AFTER the (possibly bucket-rejection-sampled) + reset, so agent + object stays random but goal is constant.""" + if replay_inits is not None: + obs, info = _reset_to_init(env, replay_inits[init_idx]) + return obs, info, -1 + obs, info, bucket = _reset_with_balance( + env, tracker, bucket_fn, on_target_rng, min_angle_rad + ) + if fixed_goal is not None: + env.set_state(goal_pose=fixed_goal) + obs = env._get_obs() + return obs, info, bucket + + def _make_init_state() -> dict: + """Init dict to hand to the writer at start_episode. In replay mode, + stamps the source episode filename onto the saved init so resume can + skip it on a later run.""" + init = env.get_episode_init() + if replay_inits is not None and init_idx < len(replay_inits): + init[_REPLAY_SOURCE_KEY] = replay_inits[init_idx][_REPLAY_SOURCE_KEY] + return init + + tracker: BucketTracker | None = None + if args.balance: + per_bucket = args.per_bucket + if per_bucket is None: + per_bucket = max(1, -(-args.num_episodes // num_buckets)) # ceil div + # Count only same-tag files so on-target and standard episodes have + # independent bucket budgets in a mixed folder. + filename_filter = writer_tag if writer_tag is not None else None + initial_counts = count_existing_buckets( + output_dir, + float(env.WORLD_SIZE), + bucket_fn=bucket_fn, + num_buckets=num_buckets, + filename_contains=filename_filter, + ) + tracker = BucketTracker( + target_per_bucket=per_bucket, + initial_counts=initial_counts, + num_buckets=num_buckets, + ) + target_total = tracker.goal_total + existing = sum(initial_counts) + tag_msg = f" [tag={writer_tag}]" if writer_tag else "" + print( + f"balance{tag_msg}: target {per_bucket}/bucket x {num_buckets} buckets " + f"= {target_total} episodes (found {existing} pre-existing in {output_dir})" + ) + if existing > 0: + print(tracker.histogram()) + elif replay_inits is not None: + target_total = len(replay_inits) + else: + target_total = args.num_episodes + + obs, info, current_bucket = _do_reset() coverage = info.get("coverage", 0.0) # Auto-start recording so a successful push is never lost because the # user forgot to press SPACE before moving the shape. - writer.start_episode(init_state=env.get_episode_init()) + writer.start_episode(init_state=_make_init_state()) recording = True saved = 0 running = True - while running and saved < args.num_episodes: + while running: + if tracker is not None and tracker.filled: + break + if replay_inits is not None and init_idx >= len(replay_inits): + break + if tracker is None and replay_inits is None and saved >= target_total: + break for event in pygame.event.get(): if event.type == pygame.QUIT: running = False @@ -137,21 +492,32 @@ def run(args: argparse.Namespace) -> int: recording = not recording elif event.key == pygame.K_r: writer.abort_episode() - obs, info = env.reset() + # Replay mode: stay on the same init so the user can retry. + obs, info, current_bucket = _do_reset() coverage = info.get("coverage", 0.0) - writer.start_episode(init_state=env.get_episode_init()) + writer.start_episode(init_state=_make_init_state()) recording = True elif event.key == pygame.K_s: if writer.steps_in_episode > 0: idx = writer.commit_episode() if idx >= 0: saved += 1 + if tracker is not None: + tracker.increment(current_bucket) + if replay_inits is not None: + init_idx += 1 print( - f"saved episode {idx:06d} ({saved}/{args.num_episodes})" + f"saved episode {idx:06d} bucket={current_bucket:>2} " + f"({saved}/{target_total})" ) - obs, info = env.reset() + if tracker is not None: + print(tracker.histogram()) + if replay_inits is not None and init_idx >= len(replay_inits): + running = False + break + obs, info, current_bucket = _do_reset() coverage = info.get("coverage", 0.0) - writer.start_episode(init_state=env.get_episode_init()) + writer.start_episode(init_state=_make_init_state()) recording = True # Action = mouse pos in world coords. Window is scaled up from the @@ -204,15 +570,28 @@ def run(args: argparse.Namespace) -> int: idx = writer.commit_episode() if idx >= 0: saved += 1 + if tracker is not None: + tracker.increment(current_bucket) + if replay_inits is not None: + init_idx += 1 print( - f"auto-saved episode {idx:06d} ({saved}/{args.num_episodes})" + f"auto-saved episode {idx:06d} bucket={current_bucket:>2} " + f"({saved}/{target_total})" ) + if tracker is not None: + print(tracker.histogram()) else: writer.abort_episode() - if saved < args.num_episodes: - obs, info = env.reset() + if tracker is not None: + more_to_collect = not tracker.filled + elif replay_inits is not None: + more_to_collect = init_idx < len(replay_inits) + else: + more_to_collect = saved < target_total + if more_to_collect: + obs, info, current_bucket = _do_reset() coverage = info.get("coverage", 0.0) - writer.start_episode(init_state=env.get_episode_init()) + writer.start_episode(init_state=_make_init_state()) recording = True writer.close() @@ -220,6 +599,8 @@ def run(args: argparse.Namespace) -> int: pygame.display.quit() pygame.quit() print(f"done. saved {saved} episodes to {output_dir}") + if tracker is not None: + print(tracker.histogram()) return 0 @@ -227,16 +608,81 @@ def build_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser(description=__doc__) p.add_argument( "--output", - required=True, - help="dataset root; demos are stored under ///", + default=None, + help="dataset root; demos are stored under ///." + " Ignored if --output-dir is set.", + ) + p.add_argument( + "--output-dir", + default=None, + help="write episodes directly into this exact directory, bypassing the" + " / subpath. Use this to keep custom folder names.", ) p.add_argument("--object", default="T", choices=["T", "U", "Z"]) - p.add_argument("--pusher", default="circle", choices=["circle", "stick"]) - p.add_argument("--obstacles", type=int, default=0, choices=[0, 1, 2, 3]) + p.add_argument( + "--pusher", + default="circle", + choices=["circle", "circle_small", "stick"], + ) + p.add_argument("--obstacles", type=int, default=0, choices=list(range(30))) p.add_argument("--num-episodes", type=int, default=50) p.add_argument("--image-size", type=int, default=96) p.add_argument("--fps", type=int, default=30) p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--balance", + action="store_true", + help="balance saved episodes across the 16 (object-quadrant, goal-quadrant)" + " buckets via rejection sampling at each reset", + ) + p.add_argument( + "--per-bucket", + type=int, + default=None, + help="with --balance, episodes per bucket. Default is" + " ceil(num_episodes/N) where N is 16 in standard mode and 4 in on-target mode", + ) + p.add_argument( + "--mode", + choices=[_MODE_STANDARD, _MODE_ON_TARGET], + default=_MODE_STANDARD, + help="standard: random object + goal placements (4x4=16 buckets)." + " on-target: object spawns AT goal xy with a bad orientation" + " (4 pusher-quadrant buckets); useful for collecting recovery-rotation demos", + ) + p.add_argument( + "--tag", + default=None, + help="alphanumeric tag inserted into saved filenames. on-target mode" + f" defaults to '{_DEFAULT_ONTARGET_TAG}' so it stays distinct from" + " standard episodes in the same folder. Tagged + untagged sequences" + " are numbered independently", + ) + p.add_argument( + "--fixed-goal", + default=None, + help="X,Y,THETA (radians) — override the env's randomized goal_pose" + " with this fixed pose after each reset. Useful for matching" + " benchmarks that use a single goal configuration" + " (e.g. Diffusion Policy PushT: 256,256,0.7853981633974483).", + ) + p.add_argument( + "--replay-init-from", + default=None, + help="path to an existing dataset directory. Iterates its .zarr" + " episodes in filename order; for each one, calls env.reset() then" + " env.set_state() to force the saved (agent_pos, object_pose," + " goal_pose). Disables --balance and on-target mode. Stops once" + " every source init has been recorded once.", + ) + p.add_argument( + "--min-angle-deg", + type=float, + default=30.0, + help="on-target mode: minimum |delta| (degrees) between object and" + " goal orientation. Below this, the spawn is redrawn so the rotation" + " recovery task is non-trivial", + ) return p diff --git a/Tsimulation/collect/scripted_collect.py b/Tsimulation/collect/scripted_collect.py index 6a145384a..c00b9ec55 100644 --- a/Tsimulation/collect/scripted_collect.py +++ b/Tsimulation/collect/scripted_collect.py @@ -31,9 +31,16 @@ import argparse import os import sys +from pathlib import Path import numpy as np +from Tsimulation.collect.balance import ( + N_BUCKETS, + BucketTracker, + bucket_for, + count_existing_buckets, +) from Tsimulation.collect.zarr_writer import ZarrDemoWriter from Tsimulation.pushshapes.env import PushShapesEnv @@ -41,6 +48,41 @@ PUSH_LOOKAHEAD = 80.0 CONTACT_RADIUS = 25.0 +# Default per-frame drift tolerance (world units, L-inf over the 5-vec +# [pusher_x, pusher_y, obj_x, obj_y, obj_theta]) when replaying a recorded +# episode through the live env. Episodes that exceed this are rejected. +DEFAULT_REPLAY_DRIFT_THRESHOLD = 0.5 + + +def _replay_validate( + env: PushShapesEnv, + ep_init: dict, + actions: np.ndarray, + recorded_states: np.ndarray, + drift_threshold: float, +) -> tuple[bool, float, int]: + """Replay ``actions`` through ``env`` from ``ep_init`` and reject if + any post-step drift between env and ``recorded_states[t+1]`` exceeds + ``drift_threshold``. Returns ``(ok, max_drift, frame)``. + + Delegates the step-loop to ``replay_zarr._replay_step_loop`` so there + is a single source of truth for the replay logic shared with + ``Tsimulation.examples.replay_zarr.replay_one`` and the dataloader's + coverage filter. + """ + from Tsimulation.examples.replay_zarr import _replay_step_loop + env.reset(seed=ep_init.get("reset_seed")) + env.set_state( + agent_pos=tuple(ep_init["agent_pos"]), + object_pose=tuple(ep_init["object_pose"]), + goal_pose=tuple(ep_init["goal_pose"]), + ) + metrics = _replay_step_loop(env, actions, recorded_states, + early_stop_drift=drift_threshold) + if metrics["early_stop_frame"] is not None: + return False, metrics["drift_max"], int(metrics["early_stop_frame"]) + return True, metrics["drift_max"], len(actions) + def scripted_action( *, @@ -72,13 +114,33 @@ def run_episode( rng: np.random.Generator, jitter: float, seed: int | None, -) -> tuple[bool, float, int]: - """Roll out one scripted episode. Returns (committed, final_coverage, steps).""" + tracker: BucketTracker | None = None, + replay_validate: bool = True, + replay_drift_threshold: float = DEFAULT_REPLAY_DRIFT_THRESHOLD, +) -> tuple[bool, float, int, int, str]: + """Roll out one scripted episode. + + Returns ``(committed, final_coverage, steps, bucket, reject_reason)``. + ``reject_reason`` is ``""`` on commit; ``"bucket_full"`` / + ``"low_coverage"`` / ``"replay_drift"`` when the episode is rejected. + + With ``replay_validate=True`` (default), the recorded actions are + replayed through a fresh env from the captured ``episode_init`` after + the rollout completes; if any frame's L-inf drift in the + ``[pusher_xy, obj_xy, obj_theta]`` 5-vec exceeds + ``replay_drift_threshold``, the episode is aborted instead of committed. + """ obs, info = env.reset(seed=seed) - writer.start_episode(init_state=env.get_episode_init()) + ep_init = env.get_episode_init() + bucket = bucket_for(ep_init, float(env.WORLD_SIZE)) + if tracker is not None and not tracker.has_room(bucket): + return False, 0.0, 0, bucket, "bucket_full" + writer.start_episode(init_state=ep_init) coverage = float(info.get("coverage", 0.0)) final_coverage = coverage steps = 0 + recorded_actions: list[np.ndarray] = [] + recorded_states: list[np.ndarray] = [] for _ in range(max_steps): agent_xy = np.asarray(obs["agent_pos"], dtype=np.float64) object_xy = np.asarray(obs["object_pose"][:2], dtype=np.float64) @@ -105,6 +167,13 @@ def run_episode( reward=float(reward), goal_pose=obs["goal_pose"], ) + recorded_actions.append(np.asarray(action, dtype=np.float64)) + recorded_states.append( + np.concatenate([ + np.asarray(obs["agent_pos"], dtype=np.float64), + np.asarray(obs["object_pose"], dtype=np.float64), + ]) + ) obs = next_obs steps += 1 @@ -112,13 +181,24 @@ def run_episode( if terminated or truncated: break - committed = False - if final_coverage >= success_threshold and writer.steps_in_episode > 0: - idx = writer.commit_episode() - committed = idx >= 0 - else: + if final_coverage < success_threshold or writer.steps_in_episode == 0: writer.abort_episode() - return committed, final_coverage, steps + return False, final_coverage, steps, bucket, "low_coverage" + + if replay_validate and len(recorded_actions) > 0: + actions_arr = np.stack(recorded_actions, axis=0) + states_arr = np.stack(recorded_states, axis=0) + ok, drift, frame = _replay_validate( + env, ep_init, actions_arr, states_arr, replay_drift_threshold, + ) + if not ok: + writer.abort_episode() + return False, final_coverage, steps, bucket, f"replay_drift={drift:.3f}@f{frame}" + + idx = writer.commit_episode() + if idx < 0: + return False, final_coverage, steps, bucket, "commit_failed" + return True, final_coverage, steps, bucket, "" def run(args: argparse.Namespace) -> int: @@ -151,13 +231,46 @@ def run(args: argparse.Namespace) -> int: ) rng = np.random.default_rng(args.seed) + tracker: BucketTracker | None = None + if args.balance: + per_bucket = args.per_bucket + if per_bucket is None: + per_bucket = max(1, -(-args.num_episodes // N_BUCKETS)) # ceil div + initial_counts = count_existing_buckets( + Path(args.output), float(env.WORLD_SIZE) + ) + tracker = BucketTracker( + target_per_bucket=per_bucket, initial_counts=initial_counts + ) + target_total = tracker.goal_total + existing = sum(initial_counts) + print( + f"balance: target {per_bucket}/bucket x {N_BUCKETS} buckets " + f"= {target_total} episodes (found {existing} pre-existing in {args.output})" + ) + if existing > 0: + print(tracker.histogram()) + else: + target_total = args.num_episodes + saved = 0 attempts = 0 - max_attempts = args.num_episodes * args.max_tries - while saved < args.num_episodes and attempts < max_attempts: + drift_rejects = 0 + max_attempts = ( + (target_total - (tracker.total if tracker is not None else 0)) + * args.max_tries + ) + max_attempts = max(max_attempts, args.max_tries) + while True: + if tracker is not None and tracker.filled: + break + if tracker is None and saved >= target_total: + break + if attempts >= max_attempts: + break attempts += 1 ep_seed = (args.seed + attempts) if args.seed is not None else None - committed, coverage, steps = run_episode( + committed, coverage, steps, bucket, reject_reason = run_episode( env, writer, max_steps=args.max_steps, @@ -165,26 +278,38 @@ def run(args: argparse.Namespace) -> int: rng=rng, jitter=args.jitter, seed=ep_seed, + tracker=tracker, + replay_validate=args.replay_validate, + replay_drift_threshold=args.replay_drift_threshold, ) if committed: saved += 1 + if tracker is not None: + tracker.increment(bucket) print( - f"saved ep {saved:03d}/{args.num_episodes} " + f"saved ep {saved:03d}/{target_total} bucket={bucket:>2} " f"steps={steps:3d} coverage={coverage:.3f} (attempt {attempts})" ) - elif args.verbose: - print( - f"discard attempt {attempts:3d} " - f"steps={steps:3d} coverage={coverage:.3f}" - ) + if tracker is not None and tracker.filled: + break + else: + if reject_reason.startswith("replay_drift"): + drift_rejects += 1 + if args.verbose or reject_reason.startswith("replay_drift"): + print( + f"discard attempt {attempts:3d} bucket={bucket:>2} " + f"steps={steps:3d} coverage={coverage:.3f} reason={reject_reason}" + ) writer.close() env.close() print( - f"done. saved {saved}/{args.num_episodes} in {attempts} attempts " - f"-> {args.output}" + f"done. saved {saved}/{target_total} in {attempts} attempts " + f"(replay_drift_rejected={drift_rejects}) -> {args.output}" ) - return 0 if saved == args.num_episodes else 1 + if tracker is not None: + print(tracker.histogram()) + return 0 if saved == target_total else 1 def build_parser() -> argparse.ArgumentParser: @@ -193,8 +318,12 @@ def build_parser() -> argparse.ArgumentParser: "--output", required=True, help="directory for episode_*.zarr stores" ) parser.add_argument("--object", default="T", choices=["T", "U", "Z"]) - parser.add_argument("--pusher", default="circle", choices=["circle", "stick"]) - parser.add_argument("--obstacles", type=int, default=0, choices=[0, 1, 2, 3]) + parser.add_argument( + "--pusher", + default="circle", + choices=["circle", "circle_small", "stick"], + ) + parser.add_argument("--obstacles", type=int, default=0, choices=list(range(30))) parser.add_argument("--num-episodes", type=int, default=20) parser.add_argument("--image-size", type=int, default=96) parser.add_argument("--max-steps", type=int, default=300) @@ -210,7 +339,20 @@ def build_parser() -> argparse.ArgumentParser: "--max-tries", type=int, default=4, - help="give up after num_episodes * max_tries attempts", + help="give up after target_total * max_tries attempts (target_total is" + " num_episodes, or per_bucket*16 when --balance is set)", + ) + parser.add_argument( + "--balance", + action="store_true", + help="balance saved episodes across the 16 (object-quadrant, goal-quadrant)" + " buckets via rejection sampling", + ) + parser.add_argument( + "--per-bucket", + type=int, + default=None, + help="with --balance, episodes per bucket (default: ceil(num_episodes/16))", ) parser.add_argument( "--jitter", @@ -231,6 +373,29 @@ def build_parser() -> argparse.ArgumentParser: help="render to a real display (requires a windowed environment)", ) parser.add_argument("--verbose", action="store_true") + parser.add_argument( + "--replay-validate", + dest="replay_validate", + action="store_true", + default=True, + help="replay recorded actions through the env after each rollout and" + " reject episodes that drift from the recorded states (default: True)", + ) + parser.add_argument( + "--no-replay-validate", + dest="replay_validate", + action="store_false", + help="skip replay-drift validation (faster but allows non-deterministic" + " episodes to land on disk)", + ) + parser.add_argument( + "--replay-drift-threshold", + type=float, + default=DEFAULT_REPLAY_DRIFT_THRESHOLD, + help="max per-frame L-inf drift (world units) in [pusher_xy, obj_xy," + f" obj_theta] tolerated by replay validation (default:" + f" {DEFAULT_REPLAY_DRIFT_THRESHOLD})", + ) return parser diff --git a/Tsimulation/collect/zarr_writer.py b/Tsimulation/collect/zarr_writer.py index 89197442a..8eb391ad8 100644 --- a/Tsimulation/collect/zarr_writer.py +++ b/Tsimulation/collect/zarr_writer.py @@ -6,7 +6,8 @@ Episode commits are atomic: data is buffered in memory and only written to disk in `commit_episode`. Crash mid-episode leaves the on-disk store -untouched. +untouched. Because commits are bulk writes, the resulting episode layout +matches the repo's compact one-chunk-per-array format. """ from __future__ import annotations @@ -27,7 +28,11 @@ REWARD_KEY = "reward" GOAL_KEY = "goal_pose" -_EPISODE_RE = re.compile(r"^episode_[A-Za-z0-9]+_[A-Za-z0-9]+_obs\d+_(\d+)\.zarr$") +# Non-greedy ``.+?`` for the object_shape + pusher_shape prefix so pusher names +# containing ``_`` (e.g. ``circle_small``) still match. Anchored by ``_obs\d+_`` +# on the right so we don't over-match into the obstacle level / tag / index. +_EPISODE_RE = re.compile(r"^episode_.+?_obs\d+_(\d+)\.zarr$") +_TAGGED_EPISODE_RE_TMPL = r"^episode_.+?_obs\d+_{tag}_(\d+)\.zarr$" class ZarrDemoWriter: @@ -42,6 +47,7 @@ def __init__( task_name: str = "pushshapes", embodiment: str = "pushshapes_sim", chunk_timesteps: int = 100, + tag: str | None = None, ): self.path = Path(path) self.path.mkdir(parents=True, exist_ok=True) @@ -53,6 +59,14 @@ def __init__( self._object_shape = env_args.get("object_shape", "obj") self._pusher_shape = env_args.get("pusher_shape", "pusher") self._obstacle_level = env_args.get("obstacle_level", 0) + # Optional filename tag — keeps mixed-purpose datasets distinguishable + # and indexed independently in the same folder. + if tag is not None: + if not re.match(r"^[A-Za-z0-9]+$", tag): + raise ValueError( + f"tag must be alphanumeric ([A-Za-z0-9]+), got {tag!r}" + ) + self._tag = tag self.task_description = json.dumps( {"env_args": env_args, "version": "0.3"}, separators=(",", ":") ) @@ -130,9 +144,10 @@ def commit_episode(self) -> int: self._buffer = None return -1 + tag_part = f"_{self._tag}" if self._tag else "" ep_path = self.path / ( f"episode_{self._object_shape}_{self._pusher_shape}" - f"_obs{self._obstacle_level}_{self._episode_idx:06d}.zarr" + f"_obs{self._obstacle_level}{tag_part}_{self._episode_idx:06d}.zarr" ) images = np.stack(self._buffer[IMAGE_KEY], axis=0) numeric = { @@ -182,14 +197,21 @@ def _find_next_index(self) -> int: Only directories matching the regex are considered — stray files or unrelated subdirs are ignored. Returns 0 on a fresh output dir. + When ``tag`` was supplied, only episodes carrying the same tag in + their filename are counted (so tagged + untagged sequences advance + independently in a mixed-purpose folder). """ if not self.path.exists(): return 0 + if self._tag is not None: + pattern = re.compile(_TAGGED_EPISODE_RE_TMPL.format(tag=re.escape(self._tag))) + else: + pattern = _EPISODE_RE max_idx = -1 for entry in self.path.iterdir(): if not entry.is_dir(): continue - m = _EPISODE_RE.match(entry.name) + m = pattern.match(entry.name) if m: max_idx = max(max_idx, int(m.group(1))) return max_idx + 1 diff --git a/Tsimulation/examples/play_random.py b/Tsimulation/examples/play_random.py index f9a720b47..8980910a2 100644 --- a/Tsimulation/examples/play_random.py +++ b/Tsimulation/examples/play_random.py @@ -20,7 +20,11 @@ def main() -> int: p = argparse.ArgumentParser() p.add_argument("--object", default="T", choices=["T", "U", "Z"]) - p.add_argument("--pusher", default="circle", choices=["circle", "stick"]) + p.add_argument( + "--pusher", + default="circle", + choices=["circle", "circle_small", "stick"], + ) p.add_argument("--obstacles", type=int, default=0, choices=[0, 1, 2, 3]) p.add_argument("--steps", type=int, default=100) p.add_argument("--seed", type=int, default=0) diff --git a/Tsimulation/examples/replay_zarr.py b/Tsimulation/examples/replay_zarr.py index 78a7c4d6c..7e5e95c3c 100644 --- a/Tsimulation/examples/replay_zarr.py +++ b/Tsimulation/examples/replay_zarr.py @@ -83,30 +83,59 @@ def replay_one(episode_path: Path, tol: float) -> dict: op = (float(s0[2]), float(s0[3]), float(s0[4])) gp = (float(goal_pose[0]), float(goal_pose[1]), float(goal_pose[2])) env.set_state(agent_pos=ap, object_pose=op, goal_pose=gp) + metrics = _replay_step_loop(env, actions, states) + env.close() + stored_max = float(reward.max()) + return { + "name": episode_path.name, + "T": len(actions), + "stored_cov": stored_max, + "replay_cov": metrics["replay_cov"], + "drift_mean": metrics["drift_mean"], + "drift_max": metrics["drift_max"], + "ok": metrics["replay_cov"] >= stored_max - tol, + } + - drift = [] +def _replay_step_loop( + env, + actions: np.ndarray, + recorded_states: np.ndarray, + *, + early_stop_drift: float | None = None, +) -> dict: + """Step ``actions`` through ``env`` (already reset + set_state'd) and + track per-step drift between the post-step env state and + ``recorded_states[t+1]`` (L2 norm on the 5-vec + ``[pusher_xy, obj_xy, obj_theta]``). + + Optional ``early_stop_drift`` short-circuits the loop on the first + frame where drift exceeds the threshold — used by the collector's + pre-commit validation. + + Returns ``{drift_max, drift_mean, replay_cov, early_stop_frame}``. + """ + drifts: list[float] = [] max_cov = 0.0 + early_stop_frame: int | None = None for i in range(len(actions)): obs, _, term, _, info = env.step(actions[i]) max_cov = max(max_cov, info["coverage"]) - if i + 1 < len(states): + if i + 1 < len(recorded_states): live = np.concatenate([obs["agent_pos"], obs["object_pose"]]) - drift.append(float(np.linalg.norm(states[i + 1] - live))) + d = float(np.linalg.norm(recorded_states[i + 1] - live)) + drifts.append(d) + if early_stop_drift is not None and d > early_stop_drift: + early_stop_frame = i + break if term: break - - env.close() - drift = np.asarray(drift) if drift else np.zeros(1) - stored_max = float(reward.max()) - + drift_arr = np.asarray(drifts) if drifts else np.zeros(1) return { - "name": episode_path.name, - "T": len(actions), - "stored_cov": stored_max, - "replay_cov": max_cov, - "drift_mean": float(drift.mean()), - "drift_max": float(drift.max()), - "ok": max_cov >= stored_max - tol, + "drift_max": float(drift_arr.max()), + "drift_mean": float(drift_arr.mean()), + "replay_cov": float(max_cov), + "early_stop_frame": early_stop_frame, } diff --git a/Tsimulation/pushshapes/env.py b/Tsimulation/pushshapes/env.py index 7fa46f050..060942a21 100644 --- a/Tsimulation/pushshapes/env.py +++ b/Tsimulation/pushshapes/env.py @@ -34,12 +34,14 @@ to_image_obs, ) from Tsimulation.pushshapes.shapes import ( - PUSHER_RADIUS, SHAPES, make_object, make_pusher, + pusher_radius, ) +_VALID_PUSHERS = ("circle", "circle_small", "stick") + # Tunables not exposed via __init__ — surfaced here for visibility. _MIN_TARGET_DIST = 1e-3 # below this, treat pusher as on-target _MIN_STICK_TURN_DIST = 1.0 # stick only re-orients when moving meaningfully @@ -74,9 +76,9 @@ def __init__( if object_shape not in SHAPES: raise ValueError(f"object_shape {object_shape!r} not in {list(SHAPES)}") - if pusher_shape not in ("circle", "stick"): + if pusher_shape not in _VALID_PUSHERS: raise ValueError( - f"pusher_shape {pusher_shape!r} not in ('circle', 'stick')" + f"pusher_shape {pusher_shape!r} not in {_VALID_PUSHERS}" ) if obstacle_level not in OBSTACLE_LEVELS: raise ValueError( @@ -475,7 +477,7 @@ def _sample_pusher_pos( object_pos: tuple[float, float], ) -> tuple[float, float]: m = self.SPAWN_MARGIN - radius = PUSHER_RADIUS + 5.0 + radius = pusher_radius(self.pusher_shape) + 5.0 for _ in range(_SPAWN_MAX_TRIES): x = float(self._np_random.uniform(m, self.WORLD_SIZE - m)) y = float(self._np_random.uniform(m, self.WORLD_SIZE - m)) diff --git a/Tsimulation/pushshapes/obstacles.py b/Tsimulation/pushshapes/obstacles.py index 6fb8d3eaf..10928bdfa 100644 --- a/Tsimulation/pushshapes/obstacles.py +++ b/Tsimulation/pushshapes/obstacles.py @@ -4,16 +4,387 @@ arena. Segments are added to the space's static body as `pymunk.Segment` shapes. Level 0 is empty; higher levels progressively constrain the routes the pusher can take. + +Levels 4..19 were designed to be solvable: every corridor that an object +must traverse is at least ``_MIN_CORRIDOR`` units wide so the T-shape +(120x120 AABB) plus pusher (30 diameter) can fit through. The accompanying +:func:`verify_level_solvable` helper probes each level by sampling many +(object, goal) pairs and reporting spawn-fallback rate and per-quadrant +reachability. """ from __future__ import annotations +from typing import Iterable + import pymunk WALL_RADIUS = 4.0 WALL_FRICTION = 0.7 -OBSTACLE_LEVELS: dict[int, list[tuple[tuple[float, float], tuple[float, float]]]] = { +# Minimum gap that a T-shape (120x120 AABB) + circle pusher (30 diameter) +# can navigate through with a safety margin. Used to size every corridor +# in the level designs below. +_MIN_CORRIDOR = 150.0 + +# 512x512 arena. Levels assume this and place obstacles inside. +_W = 512.0 + +Segment = tuple[tuple[float, float], tuple[float, float]] + + +# ---------------------------------------------------------------------- # +# Composable obstacle primitives. +# Each helper returns a list of Segments. Higher levels compose these. +# ---------------------------------------------------------------------- # + + +def _wall(p1: tuple[float, float], p2: tuple[float, float]) -> list[Segment]: + return [(p1, p2)] + + +def _box(cx: float, cy: float, w: float, h: float) -> list[Segment]: + """Closed axis-aligned box (4 segments).""" + hw, hh = w / 2.0, h / 2.0 + p1 = (cx - hw, cy - hh) + p2 = (cx + hw, cy - hh) + p3 = (cx + hw, cy + hh) + p4 = (cx - hw, cy + hh) + return [(p1, p2), (p2, p3), (p3, p4), (p4, p1)] + + +def _l_shape( + corner: tuple[float, float], + arm_x: float, + arm_y: float, +) -> list[Segment]: + """L-shape with the elbow at ``corner``. Positive arm goes right/down; + negative arm goes left/up.""" + cx, cy = corner + return [ + (corner, (cx + arm_x, cy)), + (corner, (cx, cy + arm_y)), + ] + + +def _cross(cx: float, cy: float, arm: float) -> list[Segment]: + """Plus / cross: two crossing segments of total length ``2*arm``.""" + return [ + ((cx - arm, cy), (cx + arm, cy)), + ((cx, cy - arm), (cx, cy + arm)), + ] + + +def _diamond(cx: float, cy: float, r: float) -> list[Segment]: + """Rotated square (diamond) outline.""" + return [ + ((cx, cy - r), (cx + r, cy)), + ((cx + r, cy), (cx, cy + r)), + ((cx, cy + r), (cx - r, cy)), + ((cx - r, cy), (cx, cy - r)), + ] + + +def _hexagon(cx: float, cy: float, r: float) -> list[Segment]: + """Regular hexagon (6 short segments) — pointy-top orientation.""" + import math + + verts = [ + (cx + r * math.cos(math.pi / 2 + i * math.pi / 3), + cy + r * math.sin(math.pi / 2 + i * math.pi / 3)) + for i in range(6) + ] + return [(verts[i], verts[(i + 1) % 6]) for i in range(6)] + + +def _triangle( + p1: tuple[float, float], + p2: tuple[float, float], + p3: tuple[float, float], +) -> list[Segment]: + return [(p1, p2), (p2, p3), (p3, p1)] + + +def _arc( + cx: float, + cy: float, + r: float, + t_start: float, + t_end: float, + n: int = 16, +) -> list[Segment]: + """Circular arc approximated by ``n`` straight segments. + + Angles use math convention (radians) but the env is +y-down (pygame), so + ``t = 0`` is right, ``t = pi/2`` is down, ``t = pi`` is left, ``t = 3*pi/2`` + is up. Increase ``n`` for smoother curvature on larger radii.""" + import math + + pts = [ + ( + cx + r * math.cos(t_start + (t_end - t_start) * i / n), + cy + r * math.sin(t_start + (t_end - t_start) * i / n), + ) + for i in range(n + 1) + ] + return [(pts[i], pts[i + 1]) for i in range(n)] + + +# ---------------------------------------------------------------------- # +# Level layouts. +# ---------------------------------------------------------------------- # + + +def _level_3_chicane() -> list[Segment]: + """Two staggered horizontal walls leaving a wide alternating gap. Fixes + the original level-3 which had a 72-wide corridor (unsolvable for T).""" + # Top wall: x in [0, 280], y = 180 + # Bottom wall: x in [232, 512], y = 332 + # Vertical corridor between them is ~150 wide (332-180=152), the lateral + # detour at top requires going right of x=280, at bottom left of x=232. + return [ + ((0.0, 180.0), (280.0, 180.0)), + ((232.0, 332.0), (_W, 332.0)), + ] + + +def _level_4_single_box() -> list[Segment]: + return _box(256.0, 256.0, 80.0, 80.0) + + +def _level_5_corner_l() -> list[Segment]: + """Free-standing L-shaped obstacle in upper-left arena region. Arms point + inward (right + down) and never touch the arena edges so nothing is + enclosed.""" + return _l_shape(corner=(180.0, 180.0), arm_x=100.0, arm_y=100.0) + + +def _level_6_t_obstacle() -> list[Segment]: + """T-shaped wall — horizontal top + vertical stem (entirely above + center so the bottom half is open and any object/goal can be placed).""" + return [ + ((180.0, 130.0), (332.0, 130.0)), + ((256.0, 130.0), (256.0, 256.0)), + ] + + +def _level_7_central_cross() -> list[Segment]: + """Small plus in center. Corridors around it are >150.""" + return _cross(256.0, 256.0, 70.0) + + +def _level_8_diamond() -> list[Segment]: + return _diamond(256.0, 256.0, 90.0) + + +def _level_9_two_boxes() -> list[Segment]: + """Two boxes on the diagonal leaving a wide zigzag path.""" + return _box(170.0, 170.0, 80.0, 80.0) + _box(342.0, 342.0, 80.0, 80.0) + + +def _level_10_zigzag_walls() -> list[Segment]: + """Two partial horizontal walls with alternating side openings; the + walls are 200 apart vertically so the T (120 tall) has a clear + horizontal band between them.""" + return [ + ((0.0, 156.0), (240.0, 156.0)), # opening on right: x[240,512] = 272 wide + ((272.0, 356.0), (_W, 356.0)), # opening on left: x[0,272] = 272 wide + ] + + +def _level_11_corner_ls() -> list[Segment]: + """Two free-standing L-shapes in opposite arena regions. Arms point + inward so neither L touches the arena edges.""" + return ( + _l_shape(corner=(170.0, 170.0), arm_x=100.0, arm_y=100.0) + + _l_shape(corner=(342.0, 342.0), arm_x=-100.0, arm_y=-100.0) + ) + + +def _level_12_central_cross_plus_boxes() -> list[Segment]: + """Central cross + one diagonal satellite box. Smaller features so the + T can navigate around either obstacle in any direction.""" + return ( + _cross(256.0, 256.0, 40.0) + + _box(140.0, 380.0, 40.0, 40.0) + ) + + +def _level_13_dual_diamonds() -> list[Segment]: + """Two free-standing diamonds on the diagonal, sized so the T can + navigate fully around either one.""" + return _diamond(170.0, 170.0, 50.0) + _diamond(342.0, 342.0, 50.0) + + +def _level_14_hexagon_center() -> list[Segment]: + return _hexagon(256.0, 256.0, 75.0) + + +def _level_15_hexagon_with_walls() -> list[Segment]: + """Hexagon plus two short interior walls — none of which touch the arena + edges, so no region is enclosed.""" + return ( + _hexagon(256.0, 256.0, 65.0) + + _wall((100.0, 130.0), (260.0, 130.0)) + + _wall((252.0, 382.0), (412.0, 382.0)) + ) + + +def _level_16_three_boxes_row() -> list[Segment]: + """Three small boxes in a horizontal row with wide vertical lanes around.""" + return ( + _box(140.0, 256.0, 50.0, 50.0) + + _box(256.0, 256.0, 50.0, 50.0) + + _box(372.0, 256.0, 50.0, 50.0) + ) + + +def _level_17_triangle_and_box() -> list[Segment]: + """Triangle wedge plus a box, both pulled away from arena edges.""" + return ( + _triangle((190.0, 190.0), (260.0, 190.0), (225.0, 250.0)) + + _box(360.0, 360.0, 60.0, 60.0) + ) + + +def _level_18_double_zigzag() -> list[Segment]: + """S-shaped slalom (two staggered horizontal walls with alternating + side openings, 200 apart vertically).""" + return [ + ((0.0, 156.0), (240.0, 156.0)), # opening on right + ((272.0, 356.0), (_W, 356.0)), # opening on left + ] + + +def _level_19_maze() -> list[Segment]: + """Densest solvable layout: small central diamond + 4 free-standing L's + pointing inward from each corner region. Tuned so all corridors stay + in a single navigable component.""" + return ( + _diamond(256.0, 256.0, 45.0) + + _l_shape(corner=(170.0, 170.0), arm_x=60.0, arm_y=60.0) + + _l_shape(corner=(342.0, 170.0), arm_x=-60.0, arm_y=60.0) + + _l_shape(corner=(170.0, 342.0), arm_x=60.0, arm_y=-60.0) + + _l_shape(corner=(342.0, 342.0), arm_x=-60.0, arm_y=-60.0) + ) + + +# ---------------------------------------------------------------------- # +# Curve-based levels (20..29). All composed from `_arc` primitives; radii +# and gaps tuned so the T (120 AABB) + circle pusher can still navigate. +# ---------------------------------------------------------------------- # + + +def _level_20_c_curve() -> list[Segment]: + """Single C-shaped arc opening to the right (180-deg sweep on left half + of center).""" + import math + + return _arc(256.0, 256.0, 80.0, math.pi / 2, 3 * math.pi / 2, n=24) + + +def _level_21_s_curve() -> list[Segment]: + """S-curve: top half-circle bulging down + bottom half-circle bulging up, + diagonally offset so neither arc closes off any region.""" + import math + + return ( + _arc(180.0, 180.0, 55.0, 0.0, math.pi, n=18) + + _arc(342.0, 342.0, 55.0, math.pi, 2 * math.pi, n=18) + ) + + +def _level_22_open_ring() -> list[Segment]: + """Ring with a single 60-degree gap on the right side.""" + import math + + return _arc(256.0, 256.0, 80.0, math.pi / 6, 2 * math.pi - math.pi / 6, n=28) + + +def _level_23_four_quarter_corners() -> list[Segment]: + """Quarter-arc in each corner curving inward toward the arena center.""" + import math + + r = 70.0 + inset = 130.0 + return ( + # top-left: arc from right (0) sweeping down (pi/2) + _arc(inset, inset, r, 0.0, math.pi / 2, n=10) + # top-right: arc from down (pi/2) sweeping left (pi) + + _arc(_W - inset, inset, r, math.pi / 2, math.pi, n=10) + # bottom-right: arc from left (pi) sweeping up (3pi/2) + + _arc(_W - inset, _W - inset, r, math.pi, 3 * math.pi / 2, n=10) + # bottom-left: arc from up (3pi/2) sweeping right (2pi) + + _arc(inset, _W - inset, r, 3 * math.pi / 2, 2 * math.pi, n=10) + ) + + +def _level_24_arch_top() -> list[Segment]: + """Top arch: half-circle hanging from the top edge of the playable area + (sits well below the wall so a T can squeeze through underneath).""" + import math + + return _arc(256.0, 130.0, 90.0, 0.0, math.pi, n=20) + + +def _level_25_bowl_bottom() -> list[Segment]: + """Bottom bowl: half-circle resting near the bottom of the arena. Mirror + of level 24 — visually similar primitive, different placement.""" + import math + + return _arc(256.0, 382.0, 90.0, math.pi, 2 * math.pi, n=20) + + +def _level_26_wave_floor() -> list[Segment]: + """Three small half-circle bumps spaced across the middle (wavy ridge), + each bumping downward. Wide vertical lanes above and below.""" + import math + + r = 35.0 + y = 256.0 + return ( + _arc(130.0, y, r, math.pi, 2 * math.pi, n=12) + + _arc(256.0, y, r, math.pi, 2 * math.pi, n=12) + + _arc(382.0, y, r, math.pi, 2 * math.pi, n=12) + ) + + +def _level_27_two_loops() -> list[Segment]: + """Two open loops side-by-side (figure-8-ish silhouette): each is an arc + with a small gap facing outward so neither encloses anything.""" + import math + + return ( + _arc(180.0, 256.0, 60.0, math.pi / 6, 2 * math.pi - math.pi / 6, n=22) + + _arc(332.0, 256.0, 60.0, + math.pi - math.pi / 6, 3 * math.pi - math.pi / 6, n=22) + ) + + +def _level_28_corner_quarter_pipe() -> list[Segment]: + """Quarter-pipe in the upper-left corner (arc) + a short diagonal wall + in the opposite area for asymmetry.""" + import math + + return ( + _arc(120.0, 120.0, 110.0, 0.0, math.pi / 2, n=14) + + _wall((320.0, 360.0), (420.0, 420.0)) + ) + + +def _level_29_concentric_rings() -> list[Segment]: + """Two concentric arcs (inner + outer), each with a gap, gaps on opposite + sides so a path winds between them like a circular maze.""" + import math + + # Outer ring with gap on the right. + outer = _arc(256.0, 256.0, 130.0, math.pi / 5, 2 * math.pi - math.pi / 5, n=32) + # Inner ring with gap on the left. + inner = _arc(256.0, 256.0, 55.0, math.pi + math.pi / 4, 3 * math.pi - math.pi / 4, n=20) + return outer + inner + + +OBSTACLE_LEVELS: dict[int, list[Segment]] = { 0: [], 1: [ ((180.0, 180.0), (180.0, 332.0)), @@ -22,13 +393,33 @@ ((180.0, 100.0), (180.0, 280.0)), ((332.0, 232.0), (332.0, 412.0)), ], - 3: [ - # Narrow vertical corridor down the middle plus a top deflector, - # forcing pushes around the obstacles to reach goals on either side. - ((220.0, 0.0), (220.0, 280.0)), - ((292.0, 232.0), (292.0, 512.0)), - ((100.0, 140.0), (220.0, 140.0)), - ], + 3: _level_3_chicane(), + 4: _level_4_single_box(), + 5: _level_5_corner_l(), + 6: _level_6_t_obstacle(), + 7: _level_7_central_cross(), + 8: _level_8_diamond(), + 9: _level_9_two_boxes(), + 10: _level_10_zigzag_walls(), + 11: _level_11_corner_ls(), + 12: _level_12_central_cross_plus_boxes(), + 13: _level_13_dual_diamonds(), + 14: _level_14_hexagon_center(), + 15: _level_15_hexagon_with_walls(), + 16: _level_16_three_boxes_row(), + 17: _level_17_triangle_and_box(), + 18: _level_18_double_zigzag(), + 19: _level_19_maze(), + 20: _level_20_c_curve(), + 21: _level_21_s_curve(), + 22: _level_22_open_ring(), + 23: _level_23_four_quarter_corners(), + 24: _level_24_arch_top(), + 25: _level_25_bowl_bottom(), + 26: _level_26_wave_floor(), + 27: _level_27_two_loops(), + 28: _level_28_corner_quarter_pipe(), + 29: _level_29_concentric_rings(), } @@ -47,3 +438,7 @@ def build_obstacles(space: pymunk.Space, level: int) -> list[pymunk.Segment]: if segments: space.add(*segments) return segments + + +def all_levels() -> Iterable[int]: + return sorted(OBSTACLE_LEVELS) diff --git a/Tsimulation/pushshapes/render.py b/Tsimulation/pushshapes/render.py index cd2fc7dd4..de8213084 100644 --- a/Tsimulation/pushshapes/render.py +++ b/Tsimulation/pushshapes/render.py @@ -16,10 +16,10 @@ import pymunk from Tsimulation.pushshapes.shapes import ( - PUSHER_RADIUS, SHAPES, STICK_HALF_LEN, STICK_HALF_THICK, + pusher_radius, ) BG_COLOR = (240, 240, 240) @@ -130,9 +130,12 @@ def _draw_pusher( angle: float, ) -> None: px, py = pos - if pusher_shape == "circle": + if pusher_shape in ("circle", "circle_small"): pygame.draw.circle( - surface, PUSHER_COLOR, (int(px), int(py)), int(PUSHER_RADIUS) + surface, + PUSHER_COLOR, + (int(px), int(py)), + max(1, int(round(pusher_radius(pusher_shape)))), ) return diff --git a/Tsimulation/pushshapes/shapes.py b/Tsimulation/pushshapes/shapes.py index 81050d298..97f402535 100644 --- a/Tsimulation/pushshapes/shapes.py +++ b/Tsimulation/pushshapes/shapes.py @@ -56,9 +56,27 @@ OBJECT_DENSITY = 0.30 OBJECT_FRICTION = 0.6 PUSHER_RADIUS = 15.0 +PUSHER_RADIUS_SMALL = 5.0 # circle_small: 3x smaller than the standard circle STICK_HALF_LEN = 30.0 STICK_HALF_THICK = 5.0 +# Per-shape effective pusher radius — used by env spawn-clearance and renderer. +# Stick uses its end-cap radius (the largest contact circle on its body). +_PUSHER_RADII: dict[str, float] = { + "circle": PUSHER_RADIUS, + "circle_small": PUSHER_RADIUS_SMALL, + "stick": STICK_HALF_THICK, +} + + +def pusher_radius(shape: str) -> float: + """Effective contact radius for ``shape``. Raises on unknown shapes.""" + if shape not in _PUSHER_RADII: + raise ValueError( + f"unknown pusher shape '{shape}', valid: {list(_PUSHER_RADII)}" + ) + return _PUSHER_RADII[shape] + def _rect_verts(cx: float, cy: float, w: float, h: float) -> list[tuple[float, float]]: hw, hh = w / 2.0, h / 2.0 @@ -96,7 +114,7 @@ def make_object( def make_pusher( - shape: Literal["circle", "stick"], + shape: Literal["circle", "circle_small", "stick"], space: pymunk.Space, position: tuple[float, float], ) -> tuple[pymunk.Body, list[pymunk.Shape]]: @@ -104,8 +122,8 @@ def make_pusher( body = pymunk.Body(body_type=pymunk.Body.KINEMATIC) body.position = position - if shape == "circle": - s = pymunk.Circle(body, PUSHER_RADIUS) + if shape in ("circle", "circle_small"): + s = pymunk.Circle(body, pusher_radius(shape)) s.friction = OBJECT_FRICTION space.add(body, s) return body, [s] @@ -124,7 +142,9 @@ def make_pusher( space.add(body, rect, end_a, end_b) return body, [rect, end_a, end_b] - raise ValueError(f"unknown pusher shape '{shape}', valid: ['circle', 'stick']") + raise ValueError( + f"unknown pusher shape '{shape}', valid: {list(_PUSHER_RADII)}" + ) def aabb(shape: str) -> tuple[float, float, float, float]: