|
3 | 3 | import abc |
4 | 4 | import linecache |
5 | 5 | import logging |
| 6 | +import multiprocessing as mp |
6 | 7 | import os |
7 | 8 | import typing as t |
8 | 9 | from collections import defaultdict |
| 10 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
9 | 11 | from dataclasses import dataclass |
10 | 12 | from pathlib import Path |
11 | 13 |
|
|
23 | 25 | ModelCache, |
24 | 26 | OptimizedQueryCache, |
25 | 27 | SeedModel, |
| 28 | + SqlModel, |
26 | 29 | create_external_model, |
27 | 30 | load_sql_based_model, |
28 | 31 | ) |
|
42 | 45 | logger = logging.getLogger(__name__) |
43 | 46 |
|
44 | 47 |
|
45 | | -# TODO: consider moving this to context |
46 | | -def update_model_schemas( |
47 | | - dag: DAG[str], |
48 | | - models: UniqueKeyDict[str, Model], |
49 | | - context_path: Path, |
50 | | -) -> None: |
51 | | - schema = MappingSchema(normalize=False) |
52 | | - optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE) |
53 | | - |
54 | | - for name in dag.sorted: |
55 | | - model = models.get(name) |
56 | | - |
57 | | - # External models don't exist in the context, so we need to skip them |
58 | | - if not model: |
59 | | - continue |
60 | | - |
61 | | - try: |
62 | | - model.update_schema(schema) |
63 | | - optimized_query_cache.with_optimized_query(model) |
64 | | - |
65 | | - columns_to_types = model.columns_to_types |
66 | | - if columns_to_types is not None: |
67 | | - schema.add_table( |
68 | | - model.fqn, columns_to_types, dialect=model.dialect, normalize=False |
69 | | - ) |
70 | | - except SchemaError as e: |
71 | | - if "nesting level:" in str(e): |
72 | | - logger.error( |
73 | | - "SQLMesh requires all model names and references to have the same level of nesting." |
74 | | - ) |
75 | | - raise |
76 | | - |
77 | | - |
78 | 48 | @dataclass |
79 | 49 | class LoadedProject: |
80 | 50 | macros: MacroRegistry |
@@ -568,3 +538,111 @@ def _model_cache_entry_id(self, model_path: Path) -> str: |
568 | 538 | or self._loader._context.config.default_gateway_name, |
569 | 539 | ] |
570 | 540 | ) |
| 541 | + |
| 542 | + |
| 543 | +# TODO: consider moving this to context |
| 544 | +def update_model_schemas( |
| 545 | + dag: DAG[str], |
| 546 | + models: UniqueKeyDict[str, Model], |
| 547 | + context_path: Path, |
| 548 | +) -> None: |
| 549 | + schema = MappingSchema(normalize=False) |
| 550 | + optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE) |
| 551 | + |
| 552 | + if not hasattr(os, "fork") or "PYTEST_CURRENT_TEST" in os.environ: |
| 553 | + _update_model_schemas_sequential(dag, models, schema, optimized_query_cache) |
| 554 | + else: |
| 555 | + _update_model_schemas_parallel(dag, models, schema, optimized_query_cache) |
| 556 | + |
| 557 | + |
| 558 | +def _update_schema_with_model(schema: MappingSchema, model: Model) -> None: |
| 559 | + columns_to_types = model.columns_to_types |
| 560 | + if columns_to_types: |
| 561 | + try: |
| 562 | + schema.add_table(model.fqn, columns_to_types, dialect=model.dialect, normalize=False) |
| 563 | + except SchemaError as e: |
| 564 | + if "nesting level:" in str(e): |
| 565 | + logger.error( |
| 566 | + "SQLMesh requires all model names and references to have the same level of nesting." |
| 567 | + ) |
| 568 | + raise |
| 569 | + |
| 570 | + |
| 571 | +def _update_model_schemas_sequential( |
| 572 | + dag: DAG[str], |
| 573 | + models: UniqueKeyDict[str, Model], |
| 574 | + schema: MappingSchema, |
| 575 | + optimized_query_cache: OptimizedQueryCache, |
| 576 | +) -> None: |
| 577 | + for name in dag.sorted: |
| 578 | + model = models.get(name) |
| 579 | + |
| 580 | + # External models don't exist in the context, so we need to skip them |
| 581 | + if not model: |
| 582 | + continue |
| 583 | + |
| 584 | + model.update_schema(schema) |
| 585 | + optimized_query_cache.with_optimized_query(model) |
| 586 | + _update_schema_with_model(schema, model) |
| 587 | + |
| 588 | + |
| 589 | +def _update_model_schemas_parallel( |
| 590 | + dag: DAG[str], |
| 591 | + models: UniqueKeyDict[str, Model], |
| 592 | + schema: MappingSchema, |
| 593 | + optimized_query_cache: OptimizedQueryCache, |
| 594 | +) -> None: |
| 595 | + futures = set() |
| 596 | + graph = { |
| 597 | + model: {dep for dep in deps if dep in models} |
| 598 | + for model, deps in dag._dag.items() |
| 599 | + if model in models |
| 600 | + } |
| 601 | + |
| 602 | + def process_models(completed_model: t.Optional[Model] = None) -> None: |
| 603 | + for name in list(graph): |
| 604 | + deps = graph[name] |
| 605 | + |
| 606 | + if completed_model: |
| 607 | + deps.discard(completed_model.fqn) |
| 608 | + |
| 609 | + if not deps: |
| 610 | + del graph[name] |
| 611 | + model = models[name] |
| 612 | + model.update_schema(schema) |
| 613 | + futures.add(executor.submit(_load_optimized_query_cache, model)) |
| 614 | + |
| 615 | + with ProcessPoolExecutor( |
| 616 | + mp_context=mp.get_context("fork"), |
| 617 | + initializer=_init_optimized_query_cache, |
| 618 | + initargs=(optimized_query_cache,), |
| 619 | + ) as executor: |
| 620 | + process_models() |
| 621 | + |
| 622 | + while futures: |
| 623 | + for future in as_completed(futures): |
| 624 | + futures.remove(future) |
| 625 | + fqn, entry_name = future.result() |
| 626 | + model = models[fqn] |
| 627 | + if entry_name: |
| 628 | + optimized_query_cache.with_optimized_query(model, entry_name) |
| 629 | + |
| 630 | + _update_schema_with_model(schema, model) |
| 631 | + process_models(completed_model=model) |
| 632 | + |
| 633 | + |
| 634 | +_optimized_query_cache: t.Optional[OptimizedQueryCache] = None |
| 635 | + |
| 636 | + |
| 637 | +def _init_optimized_query_cache(optimized_query_cache: OptimizedQueryCache) -> None: |
| 638 | + global _optimized_query_cache |
| 639 | + _optimized_query_cache = optimized_query_cache |
| 640 | + |
| 641 | + |
| 642 | +def _load_optimized_query_cache(model: Model) -> t.Tuple[str, t.Optional[str]]: |
| 643 | + assert _optimized_query_cache |
| 644 | + if isinstance(model, SqlModel): |
| 645 | + entry_name = _optimized_query_cache.put(model) |
| 646 | + else: |
| 647 | + entry_name = None |
| 648 | + return model.fqn, entry_name |
0 commit comments