Skip to content

Commit a61fed9

Browse files
committed
Implement multiprocess support in the Redis mode
Signed-off-by: Stefano Rivera <stefano@rivera.za.net>
1 parent 2e79562 commit a61fed9

4 files changed

Lines changed: 524 additions & 85 deletions

File tree

prometheus_client/redis.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
from datetime import timedelta
3+
from threading import Event, Thread
4+
from typing import Any
5+
from urllib.parse import urlsplit
6+
7+
from redis import Redis
8+
9+
# For testing, a pool of otherwise anonymous FakeRedis instances are made
10+
# available by ID
11+
_fake_redis_pool: dict[int, Redis] = {}
12+
13+
14+
def redis_client() -> Redis:
15+
"""
16+
Create a redis client for PROMETHEUS_REDIS_URL.
17+
18+
Configure the redis database via a URL in PROMETHEUS_REDIS_URL of the form
19+
redis://localhost:6379/0
20+
"""
21+
parsed_url = urlsplit(os.environ["PROMETHEUS_REDIS_URL"])
22+
assert parsed_url.path.startswith("/")
23+
assert parsed_url.path[1:].isdigit()
24+
port = parsed_url.port or 6379
25+
db = int(parsed_url.path[1:])
26+
27+
if parsed_url.scheme == "fakeredis":
28+
from fakeredis import FakeRedis
29+
30+
if db not in _fake_redis_pool:
31+
_fake_redis_pool[db] = FakeRedis()
32+
return _fake_redis_pool[db]
33+
34+
assert parsed_url.scheme == "redis"
35+
assert parsed_url.hostname
36+
return Redis(host=parsed_url.hostname, port=port, db=db)
37+
38+
39+
# For each process identifier, a list of keys that should be kept from expiring
40+
_live_metrics: dict[str, set[str]] = {}
41+
42+
43+
def _key_expiry() -> timedelta:
44+
"""Return the configured expiry for multiprocess keys."""
45+
return timedelta(seconds=int(os.environ.get("PROMETHEUS_REDIS_REFRESH_TTL", 20)))
46+
47+
48+
class KeepMetricsAliveThread(Thread):
49+
"""A daemon thread that keeps metrics from expiring as long as we live."""
50+
51+
stop: Event
52+
identifier: str
53+
54+
def __init__(self, identifier: str, *args: Any, **kwargs: Any) -> None:
55+
self.stop = Event()
56+
self.identifier = identifier
57+
super().__init__(*args, **kwargs)
58+
59+
def run(self) -> None:
60+
delay = int(os.environ.get("PROMETHEUS_REDIS_REFRESH_FREQUENCY", 10))
61+
expiry = _key_expiry()
62+
client = redis_client()
63+
while not self.stop.wait(delay):
64+
for key in _live_metrics[self.identifier]:
65+
client.expire(key, expiry)
66+
67+
68+
_daemon_threads: dict[str, KeepMetricsAliveThread] = {}
69+
70+
71+
def _keep_key_from_expiring(identifier: str, key: str) -> None:
72+
"""Stop key for process identifier from expiring as long as we are alive."""
73+
_live_metrics.setdefault(identifier, set()).add(key)
74+
if identifier not in _daemon_threads:
75+
thread = KeepMetricsAliveThread(identifier=identifier, daemon=True)
76+
thread.start()
77+
_daemon_threads[identifier] = thread
78+
79+
80+
def mark_process_dead(identifier: str | int) -> None:
81+
"""Immediately expire all live* metrics for process identifier."""
82+
thread = _daemon_threads.pop(str(identifier), None)
83+
if thread is not None:
84+
thread.stop.set()
85+
thread.join()
86+
87+
keys = _live_metrics.pop(str(identifier), None)
88+
if not keys:
89+
return
90+
redis_client().delete(*keys)

prometheus_client/redis_collector.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,12 @@
1-
from collections.abc import Iterable
21
import json
3-
import os
4-
from urllib.parse import urlsplit
2+
from collections.abc import Iterable
3+
from typing import cast
54

65
from .metrics_core import Metric
6+
from .redis import redis_client
77
from .registry import Collector, CollectorRegistry
88
from .samples import Sample
9-
10-
fake_redis_pool = {}
11-
12-
13-
def redis_client():
14-
"""
15-
Create a redis client for PROMETHEUS_REDIS_URL.
16-
17-
Configure the redis database via a URL in PROMETHEUS_REDIS_URL of the form
18-
redis://localhost:6379/0
19-
"""
20-
from redis import Redis
21-
22-
parsed_url = urlsplit(os.environ["PROMETHEUS_REDIS_URL"])
23-
assert parsed_url.path.startswith("/")
24-
assert parsed_url.path[1:].isdigit()
25-
port = parsed_url.port or 6379
26-
db = int(parsed_url.path[1:])
27-
28-
if parsed_url.scheme == "fakeredis":
29-
from fakeredis import FakeRedis
30-
31-
if db not in fake_redis_pool:
32-
fake_redis_pool[db] = FakeRedis()
33-
return fake_redis_pool[db]
34-
35-
assert parsed_url.scheme == "redis"
36-
return Redis(host=parsed_url.hostname, port=port, db=db)
9+
from .values import MULTIPROCESS_MODE_T
3710

3811

3912
class RedisCollector(Collector):
@@ -56,30 +29,78 @@ def _iter_values(self) -> Iterable[tuple[bytes, str]]:
5629
def collect(self) -> Iterable[Metric]:
5730
metrics: dict[str, Metric] = {}
5831
histograms: set[str] = set()
32+
multiprocess: dict[str, MULTIPROCESS_MODE_T] = {}
5933

6034
for key, value_s in self._iter_values():
6135
# FIXME: Catch ValueError here, just in case?
62-
prefix_b, typ_b, mmap_key = key.split(b":", 2)
36+
prefix_b, typ_b, multiprocess_mode_b, mmap_key = key.split(b":", 3)
6337
assert prefix_b == b"value"
64-
typ = typ_b.decode()
6538
value = float(value_s)
6639

6740
metric_name, name, labels, help_text = json.loads(mmap_key)
6841

6942
metric = metrics.get(metric_name)
7043
if metric is None:
44+
typ = typ_b.decode()
7145
metric = Metric(metric_name, help_text, typ)
7246
metrics[metric_name] = metric
47+
7348
if typ in ("histogram", "gaugehistogram"):
7449
histograms.add(metric_name)
7550

51+
multiprocess_mode = cast(
52+
MULTIPROCESS_MODE_T, multiprocess_mode_b.decode()
53+
)
54+
if typ in ("gauge", "gaugehistogram") and multiprocess_mode:
55+
multiprocess[metric_name] = multiprocess_mode
56+
7657
metric.add_sample(name, labels, value)
7758

59+
for name, multiprocess_mode in multiprocess.items():
60+
self._accumulate_multiprocess(metrics[name], multiprocess_mode)
61+
7862
for name in histograms:
7963
self._fix_histogram(metrics[name])
8064

8165
return metrics.values()
8266

67+
def _accumulate_multiprocess(
68+
self, metric: Metric, multiprocess_mode: MULTIPROCESS_MODE_T
69+
) -> None:
70+
"""Merge metrics from multiple processes using multiprocess_mode."""
71+
# We deal with live/dead with Redis expiry
72+
if multiprocess_mode.startswith("live"):
73+
multiprocess_mode = cast(
74+
MULTIPROCESS_MODE_T, multiprocess_mode[len("live") :]
75+
)
76+
if multiprocess_mode == "all":
77+
return
78+
79+
by_label: dict[tuple[tuple[str, ...], str], Sample] = {}
80+
81+
for sample in metric.samples:
82+
labels = sample.labels
83+
if "pid" in sample.labels:
84+
labels = labels.copy()
85+
labels.pop("pid")
86+
key = (tuple(labels.values()), sample.name)
87+
value = sample.value
88+
if key in by_label:
89+
current_value = by_label[key].value
90+
if multiprocess_mode == "min" and value > current_value:
91+
continue
92+
if multiprocess_mode == "max" and value < current_value:
93+
continue
94+
if multiprocess_mode == "sum":
95+
value += current_value
96+
if multiprocess_mode == "mostrecent":
97+
raise NotImplementedError(
98+
"The 'mostrecent' modes are not supported in RedisCollector"
99+
)
100+
by_label[key] = Sample(sample.name, labels, value)
101+
102+
metric.samples = list(by_label.values())
103+
83104
def _fix_histogram(self, metric: Metric) -> None:
84105
"""
85106
Fix-up histogram samples.

0 commit comments

Comments
 (0)