Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ class ResourceRequest:
network: Optional[str] = None
#: Template name to use for Ray jobs (e.g., "ray.sub.j2" or "ray_enroot.sub.j2")
ray_template: str = "ray.sub.j2"
#: When True, a background thread polls squeue --start while the job is pending
#: and prints the estimated start time. Set to False to disable this behaviour.
poll_estimated_start_time: bool = True

#: Set by the executor; cannot be initialized
job_name: str = field(init=False, default="nemo-job")
Expand Down
31 changes: 16 additions & 15 deletions nemo_run/run/torchx_backend/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,22 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayReques
# Save metadata
_save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term)

# Stop any existing polling thread for this job_id (retry scenario)
if job_id in self._start_time_stop_events:
self._start_time_stop_events.pop(job_id).set()
self._start_time_threads.pop(job_id, None)

stop_event = threading.Event()
self._start_time_stop_events[job_id] = stop_event
thread = threading.Thread(
target=self._poll_job_start_time,
args=(job_id, self.tunnel, stop_event),
daemon=True,
name=f"slurm-start-time-{job_id}",
)
self._start_time_threads[job_id] = thread
thread.start()
if slurm_executor.poll_estimated_start_time:
# Stop any existing polling thread for this job_id (retry scenario)
if job_id in self._start_time_stop_events:
self._start_time_stop_events.pop(job_id).set()
self._start_time_threads.pop(job_id, None)

stop_event = threading.Event()
self._start_time_stop_events[job_id] = stop_event
thread = threading.Thread(
target=self._poll_job_start_time,
args=(job_id, self.tunnel, stop_event),
daemon=True,
name=f"slurm-start-time-{job_id}",
)
self._start_time_threads[job_id] = thread
thread.start()

return job_id

Expand Down
26 changes: 26 additions & 0 deletions test/run/torchx_backend/schedulers/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,3 +942,29 @@ def test_cancel_stops_polling_thread_for_job(slurm_scheduler, mocker):

assert ev.is_set()
assert job_id not in slurm_scheduler._start_time_stop_events


def test_schedule_skips_polling_thread_when_disabled(slurm_scheduler, mocker):
"""When poll_estimated_start_time=False on the executor, no thread is started."""
job_id = "88888"
dryrun_info = mock.MagicMock()
dryrun_info.request.executor.poll_estimated_start_time = False
dryrun_info.request.executor.job_dir = "/tmp/test"
dryrun_info.request.executor.tunnel = mock.MagicMock()
dryrun_info.request.executor.dependencies = []
dryrun_info.request.executor.job_name = "test-job"
dryrun_info.request.executor.job_details.ls_term = ""

mock_tunnel = mock.MagicMock()
mock_tunnel.run.return_value.stdout = job_id
slurm_scheduler.tunnel = mock_tunnel

mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel")
mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir")
poll_mock = mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time")

slurm_scheduler.schedule(dryrun_info)

poll_mock.assert_not_called()
assert job_id not in slurm_scheduler._start_time_threads
assert job_id not in slurm_scheduler._start_time_stop_events
Loading