Skip to content

Commit 8478778

Browse files
tobymaoizeigerman
authored andcommitted
feat: improve cold start of snapshot cache with multi processing (#3084)
1 parent 570870f commit 8478778

File tree

13 files changed

+186
-99
lines changed

13 files changed

+186
-99
lines changed

Makefile

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ engine-up: engine-mssql-up engine-mysql-up engine-postgres-up engine-spark-up en
9191
engine-down: engine-mssql-down engine-mysql-down engine-postgres-down engine-spark-down engine-trino-down
9292

9393
fast-test:
94-
pytest -n auto -m "fast and not cicdonly"
94+
pytest -n auto -m "fast and not cicdonly" && pytest -m "isolated"
9595

9696
slow-test:
97-
pytest -n auto -m "(fast or slow) and not cicdonly"
97+
pytest -n auto -m "(fast or slow) and not cicdonly" && pytest -m "isolated"
9898

9999
cicd-test:
100-
pytest -n auto -m "fast or slow" --junitxml=test-results/junit-cicd.xml
100+
pytest -n auto -m "fast or slow" --junitxml=test-results/junit-cicd.xml && pytest -m "isolated"
101101

102102
core-fast-test:
103103
pytest -n auto -m "fast and not web and not github and not dbt and not airflow and not jupyter"
@@ -199,4 +199,7 @@ databricks-test: guard-DATABRICKS_CATALOG guard-DATABRICKS_SERVER_HOSTNAME guard
199199
pytest -n auto -x -m "databricks" --junitxml=test-results/junit-databricks.xml
200200

201201
redshift-test: guard-REDSHIFT_HOST guard-REDSHIFT_USER guard-REDSHIFT_PASSWORD guard-REDSHIFT_DATABASE engine-redshift-install
202-
pytest -n auto -x -m "redshift" --junitxml=test-results/junit-redshift.xml
202+
pytest -n auto -x -m "redshift" --retries 3 --junitxml=test-results/junit-redshift.xml
203+
204+
clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNAME guard-CLICKHOUSE_CLOUD_PASSWORD engine-clickhouse-install
205+
pytest -n auto -x -m "clickhouse_cloud" --retries 3 --junitxml=test-results/junit-clickhouse-cloud.xml

docs/reference/configuration.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ This section describes the other root level configuration parameters.
1616

1717
Configuration options for SQLMesh project directories.
1818

19-
| Option | Description | Type | Required |
20-
| ----------------- | ------------------------------------------------------------------------------------------------------------------ | :----------: | :------: |
21-
| `ignore_patterns` | Files that match glob patterns specified in this list are ignored when scanning the project folder (Default: `[]`) | list[string] | N |
22-
| `project` | The project name of this config. Used for [multi-repo setups](../guides/multi_repo.md). | string | N |
19+
| Option | Description | Type | Required |
20+
| ------------------ | ------------------------------------------------------------------------------------------------------------------ | :----------: | :------: |
21+
| `ignore_patterns` | Files that match glob patterns specified in this list are ignored when scanning the project folder (Default: `[]`) | list[string] | N |
22+
| `project` | The project name of this config. Used for [multi-repo setups](../guides/multi_repo.md). | string | N |
2323

2424
### Environments
2525

@@ -291,3 +291,8 @@ You can disable collection of anonymized usage information with these methods:
291291

292292
- Set the root `disable_anonymized_analytics: true` key in your SQLMesh project configuration file
293293
- Execute SQLMesh commands with an environment variable `SQLMESH__DISABLE_ANONYMIZED_ANALYTICS` set to `1`, `true`, `t`, `yes`, or `y`
294+
295+
## Parallel loading
296+
SQLMesh by default uses all of your cores when loading models and snapshots. It takes advantage of `fork` which is not available on Windows. The default is to use the same number of workers as cores on your machine if fork is available.
297+
298+
You can override this setting by setting the environment variable `MAX_FORK_WORKERS`. A value of 1 will disable forking and load things sequentially.

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ markers =
77
docker: test that involves interacting with a Docker container
88
remote: test that involves interacting with a remote DB
99
cicdonly: test that only runs on CI/CD
10+
isolated: tests that need to run sequentially usually because they use fork
1011

1112
# Test Domain Markers
1213
# default: core functionality

sqlmesh/core/constants.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import datetime
4+
import os
5+
import typing as t
46
from pathlib import Path
57

68
SQLMESH = "sqlmesh"
@@ -28,6 +30,21 @@
2830
MAX_MODEL_DEFINITION_SIZE = 10000
2931
"""Maximum number of characters in a model definition"""
3032

33+
34+
# The maximum number of fork processes, used for loading projects
35+
# None means default to process pool, 1 means don't fork, :N is number of processes
36+
# Factors in the number of available CPUs even if the process is bound to a subset of them
37+
# (e.g. via taskset) to avoid oversubscribing the system and causing kill signals
38+
if hasattr(os, "fork"):
39+
try:
40+
MAX_FORK_WORKERS: t.Optional[int] = int(os.getenv("MAX_FORK_WORKERS")) # type: ignore
41+
except TypeError:
42+
MAX_FORK_WORKERS = (
43+
len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else None
44+
)
45+
else:
46+
MAX_FORK_WORKERS = 1
47+
3148
EPOCH = datetime.date(1970, 1, 1)
3249

3350
DEFAULT_MAX_LIMIT = 1000

sqlmesh/core/loader.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import abc
44
import linecache
55
import logging
6-
import multiprocessing as mp
76
import os
87
import typing as t
98
from collections import defaultdict
10-
from concurrent.futures import ProcessPoolExecutor, as_completed
9+
from concurrent.futures import as_completed
1110
from dataclasses import dataclass
1211
from pathlib import Path
1312

@@ -25,10 +24,10 @@
2524
ModelCache,
2625
OptimizedQueryCache,
2726
SeedModel,
28-
SqlModel,
2927
create_external_model,
3028
load_sql_based_model,
3129
)
30+
from sqlmesh.core.model.cache import optimized_query_cache_pool, load_optimized_query_cache
3231
from sqlmesh.core.model import model as model_registry
3332
from sqlmesh.utils import UniqueKeyDict
3433
from sqlmesh.utils.dag import DAG
@@ -549,7 +548,7 @@ def update_model_schemas(
549548
schema = MappingSchema(normalize=False)
550549
optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
551550

552-
if not hasattr(os, "fork") or "PYTEST_CURRENT_TEST" in os.environ:
551+
if c.MAX_FORK_WORKERS == 1:
553552
_update_model_schemas_sequential(dag, models, schema, optimized_query_cache)
554553
else:
555554
_update_model_schemas_parallel(dag, models, schema, optimized_query_cache)
@@ -610,13 +609,9 @@ def process_models(completed_model: t.Optional[Model] = None) -> None:
610609
del graph[name]
611610
model = models[name]
612611
model.update_schema(schema)
613-
futures.add(executor.submit(_load_optimized_query_cache, model))
612+
futures.add(executor.submit(load_optimized_query_cache, model))
614613

615-
with ProcessPoolExecutor(
616-
mp_context=mp.get_context("fork"),
617-
initializer=_init_optimized_query_cache,
618-
initargs=(optimized_query_cache,),
619-
) as executor:
614+
with optimized_query_cache_pool(optimized_query_cache) as executor:
620615
process_models()
621616

622617
while futures:
@@ -629,20 +624,3 @@ def process_models(completed_model: t.Optional[Model] = None) -> None:
629624

630625
_update_schema_with_model(schema, model)
631626
process_models(completed_model=model)
632-
633-
634-
_optimized_query_cache: t.Optional[OptimizedQueryCache] = None
635-
636-
637-
def _init_optimized_query_cache(optimized_query_cache: OptimizedQueryCache) -> None:
638-
global _optimized_query_cache
639-
_optimized_query_cache = optimized_query_cache
640-
641-
642-
def _load_optimized_query_cache(model: Model) -> t.Tuple[str, t.Optional[str]]:
643-
assert _optimized_query_cache
644-
if isinstance(model, SqlModel):
645-
entry_name = _optimized_query_cache.put(model)
646-
else:
647-
entry_name = None
648-
return model.fqn, entry_name

sqlmesh/core/model/cache.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

33
import logging
4+
import multiprocessing as mp
45
import typing as t
6+
from concurrent.futures import ProcessPoolExecutor
57
from pathlib import Path
68

79
from sqlglot import exp
810
from sqlglot.optimizer.simplify import gen
911

12+
from sqlmesh.core import constants as c
1013
from sqlmesh.core.model.definition import Model, SqlModel, _Model
1114
from sqlmesh.utils.cache import FileCache
1215
from sqlmesh.utils.hashing import crc32
@@ -15,6 +18,8 @@
1518

1619
logger = logging.getLogger(__name__)
1720

21+
T = t.TypeVar("T")
22+
1823

1924
class ModelCache:
2025
"""File-based cache implementation for model definitions.
@@ -128,6 +133,49 @@ def _entry_name(model: SqlModel) -> str:
128133
return f"{model.name}_{crc32(hash_data)}"
129134

130135

136+
def optimized_query_cache_pool(optimized_query_cache: OptimizedQueryCache) -> ProcessPoolExecutor:
137+
return ProcessPoolExecutor(
138+
mp_context=mp.get_context("fork"),
139+
initializer=_init_optimized_query_cache,
140+
initargs=(optimized_query_cache,),
141+
max_workers=c.MAX_FORK_WORKERS,
142+
)
143+
144+
145+
@t.overload
146+
def load_optimized_query_cache(
147+
model_or_tuple: t.Tuple[Model, T],
148+
) -> t.Tuple[T, t.Optional[str]]: ...
149+
150+
151+
@t.overload
152+
def load_optimized_query_cache(model_or_tuple: Model) -> t.Tuple[str, t.Optional[str]]: ...
153+
154+
155+
def load_optimized_query_cache(model_or_tuple): # type: ignore
156+
assert _optimized_query_cache
157+
158+
if isinstance(model_or_tuple, _Model):
159+
model = model_or_tuple
160+
key = None
161+
else:
162+
model, key = model_or_tuple
163+
164+
if isinstance(model, SqlModel):
165+
entry_name = _optimized_query_cache.put(model)
166+
else:
167+
entry_name = None
168+
return key or model.fqn, entry_name
169+
170+
171+
_optimized_query_cache: t.Optional[OptimizedQueryCache] = None
172+
173+
174+
def _init_optimized_query_cache(optimized_query_cache: OptimizedQueryCache) -> None:
175+
global _optimized_query_cache
176+
_optimized_query_cache = optimized_query_cache
177+
178+
131179
def _mapping_schema_hash_data(schema: t.Dict[str, t.Any]) -> t.List[str]:
132180
keys = sorted(schema) if all(isinstance(v, dict) for v in schema.values()) else schema
133181

sqlmesh/core/model/definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ def full_depends_on(self) -> t.Set[str]:
960960
if self._full_depends_on is None:
961961
depends_on = self.depends_on_ or set()
962962

963-
query = self.render_query(optimize=False)
963+
query = self.render_query(needs_optimization=False)
964964
if query is not None:
965965
depends_on |= d.find_tables(
966966
query, default_catalog=self.default_catalog, dialect=self.dialect

sqlmesh/core/renderer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def render(
376376
table_mapping: t.Optional[t.Dict[str, str]] = None,
377377
deployability_index: t.Optional[DeployabilityIndex] = None,
378378
expand: t.Iterable[str] = tuple(),
379-
optimize: bool = True,
379+
needs_optimization: bool = True,
380380
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
381381
**kwargs: t.Any,
382382
) -> t.Optional[exp.Query]:
@@ -393,7 +393,8 @@ def render(
393393
expand: Expand referenced models as subqueries. This is used to bypass backfills when running queries
394394
that depend on materialized tables. Model definitions are inlined and can thus be run end to
395395
end on the fly.
396-
optimize: Whether to optimize the query.
396+
needs_optimization: Whether or not an optimization should be attempted
397+
(if passing False, it still may return a cached optimized query).
397398
runtime_stage: Indicates the current runtime stage, for example if we're still loading the project, etc.
398399
kwargs: Additional kwargs to pass to the renderer.
399400
@@ -402,7 +403,7 @@ def render(
402403
"""
403404

404405
should_cache = self._should_cache(
405-
runtime_stage, start, end, execution_time, not optimize, *kwargs.values()
406+
runtime_stage, start, end, execution_time, *kwargs.values()
406407
)
407408

408409
if should_cache and self._optimized_cache:
@@ -417,7 +418,7 @@ def render(
417418
table_mapping=table_mapping,
418419
deployability_index=deployability_index,
419420
runtime_stage=runtime_stage,
420-
normalize_identifiers=optimize,
421+
normalize_identifiers=needs_optimization,
421422
**kwargs,
422423
)
423424
except ParsetimeAdapterCallError:
@@ -439,7 +440,7 @@ def render(
439440
)
440441
raise
441442

442-
if optimize:
443+
if needs_optimization:
443444
deps = d.find_tables(
444445
query, default_catalog=self._default_catalog, dialect=self._dialect
445446
)
@@ -449,7 +450,7 @@ def render(
449450
if should_cache:
450451
self._optimized_cache = query
451452

452-
if optimize:
453+
if needs_optimization:
453454
query = self._resolve_tables(
454455
query,
455456
snapshots=snapshots,

sqlmesh/core/snapshot/cache.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import typing as t
44

55
from pathlib import Path
6-
from sqlmesh.core.model.cache import OptimizedQueryCache
6+
from sqlmesh.core.model.cache import (
7+
OptimizedQueryCache,
8+
optimized_query_cache_pool,
9+
load_optimized_query_cache,
10+
)
11+
from sqlmesh.core import constants as c
712
from sqlmesh.core.snapshot.definition import Snapshot, SnapshotId
813
from sqlmesh.utils.cache import FileCache
914

@@ -31,12 +36,10 @@ def get_or_load(
3136
"""
3237
snapshots = {}
3338
cache_hits: t.Set[SnapshotId] = set()
39+
3440
for s_id in snapshot_ids:
3541
snapshot = self._snapshot_cache.get(self._entry_name(s_id))
3642
if snapshot:
37-
if snapshot.is_model:
38-
self._optimized_query_cache.with_optimized_query(snapshot.model)
39-
self._update_node_hash_cache(snapshot)
4043
snapshot.intervals = []
4144
snapshot.dev_intervals = []
4245
snapshots[s_id] = snapshot
@@ -46,18 +49,43 @@ def get_or_load(
4649
if snapshot_ids_to_load:
4750
loaded_snapshots = loader(snapshot_ids_to_load)
4851
for snapshot in loaded_snapshots:
49-
self._update_node_hash_cache(snapshot)
50-
self.put(snapshot)
5152
snapshots[snapshot.snapshot_id] = snapshot
5253

54+
if c.MAX_FORK_WORKERS != 1:
55+
with optimized_query_cache_pool(self._optimized_query_cache) as executor:
56+
for key, entry_name in executor.map(
57+
load_optimized_query_cache,
58+
(
59+
(snapshot.model, s_id)
60+
for s_id, snapshot in snapshots.items()
61+
if snapshot.is_model
62+
),
63+
):
64+
if entry_name:
65+
self._optimized_query_cache.with_optimized_query(
66+
snapshots[key].model, entry_name
67+
)
68+
69+
for snapshot in snapshots.values():
70+
self._update_node_hash_cache(snapshot)
71+
72+
if snapshot.is_model and c.MAX_FORK_WORKERS == 1:
73+
self._optimized_query_cache.with_optimized_query(snapshot.model)
74+
75+
self.put(snapshot)
76+
5377
return snapshots, cache_hits
5478

5579
def put(self, snapshot: Snapshot) -> None:
80+
entry_name = self._entry_name(snapshot.snapshot_id)
81+
82+
if self._snapshot_cache.exists(entry_name):
83+
return
84+
5685
if snapshot.is_model:
57-
self._optimized_query_cache.put(snapshot.model)
5886
# make sure we preload full_depends_on
5987
snapshot.model.full_depends_on
60-
self._snapshot_cache.put(self._entry_name(snapshot.snapshot_id), value=snapshot)
88+
self._snapshot_cache.put(entry_name, value=snapshot)
6189

6290
def clear(self) -> None:
6391
self._snapshot_cache.clear()

tests/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from sqlmesh.core import constants as c
12
from sqlmesh.core.analytics import disable_analytics
23

4+
c.MAX_FORK_WORKERS = 1
35
disable_analytics()

0 commit comments

Comments
 (0)