Skip to content

Commit 96abbb4

Browse files
committed
fix: add HTTP request timeouts to prevent hung threads
No timeout was set on any session.get/post/delete call, causing threads to block indefinitely on slow or stalled connections. Add connect_timeout (9.05s) and read_timeout (300s) to config and wire them into _send_request. Also add Timeout to the retryable exception set. Signed-off-by: suhr25 <suhridmarwah07@gmail.com>
1 parent da993f7 commit 96abbb4

3 files changed

Lines changed: 24 additions & 4 deletions

File tree

openml/_api_calls.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,16 +370,19 @@ def _send_request( # noqa: C901, PLR0912
370370
# Error to raise in case of retrying too often. Will be set to the last observed exception.
371371
retry_raise_e: Exception | None = None
372372

373+
timeout = (config.connect_timeout, config.read_timeout)
373374
with requests.Session() as session:
374375
# Start at one to have a non-zero multiplier for the sleep
375376
for retry_counter in range(1, n_retries + 1):
376377
try:
377378
if request_method == "get":
378-
response = session.get(url, params=data, headers=_HEADERS)
379+
response = session.get(url, params=data, headers=_HEADERS, timeout=timeout)
379380
elif request_method == "delete":
380-
response = session.delete(url, params=data, headers=_HEADERS)
381+
response = session.delete(url, params=data, headers=_HEADERS, timeout=timeout)
381382
elif request_method == "post":
382-
response = session.post(url, data=data, files=files, headers=_HEADERS)
383+
response = session.post(
384+
url, data=data, files=files, headers=_HEADERS, timeout=timeout
385+
)
383386
else:
384387
raise NotImplementedError()
385388

@@ -424,6 +427,7 @@ def _send_request( # noqa: C901, PLR0912
424427
) from e
425428
retry_raise_e = e
426429
except (
430+
requests.exceptions.Timeout,
427431
requests.exceptions.ChunkedEncodingError,
428432
requests.exceptions.ConnectionError,
429433
requests.exceptions.SSLError,

openml/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class _Config(TypedDict):
3737
retry_policy: Literal["human", "robot"]
3838
connection_n_retries: int
3939
show_progress: bool
40+
connect_timeout: float
41+
read_timeout: float
4042

4143

4244
def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT002
@@ -157,6 +159,8 @@ def _resolve_default_cache_dir() -> Path:
157159
"retry_policy": "human",
158160
"connection_n_retries": 5,
159161
"show_progress": False,
162+
"connect_timeout": 9.05,
163+
"read_timeout": 300.0,
160164
}
161165

162166
# Default values are actually added here in the _setup() function which is
@@ -186,6 +190,8 @@ def get_server_base_url() -> str:
186190

187191
retry_policy: Literal["human", "robot"] = _defaults["retry_policy"]
188192
connection_n_retries: int = _defaults["connection_n_retries"]
193+
connect_timeout: float = _defaults["connect_timeout"]
194+
read_timeout: float = _defaults["read_timeout"]
189195

190196

191197
def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None:
@@ -343,6 +349,8 @@ def _setup(config: _Config | None = None) -> None:
343349
global _root_cache_directory # noqa: PLW0603
344350
global avoid_duplicate_runs # noqa: PLW0603
345351
global show_progress # noqa: PLW0603
352+
global connect_timeout # noqa: PLW0603
353+
global read_timeout # noqa: PLW0603
346354

347355
config_file = determine_config_file_path()
348356
config_dir = config_file.parent
@@ -364,6 +372,8 @@ def _setup(config: _Config | None = None) -> None:
364372
apikey = config["apikey"]
365373
server = config["server"]
366374
show_progress = config["show_progress"]
375+
connect_timeout = float(config.get("connect_timeout", _defaults["connect_timeout"])) # type: ignore[union-attr]
376+
read_timeout = float(config.get("read_timeout", _defaults["read_timeout"])) # type: ignore[union-attr]
367377
n_retries = int(config["connection_n_retries"])
368378

369379
set_retry_policy(config["retry_policy"], n_retries)
@@ -445,6 +455,8 @@ def get_config_as_dict() -> _Config:
445455
"connection_n_retries": connection_n_retries,
446456
"retry_policy": retry_policy,
447457
"show_progress": show_progress,
458+
"connect_timeout": connect_timeout,
459+
"read_timeout": read_timeout,
448460
}
449461

450462

tests/test_openml/test_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ def test_get_config_as_dict(self):
8484
_config["connection_n_retries"] = 20
8585
_config["retry_policy"] = "robot"
8686
_config["show_progress"] = False
87+
_config["connect_timeout"] = openml.config.connect_timeout
88+
_config["read_timeout"] = openml.config.read_timeout
8789
assert isinstance(config, dict)
88-
assert len(config) == 7
90+
assert len(config) == 9
8991
self.assertDictEqual(config, _config)
9092

9193
def test_setup_with_config(self):
@@ -98,6 +100,8 @@ def test_setup_with_config(self):
98100
_config["retry_policy"] = "human"
99101
_config["connection_n_retries"] = 100
100102
_config["show_progress"] = False
103+
_config["connect_timeout"] = 5.0
104+
_config["read_timeout"] = 120.0
101105
orig_config = openml.config.get_config_as_dict()
102106
openml.config._setup(_config)
103107
updated_config = openml.config.get_config_as_dict()

0 commit comments

Comments
 (0)