Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion utils/UpgradeInstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import aiohttp
import logging
import os
from logging import StreamHandler, Formatter
from packaging.requirements import Requirement
from utils.PyPiUtils import GetPyPiInfo
Expand All @@ -15,6 +16,8 @@
from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet
from utils.SGTUtils import SGTFormatter
from utils.ConfigUtils import parse_requirements
from dotenv import load_dotenv
try:
from zoneinfo import ZoneInfo # Python 3.9+
except ImportError:
Expand All @@ -27,6 +30,24 @@
logger.addHandler(handler)
logger.propagate = False # Avoid duplicate logs from root logger

# cache of current package versions from requirements file
_CURRENT_VERSIONS: dict[str, str] | None = None


def _load_current_versions() -> dict[str, str]:
"""Load current package versions from the requirements file."""
global _CURRENT_VERSIONS
if _CURRENT_VERSIONS is None:
load_dotenv(dotenv_path=".env")
req_file = os.getenv("REQUIREMENTS_FILE", "src/requirements_full_list.txt")
try:
mapping = parse_requirements(req_file)
_CURRENT_VERSIONS = {k.lower(): v for k, v in mapping.items()}
except Exception as e: # pragma: no cover - robustness
logger.warning(f"Failed to load requirements from {req_file}: {e}")
_CURRENT_VERSIONS = {}
return _CURRENT_VERSIONS

def _extract_min_version(req: Requirement) -> str | None:
"""
Return the minimal version that satisfies the requirement specifier.
Expand Down Expand Up @@ -153,10 +174,32 @@ def generate_upgrade_instruction(base_package: str, target_version: str) -> dict

# Use asyncio.run to avoid 'event loop already running' issues
SafeVersions = asyncio.run(get_safe_dependency_versions(requires_dist))
current_versions = _load_current_versions()

dependencies: list[str] = []
for dep in requires_dist:
try:
req = Requirement(dep)
except Exception as e: # pragma: no cover - unexpected formats
logger.warning(f"Failed to parse dependency {dep}: {e}")
continue

cur = current_versions.get(req.name.lower())
if cur:
try:
if req.specifier.contains(Version(cur), prereleases=True):
# already within required range; skip
continue
except InvalidVersion:
pass

safe = SafeVersions.get(req.name)
if safe:
dependencies.append(f"{req.name}=={safe}")

instruction = {
"base_package": f"{base_package}=={target_version}",
"dependencies": [f"{k}=={v}" for k, v in SafeVersions.items() if v]
"dependencies": dependencies,
}
return instruction

Expand Down