Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ logs/
site/

coverage.xml
PAK/
SEN/
NGA/
KEN/
60 changes: 39 additions & 21 deletions src/laser/init/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,33 +105,51 @@

from laser.init import __version__

__all__ = ["VERSION", "configuration"]
__all__ = ["VERSION", "configuration", "default_cache_directory", "default_log_directory"]

VERSION = __version__

default_cache_directory = Path.home() / ".laser" / "cache"
default_log_directory = Path.home() / ".laser" / "logs"

configuration = {}

# look for laser_config.[yaml,json]
# prefer the current working directory
# then look in the user's home directory / .laser
for path in [
candidates = [
Path.cwd() / "laser_config.yaml",
Path.cwd() / "laser_config.json",
Path.home() / ".laser" / "laser_config.yaml",
Path.home() / ".laser" / "laser_config.json",
]:
if path.is_file():
if path.suffix.lower() == ".yaml":
try:
configuration = yaml.safe_load(path.read_text())
except yaml.YAMLError as e:
warnings.warn(f"Error parsing YAML configuration file {path}: {e}", stacklevel=2)
configuration = {}
break
elif path.suffix.lower() == ".json":
try:
configuration = json.loads(path.read_text())
except json.JSONDecodeError as e:
warnings.warn(f"Error parsing JSON configuration file {path}: {e}", stacklevel=2)
configuration = {}
break
]

# set some defaults:
if any(path.is_file() for path in candidates):
# look for laser_config.[yaml,json]
# prefer the current working directory
# then look in the user's home directory / .laser
for path in candidates:
if path.is_file():
if path.suffix.lower() == ".yaml":
try:
configuration = yaml.safe_load(path.read_text())
except yaml.YAMLError as e:
warnings.warn(f"Error parsing YAML configuration file {path}: {e}", stacklevel=2)
configuration = {}
break
elif path.suffix.lower() == ".json":
try:
configuration = json.loads(path.read_text())
except json.JSONDecodeError as e:
warnings.warn(f"Error parsing JSON configuration file {path}: {e}", stacklevel=2)
configuration = {}
break
else:
warnings.warn("Did not find a laser configuration file. Using default.", stacklevel=2)
configuration = {
"cache_dir": str(default_cache_directory),
"log_dir": str(default_log_directory),
"shape_source": "unocha",
"raster_source": "worldpop",
"stats_source": "unwpp",
# "openai_api_key": "sk-your-key-here",
# "anthropic_api_key": "sk-ant-your-key-here",
}
3 changes: 2 additions & 1 deletion src/laser/init/extractors/gadm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

from ..config import configuration as config
from ..config import default_cache_directory
from ..utils import download_file, error, inform


Expand Down Expand Up @@ -54,7 +55,7 @@ def extract(self, country: str, level: int, year: int) -> Path | None:
RuntimeError: If both shapefile and geopackage downloads fail.
"""

cache_root = Path(config.get("cache_dir", Path.cwd()))
cache_root = Path(config.get("cache_dir", default_cache_directory))
gadm_path = Path("gadm") / country
(cache_root / gadm_path).mkdir(parents=True, exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion src/laser/init/extractors/geoboundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

from ..config import configuration as config
from ..config import default_cache_directory
from ..utils import download_file, error, inform


Expand Down Expand Up @@ -51,7 +52,7 @@ def extract(self, country: str, level: int, year: int) -> Path | None:

# Sample: https://github.com/wmgeolab/geoBoundaries/raw/refs/tags/v6.0.0/releaseData/gbOpen/MCO/ADM1/geoBoundaries-MCO-ADM1-all.zip

cache_root = Path(config.get("cache_dir", Path.cwd()))
cache_root = Path(config.get("cache_dir", default_cache_directory))
geoboundaries_path = Path("geoBoundaries") / country
(cache_root / geoboundaries_path).mkdir(parents=True, exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion src/laser/init/extractors/unocha.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

from ..config import configuration as config
from ..config import default_cache_directory
from ..utils import download_file, error, inform


Expand Down Expand Up @@ -47,7 +48,7 @@ def extract(self, country: str, level: int, year: int) -> Path | None:
RuntimeError: If the download fails.
"""

cache_root = Path(config.get("cache_dir", Path.cwd()))
cache_root = Path(config.get("cache_dir", default_cache_directory))
unocha_path: Path = Path("UNOCHA")
(cache_root / unocha_path).mkdir(parents=True, exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion src/laser/init/extractors/unwpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pathlib import Path

from ..config import configuration as config
from ..config import default_cache_directory
from ..utils import download_file, error, inform


Expand Down Expand Up @@ -84,7 +85,7 @@ def extract(self, country: str, start_year: int, end_year: int) -> Path | None:
# UNWPP data is provided as large CSV files covering all countries, so we download the relevant
# files and then filter them locally (in the transformer) for the specified country and year range.

cache_root = Path(config.get("cache_dir", Path.cwd()))
cache_root = Path(config.get("cache_dir", default_cache_directory))
unwpp_path = Path("UNWPP")
(cache_root / unwpp_path).mkdir(parents=True, exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion src/laser/init/extractors/worldpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pathlib import Path

from ..config import configuration as config
from ..config import default_cache_directory
from ..utils import download_file, error, inform


Expand Down Expand Up @@ -71,7 +72,7 @@ def extract(self, country: str, year: int) -> Path | None:

local_path = None

cache_root = Path(config.get("cache_dir", Path.cwd()))
cache_root = Path(config.get("cache_dir", default_cache_directory))
worldpop_path: Path = Path("WorldPop")
(cache_root / worldpop_path).mkdir(parents=True, exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion src/laser/init/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def download_data(url):
from pathlib import Path

from .config import configuration as config
from .config import default_log_directory

__all__ = ["logger"]

Expand All @@ -109,7 +110,7 @@ def download_data(url):
logger.addHandler(console_handler)

# File handler (all logs, timestamped file)
log_dir = Path(config.get("log_dir", Path("~").expanduser() / ".laser" / "logs"))
log_dir = Path(config.get("log_dir", default_log_directory))
log_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"laser-init_{timestamp}.log"
Expand Down
3 changes: 2 additions & 1 deletion src/laser/init/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tqdm import tqdm

from .config import configuration as config
from .config import default_cache_directory
from .french_iso import french_mapping as __french_mapping__
from .logger import logger

Expand Down Expand Up @@ -294,7 +295,7 @@ def update_local_provenance(output_dir: Path, output_filename: Path, *files: lis
Returns:
None
"""
cache_root = Path(config.get("cache_dir", Path("~").expanduser() / ".laser" / "cache"))
cache_root = Path(config.get("cache_dir", default_cache_directory))
provenance_file = cache_root / "provenance.json"
sources = json.loads(provenance_file.read_text())
provenance_local = output_dir / "provenance.json"
Expand Down
Loading