|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import contextlib |
| 4 | +import shutil |
| 5 | +import urllib |
| 6 | +import urllib.parse |
| 7 | +import zipfile |
| 8 | +from collections.abc import Callable |
3 | 9 | from pathlib import Path |
4 | 10 | from typing import TYPE_CHECKING, Any |
5 | 11 | from urllib.parse import urlencode, urljoin, urlparse |
6 | 12 |
|
| 13 | +import minio |
7 | 14 | import requests |
8 | 15 | from requests import Response |
| 16 | +from urllib3 import ProxyManager |
9 | 17 |
|
10 | 18 | from openml.__version__ import __version__ |
11 | 19 | from openml._api.config import settings |
12 | 20 |
|
13 | 21 | if TYPE_CHECKING: |
14 | 22 | from openml._api.config import APIConfig |
15 | 23 |
|
| 24 | +import openml.config |
| 25 | +from openml.utils import ProgressBar |
| 26 | + |
16 | 27 |
|
17 | 28 | class CacheMixin: |
18 | 29 | @property |
@@ -149,3 +160,143 @@ def delete( |
149 | 160 | use_api_key=True, |
150 | 161 | **request_kwargs, |
151 | 162 | ) |
| 163 | + |
| 164 | + def download( |
| 165 | + self, |
| 166 | + url: str, |
| 167 | + handler: Callable[[Response, Path, str], Path], |
| 168 | + encoding: str = "utf-8", |
| 169 | + ) -> Path: |
| 170 | + response = self.get(url) |
| 171 | + dir_path = self._get_cache_dir(url, {}) |
| 172 | + dir_path = dir_path.expanduser() |
| 173 | + if handler is not None: |
| 174 | + return handler(response, dir_path, encoding) |
| 175 | + |
| 176 | + return self._text_handler(response, dir_path, encoding, url) |
| 177 | + |
| 178 | + def _text_handler(self, response: Response, path: Path, encoding: str) -> Path: |
| 179 | + if path.is_dir(): |
| 180 | + path = path / "response.txt" |
| 181 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 182 | + with path.open("w", encoding=encoding) as f: |
| 183 | + f.write(response.text) |
| 184 | + return path |
| 185 | + |
| 186 | + |
| 187 | +class MinIOClient(CacheMixin): |
| 188 | + def __init__(self) -> None: |
| 189 | + self.headers: dict[str, str] = {"user-agent": f"openml-python/{__version__}"} |
| 190 | + |
| 191 | + def download_minio_file( |
| 192 | + self, |
| 193 | + source: str, |
| 194 | + destination: str | Path | None = None, |
| 195 | + exists_ok: bool = True, # noqa: FBT002 |
| 196 | + proxy: str | None = "auto", |
| 197 | + ) -> str: |
| 198 | + """Download file ``source`` from a MinIO Bucket and store it at ``destination``. |
| 199 | +
|
| 200 | + Parameters |
| 201 | + ---------- |
| 202 | + source : str |
| 203 | + URL to a file in a MinIO bucket. |
| 204 | + destination : str | Path |
| 205 | + Path to store the file to, if a directory is provided the original filename is used. |
| 206 | + exists_ok : bool, optional (default=True) |
| 207 | + If False, raise FileExists if a file already exists in ``destination``. |
| 208 | + proxy: str, optional (default = "auto") |
| 209 | + The proxy server to use. By default it's "auto" which uses ``requests`` to |
| 210 | + automatically find the proxy to use. Pass None or the environment variable |
| 211 | + ``no_proxy="*"`` to disable proxies. |
| 212 | + """ |
| 213 | + destination = self._get_cache_dir(source, {}) if destination is None else Path(destination) |
| 214 | + parsed_url = urllib.parse.urlparse(source) |
| 215 | + |
| 216 | + # expect path format: /BUCKET/path/to/file.ext |
| 217 | + bucket, object_name = parsed_url.path[1:].split("/", maxsplit=1) |
| 218 | + if destination.is_dir(): |
| 219 | + destination = Path(destination, object_name) |
| 220 | + if destination.is_file() and not exists_ok: |
| 221 | + raise FileExistsError(f"File already exists in {destination}.") |
| 222 | + |
| 223 | + destination = destination.expanduser() |
| 224 | + destination.parent.mkdir(parents=True, exist_ok=True) |
| 225 | + |
| 226 | + if proxy == "auto": |
| 227 | + resolved_proxies = requests.utils.get_environ_proxies(parsed_url.geturl()) |
| 228 | + proxy = requests.utils.select_proxy(parsed_url.geturl(), resolved_proxies) # type: ignore |
| 229 | + |
| 230 | + proxy_client = ProxyManager(proxy) if proxy else None |
| 231 | + |
| 232 | + client = minio.Minio(endpoint=parsed_url.netloc, secure=False, http_client=proxy_client) |
| 233 | + try: |
| 234 | + client.fget_object( |
| 235 | + bucket_name=bucket, |
| 236 | + object_name=object_name, |
| 237 | + file_path=str(destination), |
| 238 | + progress=ProgressBar() if openml.config.show_progress else None, |
| 239 | + request_headers=self.headers, |
| 240 | + ) |
| 241 | + if destination.is_file() and destination.suffix == ".zip": |
| 242 | + with zipfile.ZipFile(destination, "r") as zip_ref: |
| 243 | + zip_ref.extractall(destination.parent) |
| 244 | + |
| 245 | + except minio.error.S3Error as e: |
| 246 | + if e.message is not None and e.message.startswith("Object does not exist"): |
| 247 | + raise FileNotFoundError(f"Object at '{source}' does not exist.") from e |
| 248 | + # e.g. permission error, or a bucket does not exist (which is also interpreted as a |
| 249 | + # permission error on minio level). |
| 250 | + raise FileNotFoundError("Bucket does not exist or is private.") from e |
| 251 | + |
| 252 | + return str(destination) |
| 253 | + |
| 254 | + def download_minio_bucket(self, source: str, destination: str | Path) -> None: |
| 255 | + """Download file ``source`` from a MinIO Bucket and store it at ``destination``. |
| 256 | +
|
| 257 | + Does not redownload files which already exist. |
| 258 | +
|
| 259 | + Parameters |
| 260 | + ---------- |
| 261 | + source : str |
| 262 | + URL to a MinIO bucket. |
| 263 | + destination : str | Path |
| 264 | + Path to a directory to store the bucket content in. |
| 265 | + """ |
| 266 | + destination = self._get_cache_dir(source, {}) if destination is None else Path(destination) |
| 267 | + parsed_url = urllib.parse.urlparse(source) |
| 268 | + |
| 269 | + # expect path format: /BUCKET/path/to/file.ext |
| 270 | + _, bucket, *prefixes, _file = parsed_url.path.split("/") |
| 271 | + prefix = "/".join(prefixes) |
| 272 | + |
| 273 | + client = minio.Minio(endpoint=parsed_url.netloc, secure=False) |
| 274 | + |
| 275 | + for file_object in client.list_objects(bucket, prefix=prefix, recursive=True): |
| 276 | + if file_object.object_name is None: |
| 277 | + raise ValueError(f"Object name is None for object {file_object!r}") |
| 278 | + if file_object.etag is None: |
| 279 | + raise ValueError(f"Object etag is None for object {file_object!r}") |
| 280 | + |
| 281 | + marker = destination / file_object.etag |
| 282 | + if marker.exists(): |
| 283 | + continue |
| 284 | + |
| 285 | + file_destination = destination / file_object.object_name.rsplit("/", 1)[1] |
| 286 | + if (file_destination.parent / file_destination.stem).exists(): |
| 287 | + # Marker is missing but archive exists means the server archive changed |
| 288 | + # force a refresh |
| 289 | + shutil.rmtree(file_destination.parent / file_destination.stem) |
| 290 | + |
| 291 | + with contextlib.suppress(FileExistsError): |
| 292 | + self.download_minio_file( |
| 293 | + source=source.rsplit("/", 1)[0] |
| 294 | + + "/" |
| 295 | + + file_object.object_name.rsplit("/", 1)[1], |
| 296 | + destination=file_destination, |
| 297 | + exists_ok=False, |
| 298 | + ) |
| 299 | + |
| 300 | + if file_destination.is_file() and file_destination.suffix == ".zip": |
| 301 | + file_destination.unlink() |
| 302 | + marker.touch() |
0 commit comments