Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"filelock>=3.19.1",
"requests>=2.32.5",
"ruff>=0.11.6",
"nvidia-ml-py>=13.590.48",
]

[project.optional-dependencies]
Expand Down
13 changes: 4 additions & 9 deletions src/tabpfn_common_utils/telemetry/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logging
import os

from datetime import datetime, timezone
from typing import Any, Dict

import requests
Expand All @@ -20,28 +19,24 @@
logger = logging.getLogger(__name__)


@ttl_cache(ttl_seconds=60 * 5)
@ttl_cache(ttl_seconds=60 * 60)
def download_config() -> Dict[str, Any]:
"""Download the configuration from server.

Returns:
Dict[str, Any]: The configuration.
"""
# Bust the cache
params = {
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
}

# The default configuration
default = {"enabled": False}
default = {"enabled": True}

# This is a public URL anyone can and should read from
url = os.environ.get(
"TABPFN_TELEMETRY_CONFIG_URL",
"https://storage.googleapis.com/prior-labs-tabpfn-public/config/telemetry.json",
)
try:
resp = requests.get(url, params=params)
# We use a very short timeout to avoid blocking the main thread
resp = requests.get(url, timeout=0.25)
except Exception:
logger.debug(f"Failed to download telemetry config: {url}")
return default
Expand Down
98 changes: 80 additions & 18 deletions src/tabpfn_common_utils/telemetry/core/events.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import importlib
import os
import platform
import sys
import uuid
from dataclasses import dataclass, asdict, field
from datetime import datetime, timezone
from importlib.metadata import version, PackageNotFoundError
from functools import lru_cache
from typing import Any, Dict, Literal, Optional
from pynvml import (
nvmlInit,
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetName,
nvmlShutdown,
)
from .runtime import get_execution_context
from .state import get_property, set_property

Expand Down Expand Up @@ -79,7 +88,7 @@ def _get_sklearn_version() -> str:
Returns:
str: Version string if scikit-learn is installed.
"""
return _get_package_version("sklearn")
return _get_package_version("scikit-learn", "sklearn")


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -112,23 +121,71 @@ def _get_autogluon_version() -> str:
return _get_package_version("autogluon.core")


@lru_cache(maxsize=None)
def _get_gpu_type() -> Optional[str]:
"""Detect a local GPU using PyTorch (the TabPFN dependency) and return its
human-readable name.

Returns:
Optional[str]: Human-readable name of the GPU if available.
"""
# First, we try to use the NVML library to get the GPU names
# This is the preferred method as it is faster than using PyTorch
nvml_initialized = False

try:
import torch # type: ignore[import]
except ImportError:
return None
nvmlInit()

# Set a flag to indicate that the NVML library was initialized
# if running on CPU, this part of the code will not be reached
nvml_initialized = True

counts = nvmlDeviceGetCount()
if counts == 0:
return None

# Retrieve the names of the devices
devices: list[str] = [
nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(counts)
]

# Because NVML runs very fast, we just return the device name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still worth to cache as it will run on every event?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of the same thing, however in some rare cases, GPUs might be attached or detached to a VM, so we'd have to cache this information on-disk with a TTL. Anyway, given that NVML runs within 20-30 milliseconds, not really worth it ATM.

# without caching it. We return the first device name and assume
# that the VM has the same GPU type for all devices.
return devices[0]
except Exception:
pass
finally:
# Shutdown the NVML library
if nvml_initialized:
nvmlShutdown()

# Only then, as an alternative, we try to use PyTorch to get the GPU name
# This is the slowest method as it requires importing the PyTorch library
# so we do not prefer this over the previous methods
return _get_torch_gpu_type()


@lru_cache(maxsize=1)
def _get_torch_gpu_type() -> Optional[str]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we could get this information eagerly at import time instead of lazily at event creation. Maybe we could even get the info straight from tabpfn.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might - an interesting area to explore in the future.

"""Get the type of GPU using PyTorch.

Returns:
Optional[str]: Type of GPU if available.
"""
# First, we try to load the torch module from the sys.modules
# because we can assume it is already loaded and we avoid
# re-initializing it.
torch = sys.modules.get("torch")
if torch is None:
try:
import torch # type: ignore[import]
except ImportError:
return None

try:
if torch.cuda.is_available():
name = torch.cuda.get_device_name(0)
return name or "unknown"
name = torch.cuda.get_device_name(0) or "unknown"
return name

except Exception: # noqa: BLE001
pass
Expand All @@ -137,29 +194,34 @@ def _get_gpu_type() -> Optional[str]:
try:
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
# torch doesn't expose an MPS "model string"
return "Apple M-series GPU (MPS)"
name = "Apple M-series GPU (MPS)"
return name
except Exception:
pass

return None


def _get_package_version(package_name: str) -> str:
"""Get the version of a package if it's installed.
def _get_package_version(dist_name: str, module_name: Optional[str] = None) -> str:
"""Read the package distribution metadata and return the version.

Args:
package_name: Name of the package to import (e.g., "torch", "tabpfn").
dist_name: Name of the package to read the metadata for.
module_name: Name of the module to read the version for, if different
from the distribution name.

Returns:
str: Version string if the package is installed, "unknown" otherwise.
str: Version string if the package is installed, None otherwise.
"""
try:
import importlib

module = importlib.import_module(package_name) # type: ignore[import]
return getattr(module, "__version__", "unknown")
except ImportError:
return "unknown"
return version(dist_name)
except PackageNotFoundError:
# Fallback to importing the package and getting the version
try:
module = importlib.import_module(module_name or dist_name)
return getattr(module, "__version__", "unknown")
except ImportError:
return "unknown"


@lru_cache(maxsize=1)
Expand Down
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.