Skip to content

Commit 570870f

Browse files
tobymaoizeigerman
andcommitted
feat: use multiprocessing to speed up loading (#3077)
Co-authored-by: Iaroslav Zeigerman <zeigerman.ia@gmail.com>
1 parent 4c2f843 commit 570870f

File tree

4 files changed

+166
-44
lines changed

4 files changed

+166
-44
lines changed

sqlmesh/core/loader.py

Lines changed: 111 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import abc
44
import linecache
55
import logging
6+
import multiprocessing as mp
67
import os
78
import typing as t
89
from collections import defaultdict
10+
from concurrent.futures import ProcessPoolExecutor, as_completed
911
from dataclasses import dataclass
1012
from pathlib import Path
1113

@@ -23,6 +25,7 @@
2325
ModelCache,
2426
OptimizedQueryCache,
2527
SeedModel,
28+
SqlModel,
2629
create_external_model,
2730
load_sql_based_model,
2831
)
@@ -42,39 +45,6 @@
4245
logger = logging.getLogger(__name__)
4346

4447

45-
# TODO: consider moving this to context
46-
def update_model_schemas(
47-
dag: DAG[str],
48-
models: UniqueKeyDict[str, Model],
49-
context_path: Path,
50-
) -> None:
51-
schema = MappingSchema(normalize=False)
52-
optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
53-
54-
for name in dag.sorted:
55-
model = models.get(name)
56-
57-
# External models don't exist in the context, so we need to skip them
58-
if not model:
59-
continue
60-
61-
try:
62-
model.update_schema(schema)
63-
optimized_query_cache.with_optimized_query(model)
64-
65-
columns_to_types = model.columns_to_types
66-
if columns_to_types is not None:
67-
schema.add_table(
68-
model.fqn, columns_to_types, dialect=model.dialect, normalize=False
69-
)
70-
except SchemaError as e:
71-
if "nesting level:" in str(e):
72-
logger.error(
73-
"SQLMesh requires all model names and references to have the same level of nesting."
74-
)
75-
raise
76-
77-
7848
@dataclass
7949
class LoadedProject:
8050
macros: MacroRegistry
@@ -568,3 +538,111 @@ def _model_cache_entry_id(self, model_path: Path) -> str:
568538
or self._loader._context.config.default_gateway_name,
569539
]
570540
)
541+
542+
543+
# TODO: consider moving this to context
544+
def update_model_schemas(
545+
dag: DAG[str],
546+
models: UniqueKeyDict[str, Model],
547+
context_path: Path,
548+
) -> None:
549+
schema = MappingSchema(normalize=False)
550+
optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE)
551+
552+
if not hasattr(os, "fork") or "PYTEST_CURRENT_TEST" in os.environ:
553+
_update_model_schemas_sequential(dag, models, schema, optimized_query_cache)
554+
else:
555+
_update_model_schemas_parallel(dag, models, schema, optimized_query_cache)
556+
557+
558+
def _update_schema_with_model(schema: MappingSchema, model: Model) -> None:
559+
columns_to_types = model.columns_to_types
560+
if columns_to_types:
561+
try:
562+
schema.add_table(model.fqn, columns_to_types, dialect=model.dialect, normalize=False)
563+
except SchemaError as e:
564+
if "nesting level:" in str(e):
565+
logger.error(
566+
"SQLMesh requires all model names and references to have the same level of nesting."
567+
)
568+
raise
569+
570+
571+
def _update_model_schemas_sequential(
572+
dag: DAG[str],
573+
models: UniqueKeyDict[str, Model],
574+
schema: MappingSchema,
575+
optimized_query_cache: OptimizedQueryCache,
576+
) -> None:
577+
for name in dag.sorted:
578+
model = models.get(name)
579+
580+
# External models don't exist in the context, so we need to skip them
581+
if not model:
582+
continue
583+
584+
model.update_schema(schema)
585+
optimized_query_cache.with_optimized_query(model)
586+
_update_schema_with_model(schema, model)
587+
588+
589+
def _update_model_schemas_parallel(
590+
dag: DAG[str],
591+
models: UniqueKeyDict[str, Model],
592+
schema: MappingSchema,
593+
optimized_query_cache: OptimizedQueryCache,
594+
) -> None:
595+
futures = set()
596+
graph = {
597+
model: {dep for dep in deps if dep in models}
598+
for model, deps in dag._dag.items()
599+
if model in models
600+
}
601+
602+
def process_models(completed_model: t.Optional[Model] = None) -> None:
603+
for name in list(graph):
604+
deps = graph[name]
605+
606+
if completed_model:
607+
deps.discard(completed_model.fqn)
608+
609+
if not deps:
610+
del graph[name]
611+
model = models[name]
612+
model.update_schema(schema)
613+
futures.add(executor.submit(_load_optimized_query_cache, model))
614+
615+
with ProcessPoolExecutor(
616+
mp_context=mp.get_context("fork"),
617+
initializer=_init_optimized_query_cache,
618+
initargs=(optimized_query_cache,),
619+
) as executor:
620+
process_models()
621+
622+
while futures:
623+
for future in as_completed(futures):
624+
futures.remove(future)
625+
fqn, entry_name = future.result()
626+
model = models[fqn]
627+
if entry_name:
628+
optimized_query_cache.with_optimized_query(model, entry_name)
629+
630+
_update_schema_with_model(schema, model)
631+
process_models(completed_model=model)
632+
633+
634+
_optimized_query_cache: t.Optional[OptimizedQueryCache] = None
635+
636+
637+
def _init_optimized_query_cache(optimized_query_cache: OptimizedQueryCache) -> None:
638+
global _optimized_query_cache
639+
_optimized_query_cache = optimized_query_cache
640+
641+
642+
def _load_optimized_query_cache(model: Model) -> t.Tuple[str, t.Optional[str]]:
643+
assert _optimized_query_cache
644+
if isinstance(model, SqlModel):
645+
entry_name = _optimized_query_cache.put(model)
646+
else:
647+
entry_name = None
648+
return model.fqn, entry_name

sqlmesh/core/model/cache.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,17 @@ def __init__(self, path: Path):
7272
path, prefix="optimized_query"
7373
)
7474

75-
def with_optimized_query(self, model: Model) -> bool:
75+
def with_optimized_query(self, model: Model, name: t.Optional[str] = None) -> bool:
7676
"""Adds an optimized query to the model's in-memory cache.
7777
7878
Args:
7979
model: The model to add the optimized query to.
80+
name: The cache entry name of the model.
8081
"""
8182
if not isinstance(model, SqlModel):
8283
return False
8384

84-
name = self._entry_name(model)
85+
name = self._entry_name(model) if name is None else name
8586
cache_entry = self._file_cache.get(name)
8687
if cache_entry:
8788
try:
@@ -101,15 +102,17 @@ def with_optimized_query(self, model: Model) -> bool:
101102
self._put(name, model)
102103
return False
103104

104-
def put(self, model: Model) -> None:
105+
def put(self, model: Model) -> t.Optional[str]:
105106
if not isinstance(model, SqlModel):
106-
return
107+
return None
107108

108109
name = self._entry_name(model)
110+
109111
if self._file_cache.exists(name):
110-
return
112+
return name
111113

112114
self._put(name, model)
115+
return name
113116

114117
def _put(self, name: str, model: SqlModel) -> None:
115118
optimized_query = model.render_query()

sqlmesh/core/renderer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _render(
7777
table_mapping: t.Optional[t.Dict[str, str]] = None,
7878
deployability_index: t.Optional[DeployabilityIndex] = None,
7979
runtime_stage: RuntimeStage = RuntimeStage.LOADING,
80+
normalize_identifiers: bool = True,
8081
**kwargs: t.Any,
8182
) -> t.List[t.Optional[exp.Expression]]:
8283
"""Renders a expression, expanding macros with provided kwargs
@@ -89,14 +90,15 @@ def _render(
8990
table_mapping: Table mapping of physical locations. Takes precedence over snapshot mappings.
9091
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
9192
runtime_stage: Indicates the current runtime stage, for example if we're still loading the project, etc.
93+
normalize_identifiers: Whether or not to normalize and quote identifiers.
9294
kwargs: Additional kwargs to pass to the renderer.
9395
9496
Returns:
9597
The rendered expressions.
9698
"""
9799

98100
should_cache = self._should_cache(
99-
runtime_stage, start, end, execution_time, *kwargs.values()
101+
runtime_stage, start, end, execution_time, not normalize_identifiers, *kwargs.values()
100102
)
101103

102104
if should_cache and self._cache:
@@ -193,7 +195,7 @@ def _render(
193195
raise_config_error(f"Failed to resolve macro for expression. {ex}", self._path)
194196

195197
for expression in t.cast(t.List[exp.Expression], transformed_expressions):
196-
with self._normalize_and_quote(expression) as expression:
198+
with self._normalize_and_quote(expression, normalize_identifiers) as expression:
197199
if hasattr(expression, "selects"):
198200
for select in expression.selects:
199201
if not isinstance(select, exp.Alias) and select.output_name not in (
@@ -295,8 +297,8 @@ def _expand(node: exp.Expression) -> exp.Expression:
295297
return expression
296298

297299
@contextmanager
298-
def _normalize_and_quote(self, query: E) -> t.Iterator[E]:
299-
if self._normalize_identifiers:
300+
def _normalize_and_quote(self, query: E, normalize_identifiers: bool = True) -> t.Iterator[E]:
301+
if self._normalize_identifiers and normalize_identifiers:
300302
with d.normalize_and_quote(
301303
query, self._dialect, self._default_catalog, quote=self._quote_identifiers
302304
) as query:
@@ -400,10 +402,10 @@ def render(
400402
"""
401403

402404
should_cache = self._should_cache(
403-
runtime_stage, start, end, execution_time, *kwargs.values()
405+
runtime_stage, start, end, execution_time, not optimize, *kwargs.values()
404406
)
405407

406-
if should_cache and self._optimized_cache and optimize:
408+
if should_cache and self._optimized_cache:
407409
query = self._optimized_cache
408410
else:
409411
try:
@@ -415,6 +417,7 @@ def render(
415417
table_mapping=table_mapping,
416418
deployability_index=deployability_index,
417419
runtime_stage=runtime_stage,
420+
normalize_identifiers=optimize,
418421
**kwargs,
419422
)
420423
except ParsetimeAdapterCallError:

tests/core/test_model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5677,3 +5677,41 @@ def test_cache():
56775677
model = load_sql_based_model(expressions)
56785678
assert model.depends_on == {'"y"'}
56795679
assert model.copy(update={"depends_on_": {'"z"'}}).depends_on == {'"z"', '"y"'}
5680+
5681+
5682+
def test_parallel_load(assert_exp_eq, mocker):
5683+
import os
5684+
from sqlmesh.core import loader
5685+
5686+
pytest_current_test = os.environ.pop("PYTEST_CURRENT_TEST")
5687+
try:
5688+
spy = mocker.spy(loader, "_update_model_schemas_parallel")
5689+
context = Context(paths="examples/sushi")
5690+
5691+
if hasattr(os, "fork"):
5692+
spy.assert_called()
5693+
5694+
assert_exp_eq(
5695+
context.render("sushi.customers"),
5696+
"""
5697+
WITH "current_marketing" AS (
5698+
SELECT
5699+
"marketing"."customer_id" AS "customer_id",
5700+
"marketing"."status" AS "status"
5701+
FROM "memory"."sushi"."marketing" AS "marketing"
5702+
WHERE
5703+
"marketing"."valid_to" IS NULL
5704+
)
5705+
SELECT DISTINCT
5706+
CAST("o"."customer_id" AS INT) AS "customer_id", /* this comment should not be registered */
5707+
"m"."status" AS "status",
5708+
"d"."zip" AS "zip"
5709+
FROM "memory"."sushi"."orders" AS "o"
5710+
LEFT JOIN "current_marketing" AS "m"
5711+
ON "m"."customer_id" = "o"."customer_id"
5712+
LEFT JOIN "memory"."raw"."demographics" AS "d"
5713+
ON "d"."customer_id" = "o"."customer_id"
5714+
""",
5715+
)
5716+
finally:
5717+
os.environ["PYTEST_CURRENT_TEST"] = pytest_current_test

0 commit comments

Comments
 (0)