Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/464.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implemented a memory constraint using the environment variable `MEMORY_LIMIT_PARSL_JOB_GB` to set the memory limit (units: GB) for a PARSL worker.
42 changes: 40 additions & 2 deletions packages/climate-ref/src/climate_ref/executor/hpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@

import os
import re
import resource
import time
from typing import Annotated, Any, Literal
from collections.abc import Callable
from typing import Annotated, Any, Literal, TypeVar, cast

import parsl
from loguru import logger
Expand All @@ -44,6 +46,8 @@
from .local import ExecutionFuture, process_result
from .pbs_scheduler import SmartPBSProvider

F = TypeVar("F", bound=Callable[..., Any])


class SlurmConfig(BaseModel):
"""Slurm Configurations"""
Expand All @@ -61,7 +65,7 @@ class SlurmConfig(BaseModel):
validation: StrictBool = False
walltime: str = "00:30:00"
scheduler_options: str = ""
retries: Annotated[int, Field(strict=True, ge=1, le=3)] = 2
retries: Annotated[int, Field(strict=True, ge=0, le=3)] = 2
max_blocks: Annotated[int, Field(strict=True, ge=1)] = 1 # one block mean one job?
worker_init: str = ""
overrides: str = ""
Expand Down Expand Up @@ -111,7 +115,41 @@ def _validate_walltime(cls, v: str) -> str:
return v


def with_memory_limit(limit_gb: float | Callable[..., float | None]) -> Callable[[F], F]:
"""Set memory limit for a parsl worker"""

def decorator(func: F) -> F:
def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
current_limit = limit_gb(*args, **kwargs) if callable(limit_gb) else limit_gb
except Exception:
current_limit = None

if current_limit is not None and current_limit > 0:
bytes_limit = int(current_limit * 1024 * 1024 * 1024)
_, hard0 = resource.getrlimit(resource.RLIMIT_AS)
soft = min(bytes_limit, hard0) if hard0 > 0 else bytes_limit
resource.setrlimit(resource.RLIMIT_AS, (soft, hard0))
return func(*args, **kwargs)

return cast(F, wrapper)

return decorator


def limit_from_env(*args: Any, **kwargs: Any) -> float | None:
"""Get the memory limits from env variables"""
val = os.getenv("MEMORY_LIMIT_PARSL_JOB_GB")
if not val:
return None
try:
return float(val)
except ValueError:
return None


@python_app
@with_memory_limit(limit_from_env)
def _process_run(definition: ExecutionDefinition, log_level: str) -> ExecutionResult:
"""Run the function on computer nodes"""
# This is a catch-all for any exceptions that occur in the process and need to raise for
Expand Down
44 changes: 43 additions & 1 deletion packages/climate-ref/tests/unit/executor/test_hpc_executor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
import re
import resource
from unittest.mock import MagicMock, patch

import parsl
import pytest
from parsl.dataflow import futures
from pydantic import ValidationError

from climate_ref.executor.hpc import HPCExecutor, SlurmConfig, execute_locally
from climate_ref.executor.hpc import (
HPCExecutor,
SlurmConfig,
execute_locally,
limit_from_env,
with_memory_limit,
)
from climate_ref.executor.local import ExecutionFuture
from climate_ref_core.diagnostics import ExecutionResult
from climate_ref_core.exceptions import DiagnosticError
Expand Down Expand Up @@ -173,3 +181,37 @@ def test_hpc_slurm_missing_required_config(self, missing_config, base_config):
[slurm_cfg_dict.pop(m) for m in missing_config]
with pytest.raises(ValidationError):
SlurmConfig.model_validate(slurm_cfg_dict)

def test_memeory_limit_fixed(self):
"""Test with a fixed numeric memory limit"""

@with_memory_limit(2.0)
def test_func():
# Check if memory limit was set
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
expected_bytes = int(2.0 * 1024 * 1024 * 1024)
assert soft == expected_bytes
assert hard >= expected_bytes

def test_memory_limit_func(self):
# Save original state
os.environ.pop("MEMORY_LIMIT_PARSL_JOB_GB", None)
orig_soft, orig_hard = resource.getrlimit(resource.RLIMIT_AS)

@with_memory_limit(limit_from_env)
def unset_func():
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
assert soft == orig_soft
assert hard == orig_hard

unset_func()
os.environ["MEMORY_LIMIT_PARSL_JOB_GB"] = "7"

@with_memory_limit(limit_from_env)
def set_func():
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
expected_bytes = 7 * 1024 * 1024 * 1024
assert soft == expected_bytes
assert hard == -1 or hard >= expected_bytes

set_func()
Loading