-
Notifications
You must be signed in to change notification settings - Fork 4
Enhance telemetry performance #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,20 @@ | ||
| import importlib | ||
| import os | ||
| import platform | ||
| import sys | ||
| import uuid | ||
| from dataclasses import dataclass, asdict, field | ||
| from datetime import datetime, timezone | ||
| from importlib.metadata import version, PackageNotFoundError | ||
| from functools import lru_cache | ||
| from typing import Any, Dict, Literal, Optional | ||
| from pynvml import ( | ||
| nvmlInit, | ||
| nvmlDeviceGetCount, | ||
| nvmlDeviceGetHandleByIndex, | ||
| nvmlDeviceGetName, | ||
| nvmlShutdown, | ||
| ) | ||
safaricd marked this conversation as resolved.
Show resolved
Hide resolved
safaricd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from .runtime import get_execution_context | ||
| from .state import get_property, set_property | ||
|
|
||
|
|
@@ -79,7 +88,7 @@ def _get_sklearn_version() -> str: | |
| Returns: | ||
| str: Version string if scikit-learn is installed. | ||
| """ | ||
| return _get_package_version("sklearn") | ||
| return _get_package_version("scikit-learn", "sklearn") | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
|
|
@@ -112,23 +121,71 @@ def _get_autogluon_version() -> str: | |
| return _get_package_version("autogluon.core") | ||
|
|
||
|
|
||
| @lru_cache(maxsize=None) | ||
| def _get_gpu_type() -> Optional[str]: | ||
| """Detect a local GPU using PyTorch (the TabPFN dependency) and return its | ||
| human-readable name. | ||
|
|
||
| Returns: | ||
| Optional[str]: Human-readable name of the GPU if available. | ||
| """ | ||
| # First, we try to use the NVML library to get the GPU names | ||
| # This is the preferred method as it is faster than using PyTorch | ||
| nvml_initialized = False | ||
|
|
||
| try: | ||
| import torch # type: ignore[import] | ||
| except ImportError: | ||
| return None | ||
| nvmlInit() | ||
|
|
||
| # Set a flag to indicate that the NVML library was initialized | ||
| # if running on CPU, this part of the code will not be reached | ||
| nvml_initialized = True | ||
|
|
||
| counts = nvmlDeviceGetCount() | ||
| if counts == 0: | ||
| return None | ||
|
|
||
| # Retrieve the names of the devices | ||
| devices: list[str] = [ | ||
| nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(counts) | ||
| ] | ||
|
|
||
| # Because NVML runs very fast, we just return the device name | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still worth to cache as it will run on every event?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking of the same thing, however in some rare cases, GPUs might be attached or detached to a VM, so we'd have to cache this information on-disk with a TTL. Anyway, given that NVML runs within 20-30 milliseconds, not really worth it ATM. |
||
| # without caching it. We return the first device name and assume | ||
| # that the VM has the same GPU type for all devices. | ||
| return devices[0] | ||
| except Exception: | ||
| pass | ||
safaricd marked this conversation as resolved.
Show resolved
Hide resolved
safaricd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| finally: | ||
| # Shutdown the NVML library | ||
| if nvml_initialized: | ||
| nvmlShutdown() | ||
|
|
||
| # Only then, as an alternative, we try to use PyTorch to get the GPU name | ||
| # This is the slowest method as it requires importing the PyTorch library | ||
| # so we do not prefer this over the previous methods | ||
| return _get_torch_gpu_type() | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
| def _get_torch_gpu_type() -> Optional[str]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder whether we could get this information eagerly at import time instead of lazily at event creation. Maybe we could even get the info straight from tabpfn.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might - an interesting area to explore in the future. |
||
| """Get the type of GPU using PyTorch. | ||
|
|
||
| Returns: | ||
| Optional[str]: Type of GPU if available. | ||
| """ | ||
| # First, we try to load the torch module from the sys.modules | ||
| # because we can assume it is already loaded and we avoid | ||
| # re-initializing it. | ||
| torch = sys.modules.get("torch") | ||
| if torch is None: | ||
| try: | ||
| import torch # type: ignore[import] | ||
| except ImportError: | ||
| return None | ||
|
|
||
| try: | ||
| if torch.cuda.is_available(): | ||
| name = torch.cuda.get_device_name(0) | ||
| return name or "unknown" | ||
| name = torch.cuda.get_device_name(0) or "unknown" | ||
| return name | ||
|
|
||
| except Exception: # noqa: BLE001 | ||
| pass | ||
safaricd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -137,29 +194,34 @@ def _get_gpu_type() -> Optional[str]: | |
| try: | ||
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | ||
| # torch doesn't expose an MPS "model string" | ||
| return "Apple M-series GPU (MPS)" | ||
| name = "Apple M-series GPU (MPS)" | ||
| return name | ||
| except Exception: | ||
| pass | ||
safaricd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return None | ||
|
|
||
|
|
||
| def _get_package_version(package_name: str) -> str: | ||
| """Get the version of a package if it's installed. | ||
| def _get_package_version(dist_name: str, module_name: Optional[str] = None) -> str: | ||
| """Read the package distribution metadata and return the version. | ||
|
|
||
| Args: | ||
| package_name: Name of the package to import (e.g., "torch", "tabpfn"). | ||
| dist_name: Name of the package to read the metadata for. | ||
| module_name: Name of the module to read the version for, if different | ||
| from the distribution name. | ||
|
|
||
| Returns: | ||
| str: Version string if the package is installed, "unknown" otherwise. | ||
| str: Version string if the package is installed, None otherwise. | ||
| """ | ||
| try: | ||
| import importlib | ||
|
|
||
| module = importlib.import_module(package_name) # type: ignore[import] | ||
| return getattr(module, "__version__", "unknown") | ||
| except ImportError: | ||
| return "unknown" | ||
| return version(dist_name) | ||
| except PackageNotFoundError: | ||
| # Fallback to importing the package and getting the version | ||
| try: | ||
| module = importlib.import_module(module_name or dist_name) | ||
| return getattr(module, "__version__", "unknown") | ||
| except ImportError: | ||
| return "unknown" | ||
|
|
||
|
|
||
| @lru_cache(maxsize=1) | ||
|
|
||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.