diff --git a/dagshub/common/config.py b/dagshub/common/config.py index e82b0063..8e5e8949 100644 --- a/dagshub/common/config.py +++ b/dagshub/common/config.py @@ -21,6 +21,7 @@ HTTP_TIMEOUT_KEY = "DAGSHUB_HTTP_TIMEOUT" DAGSHUB_QUIET_KEY = "DAGSHUB_QUIET" DISABLE_TRACEPARENT_KEY = "DAGSHUB_DISABLE_TRACEPARENT" +TRUE_VALUES = {"1", "true", "yes", "on"} def set_host(new_host: str): @@ -32,6 +33,13 @@ def set_host(new_host: str): hostname, host, parsed_host = _hostname, _host, _parsed_host +def _get_boolean_env(key: str, default: bool = False) -> bool: + value = os.environ.get(key) + if value is None: + return default + return value.strip().lower() in TRUE_VALUES + + hostname = "" host = "" parsed_host = "" @@ -47,9 +55,9 @@ def set_host(new_host: str): http_timeout = os.environ.get(HTTP_TIMEOUT_KEY, 30) REPO_INFO_URL = "api/v1/repos/{owner}/{reponame}" -quiet = bool(os.environ.get(DAGSHUB_QUIET_KEY, False)) +quiet = _get_boolean_env(DAGSHUB_QUIET_KEY) -disable_traceparent = bool(os.environ.get(DISABLE_TRACEPARENT_KEY, False)) +disable_traceparent = _get_boolean_env(DISABLE_TRACEPARENT_KEY) # DVC config templates CONFIG_GITIGNORE = "/config.local\n/tmp\n/cache" diff --git a/tests/common/test_config.py b/tests/common/test_config.py new file mode 100644 index 00000000..aaf9e602 --- /dev/null +++ b/tests/common/test_config.py @@ -0,0 +1,51 @@ +import importlib.util +import os +import sys +import types +import unittest +from pathlib import Path +from unittest.mock import patch + +CONFIG_PATH = Path(__file__).resolve().parents[2] / "dagshub" / "common" / "config.py" + + +def load_config_module(): + fake_dagshub = types.ModuleType("dagshub") + fake_dagshub.__version__ = "test-version" + + fake_appdirs = types.ModuleType("appdirs") + fake_appdirs.user_cache_dir = lambda app_name: f"/tmp/{app_name}" + + fake_httpx = types.ModuleType("httpx") + fake_httpx_client = types.ModuleType("httpx._client") + fake_httpx_client.USER_AGENT = "test-agent" + + spec = importlib.util.spec_from_file_location("test_dagshub_common_config", CONFIG_PATH) + module = importlib.util.module_from_spec(spec) + with patch.dict( + sys.modules, + { + "dagshub": fake_dagshub, + "appdirs": fake_appdirs, + "httpx": fake_httpx, + "httpx._client": fake_httpx_client, + }, + ): + spec.loader.exec_module(module) + return module + + +class ConfigTestCase(unittest.TestCase): + def test_boolean_env_flags_are_parsed_from_string_values(self): + with patch.dict( + os.environ, + { + "DAGSHUB_QUIET": "false", + "DAGSHUB_DISABLE_TRACEPARENT": "1", + }, + clear=False, + ): + config = load_config_module() + + self.assertFalse(config.quiet) + self.assertTrue(config.disable_traceparent)