115115 ModelTestMetadata ,
116116 generate_test ,
117117 run_tests ,
118+ filter_tests_by_patterns ,
118119)
119120from sqlmesh .core .user import User
120121from sqlmesh .utils import UniqueKeyDict , Verbosity
146147 from typing_extensions import Literal
147148
148149 from sqlmesh .core .engine_adapter ._typing import (
149- BigframeSession ,
150150 DF ,
151+ BigframeSession ,
151152 PySparkDataFrame ,
152153 PySparkSession ,
153154 SnowparkSession ,
@@ -398,6 +399,10 @@ def __init__(
398399 self ._standalone_audits : UniqueKeyDict [str , StandaloneAudit ] = UniqueKeyDict (
399400 "standaloneaudits"
400401 )
402+ self ._model_test_metadata : t .List [ModelTestMetadata ] = []
403+ self ._model_test_metadata_path_index : t .Dict [Path , t .List [ModelTestMetadata ]] = {}
404+ self ._model_test_metadata_fully_qualified_name_index : t .Dict [str , ModelTestMetadata ] = {}
405+ self ._models_with_tests : t .Set [str ] = set ()
401406 self ._macros : UniqueKeyDict [str , ExecutableOrMacro ] = UniqueKeyDict ("macros" )
402407 self ._metrics : UniqueKeyDict [str , Metric ] = UniqueKeyDict ("metrics" )
403408 self ._jinja_macros = JinjaMacroRegistry ()
@@ -636,6 +641,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
636641 self ._excluded_requirements .clear ()
637642 self ._linters .clear ()
638643 self ._environment_statements = []
644+ self ._model_test_metadata .clear ()
645+ self ._model_test_metadata_path_index .clear ()
646+ self ._model_test_metadata_fully_qualified_name_index .clear ()
647+ self ._models_with_tests .clear ()
639648
640649 for loader , project in zip (self ._loaders , loaded_projects ):
641650 self ._jinja_macros = self ._jinja_macros .merge (project .jinja_macros )
@@ -647,6 +656,15 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
647656 self ._requirements .update (project .requirements )
648657 self ._excluded_requirements .update (project .excluded_requirements )
649658 self ._environment_statements .extend (project .environment_statements )
659+ self ._model_test_metadata .extend (project .model_test_metadata )
660+ for metadata in project .model_test_metadata :
661+ if metadata .path not in self ._model_test_metadata_path_index :
662+ self ._model_test_metadata_path_index [metadata .path ] = []
663+ self ._model_test_metadata_path_index [metadata .path ].append (metadata )
664+ self ._model_test_metadata_fully_qualified_name_index [
665+ metadata .fully_qualified_test_name
666+ ] = metadata
667+ self ._models_with_tests .add (metadata .model_name )
650668
651669 config = loader .config
652670 self ._linters [config .project ] = Linter .from_rules (
@@ -1049,6 +1067,11 @@ def standalone_audits(self) -> MappingProxyType[str, StandaloneAudit]:
10491067 """Returns all registered standalone audits in this context."""
10501068 return MappingProxyType (self ._standalone_audits )
10511069
1070+ @property
1071+ def models_with_tests (self ) -> t .Set [str ]:
1072+ """Returns all models with tests in this context."""
1073+ return self ._models_with_tests
1074+
10521075 @property
10531076 def snapshots (self ) -> t .Dict [str , Snapshot ]:
10541077 """Generates and returns snapshots based on models registered in this context.
@@ -2220,7 +2243,9 @@ def test(
22202243
22212244 pd .set_option ("display.max_columns" , None )
22222245
2223- test_meta = self .load_model_tests (tests = tests , patterns = match_patterns )
2246+ test_meta = self ._select_tests (
2247+ test_meta = self ._model_test_metadata , tests = tests , patterns = match_patterns
2248+ )
22242249
22252250 result = run_tests (
22262251 model_test_metadata = test_meta ,
@@ -2782,6 +2807,33 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
27822807 raise SQLMeshError (f"Gateway '{ gateway } ' not found in the available engine adapters." )
27832808 return self .engine_adapter
27842809
2810+ def _select_tests (
2811+ self ,
2812+ test_meta : t .List [ModelTestMetadata ],
2813+ tests : t .Optional [t .List [str ]] = None ,
2814+ patterns : t .Optional [t .List [str ]] = None ,
2815+ ) -> t .List [ModelTestMetadata ]:
2816+ """Filter pre-loaded test metadata based on tests and patterns."""
2817+
2818+ if tests :
2819+ filtered_tests = []
2820+ for test in tests :
2821+ if "::" in test :
2822+ if test in self ._model_test_metadata_fully_qualified_name_index :
2823+ filtered_tests .append (
2824+ self ._model_test_metadata_fully_qualified_name_index [test ]
2825+ )
2826+ else :
2827+ test_path = Path (test )
2828+ if test_path in self ._model_test_metadata_path_index :
2829+ filtered_tests .extend (self ._model_test_metadata_path_index [test_path ])
2830+ test_meta = filtered_tests
2831+
2832+ if patterns :
2833+ test_meta = filter_tests_by_patterns (test_meta , patterns )
2834+
2835+ return test_meta
2836+
27852837 def _snapshots (
27862838 self , models_override : t .Optional [UniqueKeyDict [str , Model ]] = None
27872839 ) -> t .Dict [str , Snapshot ]:
0 commit comments