-
Notifications
You must be signed in to change notification settings - Fork 667
Open
Description
I'm running the training loop with this code, and noticed the trainer only processes one trajectory at a time.
for step in range(TRAINING_STEPS):
print(f"Step {step+1} rollout")
train_groups = []
for scenario in scenarios:
trajectories = await rollout(model.name, vllm_url, scenario)
rewards = [traj.reward for traj in trajectories]
print(f"\n{scenario['id']} rewards: {rewards}")
train_groups.append(art.TrajectoryGroup(trajectories))
print(f"Step {step+1} training")
await model.delete_checkpoints()
await model.train(train_groups, config=art.TrainConfig(learning_rate=1e-5))
(APIServer pid=1156) Step 19 rollout
(APIServer pid=1156) ....................................................................................................................................................................................
(APIServer pid=1156) s1 rewards: [0.41975308641975306, 0.012345679012345678, 0.1851851851851852, 0.654320987654321, 0.19753086419753085, 0.3950617283950617, 0.1358024691358025, 0.6172839506172839]
(APIServer pid=1156) Step 19 training
(APIServer pid=1156) No "val/reward" metric found in history
(APIServer pid=1156) Deleted checkpoint art/art-rl/models/qwen3-14b-rl/checkpoints/0017
(APIServer pid=1156) Packed 118 trajectories into 118 sequences of length 6144
train: 100%
118/118 [08:05<00:00, 4.02s/it, loss=0.0146, grad_norm=2.29, policy_loss=0.0146, entropy=0.0266]
This seems to be aligned with the code in inputs.py, which bypasses per_device_train_batch_size config.
Is there a reason to limit to one sample at a time?
Metadata
Metadata
Assignees
Labels
No labels