diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index 2b52d292..2a794e30 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -346,6 +346,9 @@ class ResourceRequest: network: Optional[str] = None #: Template name to use for Ray jobs (e.g., "ray.sub.j2" or "ray_enroot.sub.j2") ray_template: str = "ray.sub.j2" + #: When True, a background thread polls squeue --start while the job is pending + #: and prints the estimated start time. Set to False to disable this behaviour. + poll_estimated_start_time: bool = True #: Set by the executor; cannot be initialized job_name: str = field(init=False, default="nemo-job") diff --git a/nemo_run/run/torchx_backend/schedulers/slurm.py b/nemo_run/run/torchx_backend/schedulers/slurm.py index 912bca43..9af9483a 100644 --- a/nemo_run/run/torchx_backend/schedulers/slurm.py +++ b/nemo_run/run/torchx_backend/schedulers/slurm.py @@ -257,21 +257,22 @@ def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest | SlurmRayReques # Save metadata _save_job_dir(job_id, job_dir, tunnel, slurm_executor.job_details.ls_term) - # Stop any existing polling thread for this job_id (retry scenario) - if job_id in self._start_time_stop_events: - self._start_time_stop_events.pop(job_id).set() - self._start_time_threads.pop(job_id, None) - - stop_event = threading.Event() - self._start_time_stop_events[job_id] = stop_event - thread = threading.Thread( - target=self._poll_job_start_time, - args=(job_id, self.tunnel, stop_event), - daemon=True, - name=f"slurm-start-time-{job_id}", - ) - self._start_time_threads[job_id] = thread - thread.start() + if slurm_executor.poll_estimated_start_time: + # Stop any existing polling thread for this job_id (retry scenario) + if job_id in self._start_time_stop_events: + self._start_time_stop_events.pop(job_id).set() + self._start_time_threads.pop(job_id, None) + + stop_event = threading.Event() + self._start_time_stop_events[job_id] = stop_event + thread = threading.Thread( + target=self._poll_job_start_time, + args=(job_id, self.tunnel, stop_event), + daemon=True, + name=f"slurm-start-time-{job_id}", + ) + self._start_time_threads[job_id] = thread + thread.start() return job_id diff --git a/test/run/torchx_backend/schedulers/test_slurm.py b/test/run/torchx_backend/schedulers/test_slurm.py index 3e99a93f..857d7a66 100644 --- a/test/run/torchx_backend/schedulers/test_slurm.py +++ b/test/run/torchx_backend/schedulers/test_slurm.py @@ -942,3 +942,29 @@ def test_cancel_stops_polling_thread_for_job(slurm_scheduler, mocker): assert ev.is_set() assert job_id not in slurm_scheduler._start_time_stop_events + + +def test_schedule_skips_polling_thread_when_disabled(slurm_scheduler, mocker): + """When poll_estimated_start_time=False on the executor, no thread is started.""" + job_id = "88888" + dryrun_info = mock.MagicMock() + dryrun_info.request.executor.poll_estimated_start_time = False + dryrun_info.request.executor.job_dir = "/tmp/test" + dryrun_info.request.executor.tunnel = mock.MagicMock() + dryrun_info.request.executor.dependencies = [] + dryrun_info.request.executor.job_name = "test-job" + dryrun_info.request.executor.job_details.ls_term = "" + + mock_tunnel = mock.MagicMock() + mock_tunnel.run.return_value.stdout = job_id + slurm_scheduler.tunnel = mock_tunnel + + mocker.patch.object(SlurmTunnelScheduler, "_initialize_tunnel") + mocker.patch("nemo_run.run.torchx_backend.schedulers.slurm._save_job_dir") + poll_mock = mocker.patch.object(SlurmTunnelScheduler, "_poll_job_start_time") + + slurm_scheduler.schedule(dryrun_info) + + poll_mock.assert_not_called() + assert job_id not in slurm_scheduler._start_time_threads + assert job_id not in slurm_scheduler._start_time_stop_events