Skip to content

[BUG] SliceSampler doesn't work as expected when collecting data from parallel environment? #3194

@briannnyee

Description

@briannnyee

Hi TorchRL devs,

I currently have this settings (simplified for clarity):

sampler = SliceSampler(
            slice_len=4, 
            end_key=None,
            traj_key=("collector", "traj_ids"),
            truncated_key=None,
		    strict_length=True,
)

...some codes...

frames_per_batch = num_envs * num_steps_per_env
collector = SyncDataCollectorWrapper(
            create_env_fn=env,
            policy=actor_module,
            frames_per_batch=frames_per_batch,
            total_frames=total_frames,
            init_random_frames=init_random_frames,
            exploration_type=ExplorationType.RANDOM,
            device=self.device,
)

...some codes...

data = next(collector_iter)
self.replay_buffer.extend(data.reshape(-1))

batch = self.replay_buffer.sample()
# RuntimeError: Did not find a single trajectory with sufficient length (length range: 1 - 1 / required=4))

After spending some time investigating this, I realized that the problem could be because SliceSampler expects wrong format of traj_key. Let's say we have num_envs=2 and num_steps_per_env=1, SliceSampler expects the data is stored in a episodic way, e.g. traj_key=[0,0,0,...0,1,1,1...,1]. While in reality, the data is stored sequentially, traj_key=[0,1,0,1,...,0,1].

Did I do something wrong here? or is there a way to workaround this / is it a bug that needs a patch?

My torchrl version is 0.8. Let me know if I need to provide more info. Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions