diff --git a/.gitignore b/.gitignore index 3e35a86..7de0ef0 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,7 @@ logs/ site/ coverage.xml +PAK/ +SEN/ +NGA/ +KEN/ diff --git a/src/laser/init/config.py b/src/laser/init/config.py index 49bd462..2e378e5 100644 --- a/src/laser/init/config.py +++ b/src/laser/init/config.py @@ -105,33 +105,51 @@ from laser.init import __version__ -__all__ = ["VERSION", "configuration"] +__all__ = ["VERSION", "configuration", "default_cache_directory", "default_log_directory"] VERSION = __version__ +default_cache_directory = Path.home() / ".laser" / "cache" +default_log_directory = Path.home() / ".laser" / "logs" + configuration = {} -# look for laser_config.[yaml,json] -# prefer the current working directory -# then look in the user's home directory / .laser -for path in [ +candidates = [ Path.cwd() / "laser_config.yaml", Path.cwd() / "laser_config.json", Path.home() / ".laser" / "laser_config.yaml", Path.home() / ".laser" / "laser_config.json", -]: - if path.is_file(): - if path.suffix.lower() == ".yaml": - try: - configuration = yaml.safe_load(path.read_text()) - except yaml.YAMLError as e: - warnings.warn(f"Error parsing YAML configuration file {path}: {e}", stacklevel=2) - configuration = {} - break - elif path.suffix.lower() == ".json": - try: - configuration = json.loads(path.read_text()) - except json.JSONDecodeError as e: - warnings.warn(f"Error parsing JSON configuration file {path}: {e}", stacklevel=2) - configuration = {} - break +] + +# set some defaults: +if any(path.is_file() for path in candidates): + # look for laser_config.[yaml,json] + # prefer the current working directory + # then look in the user's home directory / .laser + for path in candidates: + if path.is_file(): + if path.suffix.lower() == ".yaml": + try: + configuration = yaml.safe_load(path.read_text()) + except yaml.YAMLError as e: + warnings.warn(f"Error parsing YAML configuration file {path}: {e}", stacklevel=2) + configuration = {} + break + elif path.suffix.lower() == ".json": + try: + configuration = json.loads(path.read_text()) + except json.JSONDecodeError as e: + warnings.warn(f"Error parsing JSON configuration file {path}: {e}", stacklevel=2) + configuration = {} + break +else: + warnings.warn("Did not find a laser configuration file. Using default.", stacklevel=2) + configuration = { + "cache_dir": str(default_cache_directory), + "log_dir": str(default_log_directory), + "shape_source": "unocha", + "raster_source": "worldpop", + "stats_source": "unwpp", + # "openai_api_key": "sk-your-key-here", + # "anthropic_api_key": "sk-ant-your-key-here", + } diff --git a/src/laser/init/extractors/gadm.py b/src/laser/init/extractors/gadm.py index b44b4b1..b31943d 100644 --- a/src/laser/init/extractors/gadm.py +++ b/src/laser/init/extractors/gadm.py @@ -8,6 +8,7 @@ from pathlib import Path from ..config import configuration as config +from ..config import default_cache_directory from ..utils import download_file, error, inform @@ -54,7 +55,7 @@ def extract(self, country: str, level: int, year: int) -> Path | None: RuntimeError: If both shapefile and geopackage downloads fail. """ - cache_root = Path(config.get("cache_dir", Path.cwd())) + cache_root = Path(config.get("cache_dir", default_cache_directory)) gadm_path = Path("gadm") / country (cache_root / gadm_path).mkdir(parents=True, exist_ok=True) diff --git a/src/laser/init/extractors/geoboundaries.py b/src/laser/init/extractors/geoboundaries.py index d3652d8..4dd68de 100644 --- a/src/laser/init/extractors/geoboundaries.py +++ b/src/laser/init/extractors/geoboundaries.py @@ -7,6 +7,7 @@ from pathlib import Path from ..config import configuration as config +from ..config import default_cache_directory from ..utils import download_file, error, inform @@ -51,7 +52,7 @@ def extract(self, country: str, level: int, year: int) -> Path | None: # Sample: https://github.com/wmgeolab/geoBoundaries/raw/refs/tags/v6.0.0/releaseData/gbOpen/MCO/ADM1/geoBoundaries-MCO-ADM1-all.zip - cache_root = Path(config.get("cache_dir", Path.cwd())) + cache_root = Path(config.get("cache_dir", default_cache_directory)) geoboundaries_path = Path("geoBoundaries") / country (cache_root / geoboundaries_path).mkdir(parents=True, exist_ok=True) diff --git a/src/laser/init/extractors/unocha.py b/src/laser/init/extractors/unocha.py index 2fcb5e6..59f1db0 100644 --- a/src/laser/init/extractors/unocha.py +++ b/src/laser/init/extractors/unocha.py @@ -7,6 +7,7 @@ from pathlib import Path from ..config import configuration as config +from ..config import default_cache_directory from ..utils import download_file, error, inform @@ -47,7 +48,7 @@ def extract(self, country: str, level: int, year: int) -> Path | None: RuntimeError: If the download fails. """ - cache_root = Path(config.get("cache_dir", Path.cwd())) + cache_root = Path(config.get("cache_dir", default_cache_directory)) unocha_path: Path = Path("UNOCHA") (cache_root / unocha_path).mkdir(parents=True, exist_ok=True) diff --git a/src/laser/init/extractors/unwpp.py b/src/laser/init/extractors/unwpp.py index 8f2fb6f..83f4dc8 100644 --- a/src/laser/init/extractors/unwpp.py +++ b/src/laser/init/extractors/unwpp.py @@ -28,6 +28,7 @@ from pathlib import Path from ..config import configuration as config +from ..config import default_cache_directory from ..utils import download_file, error, inform @@ -84,7 +85,7 @@ def extract(self, country: str, start_year: int, end_year: int) -> Path | None: # UNWPP data is provided as large CSV files covering all countries, so we download the relevant # files and then filter them locally (in the transformer) for the specified country and year range. - cache_root = Path(config.get("cache_dir", Path.cwd())) + cache_root = Path(config.get("cache_dir", default_cache_directory)) unwpp_path = Path("UNWPP") (cache_root / unwpp_path).mkdir(parents=True, exist_ok=True) diff --git a/src/laser/init/extractors/worldpop.py b/src/laser/init/extractors/worldpop.py index 451f9e9..837833a 100644 --- a/src/laser/init/extractors/worldpop.py +++ b/src/laser/init/extractors/worldpop.py @@ -25,6 +25,7 @@ from pathlib import Path from ..config import configuration as config +from ..config import default_cache_directory from ..utils import download_file, error, inform @@ -71,7 +72,7 @@ def extract(self, country: str, year: int) -> Path | None: local_path = None - cache_root = Path(config.get("cache_dir", Path.cwd())) + cache_root = Path(config.get("cache_dir", default_cache_directory)) worldpop_path: Path = Path("WorldPop") (cache_root / worldpop_path).mkdir(parents=True, exist_ok=True) diff --git a/src/laser/init/logger.py b/src/laser/init/logger.py index a7f70f0..8144f31 100644 --- a/src/laser/init/logger.py +++ b/src/laser/init/logger.py @@ -95,6 +95,7 @@ def download_data(url): from pathlib import Path from .config import configuration as config +from .config import default_log_directory __all__ = ["logger"] @@ -109,7 +110,7 @@ def download_data(url): logger.addHandler(console_handler) # File handler (all logs, timestamped file) -log_dir = Path(config.get("log_dir", Path("~").expanduser() / ".laser" / "logs")) +log_dir = Path(config.get("log_dir", default_log_directory)) log_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = log_dir / f"laser-init_{timestamp}.log" diff --git a/src/laser/init/utils.py b/src/laser/init/utils.py index e7725e5..590b01a 100644 --- a/src/laser/init/utils.py +++ b/src/laser/init/utils.py @@ -22,6 +22,7 @@ from tqdm import tqdm from .config import configuration as config +from .config import default_cache_directory from .french_iso import french_mapping as __french_mapping__ from .logger import logger @@ -294,7 +295,7 @@ def update_local_provenance(output_dir: Path, output_filename: Path, *files: lis Returns: None """ - cache_root = Path(config.get("cache_dir", Path("~").expanduser() / ".laser" / "cache")) + cache_root = Path(config.get("cache_dir", default_cache_directory)) provenance_file = cache_root / "provenance.json" sources = json.loads(provenance_file.read_text()) provenance_local = output_dir / "provenance.json"