Skip to content
Merged
234 changes: 159 additions & 75 deletions scripts/zwik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from logging.handlers import RotatingFileHandler
from typing import Optional

__version__ = "5.16"
__version__ = "5.17"
min_supported_conda_version = "4.5.4"
max_supported_conda_version = "24.3.0"
min_supported_bootstrap_version = 7
Expand Down Expand Up @@ -820,16 +820,58 @@ def conda_envs_dir(self):

@property
def yaml_hash(self):
from conda import CondaError

if not self._yaml_hash:
import hashlib

hash_md5 = hashlib.md5()
with open(self.yaml_path, "r") as f:
for line in f.readlines():
hash_md5.update(line.encode("utf-8"))

if self.env_data:
Comment thread
abloemert marked this conversation as resolved.
if "channels" in self.env_data:
for channel in self.env_data["channels"]:
hash_md5.update(str(channel).encode("utf-8"))
Comment thread
abloemert marked this conversation as resolved.

if "dependencies" in self.env_data:
env_deps = self.get_dependencies(additional_dependencies=[])

# Convert the MatchSpec objects to strings and sort them.
# This allows the user to reorder the yaml file and not
# trigger regeneration
deps_to_hash = sorted([str(spec) for spec in env_deps.specs])

for dep in deps_to_hash:
hash_md5.update(dep.encode("utf-8"))
else:
raise CondaError(
"env_data could not be loaded, "
"check if zwik_environment.yaml is defined."
)

self._yaml_hash = hash_md5.hexdigest()

return self._yaml_hash

def get_legacy_yaml_hash(self):
import hashlib

hash_md5 = hashlib.md5()
with open(self.yaml_path, "r") as f:
in_allow_unsafe_block = False
for line in f.readlines():
if line.startswith("allow_unsafe:"):
in_allow_unsafe_block = True
continue

elif in_allow_unsafe_block and re.match(r"^[^\s#]", line):
# Reached a new root-level key, stop skipping
in_allow_unsafe_block = False

if not in_allow_unsafe_block:
hash_md5.update(line.encode("utf-8"))

return hash_md5.hexdigest()

@property
def env_name(self):
return self.lockfile_hash
Expand Down Expand Up @@ -893,9 +935,24 @@ def read_version_lock(self):
raise LockfileError("lock file seems corrupt")
with open(path, "r") as fp:
data = yaml.load(fp)

if "yaml_hash" not in data:
raise LockfileError("lock file seems incomplete")

# 1. Check against the modern hash first
hash_matches = False
if data["yaml_hash"] == self.yaml_hash:
hash_matches = True
else:
# 2. Fallback: Check if it's a legacy lock file (<= 5.16)
from conda.exports import VersionOrder

script_ver = data.get("script_version", "0")
if VersionOrder(str(script_ver)) <= VersionOrder("5.16"):
if data["yaml_hash"] == self.get_legacy_yaml_hash():
hash_matches = True

Comment thread
rudispr marked this conversation as resolved.
if hash_matches:
channel_alias = (
data.get("channel_alias"),
data.get("channel_alias").replace("http:", "https:"),
Expand All @@ -904,20 +961,19 @@ def read_version_lock(self):
raise LockfileError("lock file conda alias mismatch")
lock_file_channels = ";".join(data.get("channels", []))
env_channels = ";".join(self.channels)
# also compare with list of default channels
# for backwards compatibility
def_channels = ";".join(self.settings.default_channels)
if lock_file_channels not in (env_channels, def_channels):
raise LockfileError("lock file conda channel mismatch")
return data

log.info(
"The lock file is not aligned with the actual environment file"
)
return None
log.info("No version lock file found")
return None

def write_version_lock(self, lock_dep, obsolete_pkgs=(), unsafe_pkgs=()):
def write_version_lock(self, lock_dep):
import getpass
import hashlib
import io
Expand All @@ -939,14 +995,6 @@ def write_version_lock(self, lock_dep, obsolete_pkgs=(), unsafe_pkgs=()):
"dependencies": sorted(lock_dep),
}

labels = {}
for p in obsolete_pkgs:
labels[p] = "obsolete"
for p in unsafe_pkgs:
labels[p] = "unsafe"
if labels:
data["labels"] = labels

stream = io.StringIO()
yaml.dump(data, stream)
output = stream.getvalue()
Expand Down Expand Up @@ -1012,16 +1060,17 @@ def create_lockfile(self, additional_dependencies=()):
from conda.exceptions import (
PackagesNotFoundError,
ResolvePackageNotFound,
UnavailableInvalidChannel,
UnsatisfiableError,
)
from conda.exports import subdir

obsolete_pkgs = set()
unsafe_pkgs = set()
unsafe_pkgs = {}
last_exception = None
# First check only the original urls, then also the obsolete labels
# and finally also unsafe labels

for labels in ((), ("obsolete",), ("obsolete", "unsafe")):
log.debug("Checking labels %s", labels)
solver = self.get_solver(dependencies, labels)
try:
link_precs = solver.solve_final_state()
Expand All @@ -1030,35 +1079,37 @@ def create_lockfile(self, additional_dependencies=()):
for prec in link_precs:
split_channel = prec.channel.name.split("/")
if len(split_channel) > 1:
# Format is <channel>/labels/<label>
_, _, label = split_channel
label = split_channel[-1]
if label == "obsolete":
obsolete_pkgs.add(prec.name)
else:
unsafe_pkgs.add(prec.name)
if unsafe_pkgs:
self.handle_unsafe_pkgs(unsafe_pkgs)
elif label == "unsafe":
unsafe_pkgs[prec.name] = prec.version
if unsafe_pkgs:
self.handle_unsafe_pkgs(unsafe_pkgs)
break
except (
PackagesNotFoundError,
ResolvePackageNotFound,
UnsatisfiableError,
UnavailableInvalidChannel,
) as exception:
last_exception = exception
else:
raise last_exception

if obsolete_pkgs:
for pkg_name in obsolete_pkgs:
log.warning(
"WARNING: These packages are marked as obsolete,"
" try to update or find an alternative:\n%s"
", ".join(obsolete_pkgs),
"The package '%s' is marked as OBSOLETE. "
"Consider updating to a newer version.",
pkg_name,
)
if unsafe_pkgs:

for pkg_name in unsafe_pkgs:
log.warning(
"WARNING: Packages below are marked as UNSAFE."
" Client continues because of comment in environment file.\n%s"
", ".join(unsafe_pkgs),
"The package '%s' is marked as UNSAFE. "
"Client continues because it is explicitly allowed "
"in the environment file.",
pkg_name,
)

solved_dep_list = link_precs.item_list
Expand All @@ -1077,11 +1128,7 @@ def create_lockfile(self, additional_dependencies=()):
"dependencies": sorted(self.env_data["dependencies"][:]),
}
else:
self.write_version_lock(
lockfile_deps,
obsolete_pkgs,
unsafe_pkgs,
)
self.write_version_lock(lockfile_deps)
self.lock_data = self.read_version_lock()
assert self.lock_data

Expand Down Expand Up @@ -1136,31 +1183,38 @@ def get_solver(self, dependencies, labels):
)

def handle_unsafe_pkgs(self, unsafe_pkgs):
from conda import CondaError
from conda.exports import MatchSpec

for pkg_name in unsafe_pkgs:
for index, env_dep in enumerate(self.env_data["dependencies"]):
env_dep_spec = MatchSpec(env_dep)
if env_dep_spec.name == pkg_name:
comment = self.get_yaml_comment(
self.env_data["dependencies"],
index,
)
if comment and comment.strip().startswith(
"# CAUTION: UNSAFE PACKAGE"
):
break
else:
from conda import CondaError
allow_unsafe_data = self.env_data.get("allow_unsafe") or {}
confirm_msg = allow_unsafe_data.get("confirm", "")
allowed_pkgs = allow_unsafe_data.get("packages") or []

raise CondaError(
"ERROR: The following package is UNSAFE,"
" check {}/unsafe"
" for more info: {}".format(
self.settings.website_url,
pkg_name,
)
)
allowed_pkg_names = [MatchSpec(p).name for p in allowed_pkgs]

for pkg_name, pkg_version in unsafe_pkgs.items():
# Check if the risk is accepted AND the package is listed
if (
confirm_msg == "Risk of unsafe packages is accepted"
and pkg_name in allowed_pkg_names
):
continue

error_msg = (
f"\nERROR: The package '{pkg_name}' is marked as UNSAFE.\n"
f"To continue using this package, "
f"you must explicitly allow it by adding "
f"the 'allow_unsafe' section to your "
f"zwik_environment.yaml file.\n\n"
f"Example:\n\n"
f"dependencies:\n"
f"- {pkg_name}\n\n"
f"allow_unsafe:\n"
f' confirm: "Risk of unsafe packages is accepted"\n'
f" packages:\n"
f" - {pkg_name}=={pkg_version}\n"
)
raise CondaError(error_msg)

def partially_update_lockfile(self, update_list):
from conda.exports import MatchSpec
Expand Down Expand Up @@ -1213,38 +1267,68 @@ def create_env(self):
from conda import __version__ as conda_version
from conda.api import SubdirData
from conda.core.link import PrefixSetup, UnlinkLinkTransaction
from conda.exports import MatchSpec

self._check_installation()

specs_to_add = self.lock_data["dependencies"]
channels = self.settings.resolve_channels(
self.lock_data["channels"],
)

# Pre-resolve all channel variations
channels = self.settings.resolve_channels(self.lock_data["channels"])
obsolete_channels = self.settings.resolve_channels(
self.lock_data["channels"], ("obsolete",)
)
unsafe_channels = self.settings.resolve_channels(
self.lock_data["channels"], ("unsafe",)
)

subdirs = [self.lock_data["subdir"], "noarch"]
link_precs = []

for spec in specs_to_add:
spec_name, _ = spec.split("=", maxsplit=1)
search_channels = channels
label = self.lock_data.get("labels", {}).get(spec_name)
if label:
search_channels = self.settings.resolve_channels(
self.lock_data["channels"], ("", label)
)
result = SubdirData.query_all(spec, search_channels, subdirs)

is_obsolete = False
is_unsafe = False

# 1. Try Standard Channels
result = SubdirData.query_all(spec, channels, subdirs)

if not result:
# 2. Try Obsolete Channels
result = SubdirData.query_all(spec, obsolete_channels, subdirs)
if result:
log.warning(
"The package '%s' is obsolete,"
" please review the environment",
spec_name,
)
is_obsolete = True
else:
raise AssertionError("Package not found: {}".format(spec))
# 3. Try Unsafe Channels
result = SubdirData.query_all(spec, unsafe_channels, subdirs)
if result:
is_unsafe = True

if not result:
raise AssertionError(
"Package not found in standard, obsolete, "
"or unsafe channels: {}".format(spec)
)

if is_obsolete:
log.warning(
"The package '%s' is OBSOLETE. "
"Consider updating to a newer version.",
spec_name,
)
elif is_unsafe:
log.warning(
"The package '%s' is UNSAFE. "
"Client continues because it "
"is explicitly allowed in the environment file",
spec_name,
)
Comment thread
rudispr marked this conversation as resolved.

ms = MatchSpec(spec)
pkg_version = str(ms.version) if ms.version else "<version>"
self.handle_unsafe_pkgs({ms.name: pkg_version})

if self._multiple_packages_found(result):
log.warning("Multiple packages found for '%s'.", spec)
result = self._filter_package_from_default_channels(result, spec)
Expand Down
Loading
Loading