Skip to content

Commit 2e77993

Browse files
authored
Merge branch 'main' into fabric_alter_table_no_op
2 parents ba2c888 + 98998d4 commit 2e77993

File tree

8 files changed

+204
-8
lines changed

8 files changed

+204
-8
lines changed

sqlmesh/core/context.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
ModelTestMetadata,
116116
generate_test,
117117
run_tests,
118+
filter_tests_by_patterns,
118119
)
119120
from sqlmesh.core.user import User
120121
from sqlmesh.utils import UniqueKeyDict, Verbosity
@@ -146,8 +147,8 @@
146147
from typing_extensions import Literal
147148

148149
from sqlmesh.core.engine_adapter._typing import (
149-
BigframeSession,
150150
DF,
151+
BigframeSession,
151152
PySparkDataFrame,
152153
PySparkSession,
153154
SnowparkSession,
@@ -398,6 +399,10 @@ def __init__(
398399
self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict(
399400
"standaloneaudits"
400401
)
402+
self._model_test_metadata: t.List[ModelTestMetadata] = []
403+
self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {}
404+
self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {}
405+
self._models_with_tests: t.Set[str] = set()
401406
self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros")
402407
self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics")
403408
self._jinja_macros = JinjaMacroRegistry()
@@ -636,6 +641,10 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
636641
self._excluded_requirements.clear()
637642
self._linters.clear()
638643
self._environment_statements = []
644+
self._model_test_metadata.clear()
645+
self._model_test_metadata_path_index.clear()
646+
self._model_test_metadata_fully_qualified_name_index.clear()
647+
self._models_with_tests.clear()
639648

640649
for loader, project in zip(self._loaders, loaded_projects):
641650
self._jinja_macros = self._jinja_macros.merge(project.jinja_macros)
@@ -647,6 +656,15 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]:
647656
self._requirements.update(project.requirements)
648657
self._excluded_requirements.update(project.excluded_requirements)
649658
self._environment_statements.extend(project.environment_statements)
659+
self._model_test_metadata.extend(project.model_test_metadata)
660+
for metadata in project.model_test_metadata:
661+
if metadata.path not in self._model_test_metadata_path_index:
662+
self._model_test_metadata_path_index[metadata.path] = []
663+
self._model_test_metadata_path_index[metadata.path].append(metadata)
664+
self._model_test_metadata_fully_qualified_name_index[
665+
metadata.fully_qualified_test_name
666+
] = metadata
667+
self._models_with_tests.add(metadata.model_name)
650668

651669
config = loader.config
652670
self._linters[config.project] = Linter.from_rules(
@@ -1049,6 +1067,11 @@ def standalone_audits(self) -> MappingProxyType[str, StandaloneAudit]:
10491067
"""Returns all registered standalone audits in this context."""
10501068
return MappingProxyType(self._standalone_audits)
10511069

1070+
@property
1071+
def models_with_tests(self) -> t.Set[str]:
1072+
"""Returns all models with tests in this context."""
1073+
return self._models_with_tests
1074+
10521075
@property
10531076
def snapshots(self) -> t.Dict[str, Snapshot]:
10541077
"""Generates and returns snapshots based on models registered in this context.
@@ -2220,7 +2243,9 @@ def test(
22202243

22212244
pd.set_option("display.max_columns", None)
22222245

2223-
test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
2246+
test_meta = self._select_tests(
2247+
test_meta=self._model_test_metadata, tests=tests, patterns=match_patterns
2248+
)
22242249

22252250
result = run_tests(
22262251
model_test_metadata=test_meta,
@@ -2782,6 +2807,33 @@ def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter:
27822807
raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.")
27832808
return self.engine_adapter
27842809

2810+
def _select_tests(
2811+
self,
2812+
test_meta: t.List[ModelTestMetadata],
2813+
tests: t.Optional[t.List[str]] = None,
2814+
patterns: t.Optional[t.List[str]] = None,
2815+
) -> t.List[ModelTestMetadata]:
2816+
"""Filter pre-loaded test metadata based on tests and patterns."""
2817+
2818+
if tests:
2819+
filtered_tests = []
2820+
for test in tests:
2821+
if "::" in test:
2822+
if test in self._model_test_metadata_fully_qualified_name_index:
2823+
filtered_tests.append(
2824+
self._model_test_metadata_fully_qualified_name_index[test]
2825+
)
2826+
else:
2827+
test_path = Path(test)
2828+
if test_path in self._model_test_metadata_path_index:
2829+
filtered_tests.extend(self._model_test_metadata_path_index[test_path])
2830+
test_meta = filtered_tests
2831+
2832+
if patterns:
2833+
test_meta = filter_tests_by_patterns(test_meta, patterns)
2834+
2835+
return test_meta
2836+
27852837
def _snapshots(
27862838
self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None
27872839
) -> t.Dict[str, Snapshot]:

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,17 @@ def _build_table_properties_exp(
394394
expressions.append(clustered_by_exp)
395395
properties = exp.Properties(expressions=expressions)
396396
return properties
397+
398+
def _build_column_defs(
399+
self,
400+
target_columns_to_types: t.Dict[str, exp.DataType],
401+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
402+
is_view: bool = False,
403+
) -> t.List[exp.ColumnDef]:
404+
# Databricks requires column types to be specified when adding column comments
405+
# in CREATE MATERIALIZED VIEW statements. Override is_view to False to force
406+
# column types to be included when comments are present.
407+
if is_view and column_descriptions:
408+
is_view = False
409+
410+
return super()._build_column_defs(target_columns_to_types, column_descriptions, is_view)

sqlmesh/core/linter/rules/builtin.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,21 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
129129
return self.violation()
130130

131131

132+
class NoMissingUnitTest(Rule):
133+
"""All models must have a unit test found in the test/ directory yaml files"""
134+
135+
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
136+
# External models cannot have unit tests
137+
if isinstance(model, ExternalModel):
138+
return None
139+
140+
if model.name not in self.context.models_with_tests:
141+
return self.violation(
142+
violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory."
143+
)
144+
return None
145+
146+
132147
class NoMissingExternalModels(Rule):
133148
"""All external models must be registered in the external_models.yaml file"""
134149

sqlmesh/core/loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class LoadedProject:
6464
excluded_requirements: t.Set[str]
6565
environment_statements: t.List[EnvironmentStatements]
6666
user_rules: RuleSet
67+
model_test_metadata: t.List[ModelTestMetadata]
6768

6869

6970
class CacheBase(abc.ABC):
@@ -243,6 +244,8 @@ def load(self) -> LoadedProject:
243244

244245
user_rules = self._load_linting_rules()
245246

247+
model_test_metadata = self.load_model_tests()
248+
246249
project = LoadedProject(
247250
macros=macros,
248251
jinja_macros=jinja_macros,
@@ -254,6 +257,7 @@ def load(self) -> LoadedProject:
254257
excluded_requirements=excluded_requirements,
255258
environment_statements=environment_statements,
256259
user_rules=user_rules,
260+
model_test_metadata=model_test_metadata,
257261
)
258262
return project
259263

sqlmesh/core/test/discovery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ class ModelTestMetadata(PydanticModel):
2020
def fully_qualified_test_name(self) -> str:
2121
return f"{self.path}::{self.test_name}"
2222

23+
@property
24+
def model_name(self) -> str:
25+
return self.body.get("model", "")
26+
2327
def __hash__(self) -> int:
2428
return self.fully_qualified_test_name.__hash__()
2529

tests/core/engine_adapter/test_databricks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,36 @@ def test_materialized_view_properties(mocker: MockFixture, make_mocked_engine_ad
376376
]
377377

378378

379+
def test_materialized_view_with_column_comments(
380+
mocker: MockFixture, make_mocked_engine_adapter: t.Callable
381+
):
382+
mocker.patch(
383+
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"
384+
)
385+
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
386+
mocker.patch.object(adapter, "get_current_catalog", return_value="test_catalog")
387+
388+
adapter.create_view(
389+
"test_view",
390+
parse_one("SELECT a, b FROM source_table"),
391+
target_columns_to_types={
392+
"a": exp.DataType.build("INT"),
393+
"b": exp.DataType.build("STRING"),
394+
},
395+
materialized=True,
396+
column_descriptions={
397+
"a": "column a description",
398+
"b": "column b description",
399+
},
400+
)
401+
402+
sql_calls = to_sql_calls(adapter)
403+
# Databricks requires column types when column comments are present in materialized views
404+
assert sql_calls == [
405+
"CREATE OR REPLACE MATERIALIZED VIEW `test_view` (`a` INT COMMENT 'column a description', `b` STRING COMMENT 'column b description') AS SELECT `a`, `b` FROM `source_table`",
406+
]
407+
408+
379409
def test_create_table_clustered_by(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):
380410
mocker.patch(
381411
"sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog"

tests/core/linter/test_builtin.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,63 @@ def test_no_missing_external_models_with_existing_file_not_ending_in_newline(
172172
)
173173
fix_path = sushi_path / "external_models.yaml"
174174
assert edit.path == fix_path
175+
176+
177+
def test_no_missing_unit_tests(tmp_path, copy_to_temp_path):
178+
"""
179+
Tests that the NoMissingUnitTest linter rule correctly identifies models
180+
without corresponding unit tests in the tests/ directory
181+
182+
This test checks the sushi example project, enables the linter,
183+
and verifies that the linter raises a rule violation for the models
184+
that do not have a unit test
185+
"""
186+
sushi_paths = copy_to_temp_path("examples/sushi")
187+
sushi_path = sushi_paths[0]
188+
189+
# Override the config.py to turn on lint
190+
with open(sushi_path / "config.py", "r") as f:
191+
read_file = f.read()
192+
193+
before = """ linter=LinterConfig(
194+
enabled=False,
195+
rules=[
196+
"ambiguousorinvalidcolumn",
197+
"invalidselectstarexpansion",
198+
"noselectstar",
199+
"nomissingaudits",
200+
"nomissingowner",
201+
"nomissingexternalmodels",
202+
],
203+
),"""
204+
after = """linter=LinterConfig(enabled=True, rules=["nomissingunittest"]),"""
205+
read_file = read_file.replace(before, after)
206+
assert after in read_file
207+
with open(sushi_path / "config.py", "w") as f:
208+
f.writelines(read_file)
209+
210+
# Load the context with the temporary sushi path
211+
context = Context(paths=[sushi_path])
212+
213+
# Lint the models
214+
lints = context.lint_models(raise_on_error=False)
215+
216+
# Should have violations for models without tests (most models except customers)
217+
assert len(lints) >= 1
218+
219+
# Check that we get violations for models without tests
220+
violation_messages = [lint.violation_msg for lint in lints]
221+
assert any("is missing unit test(s)" in msg for msg in violation_messages)
222+
223+
# Check that models with existing tests don't have violations
224+
models_with_tests = ["customer_revenue_by_day", "customer_revenue_lifetime", "order_items"]
225+
226+
for model_name in models_with_tests:
227+
model_violations = [
228+
lint
229+
for lint in lints
230+
if model_name in lint.violation_msg and "is missing unit test(s)" in lint.violation_msg
231+
]
232+
assert len(model_violations) == 0, (
233+
f"Model {model_name} should not have a violation since it has a test"
234+
)

tests/core/test_test.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,9 @@ def test_gateway(copy_to_temp_path: t.Callable, mocker: MockerFixture) -> None:
15391539
with open(test_path, "w", encoding="utf-8") as file:
15401540
dump_yaml(test_dict, file)
15411541

1542+
# Re-initialize context to pick up the modified test file
1543+
context = Context(paths=path, config=config)
1544+
15421545
spy_execute = mocker.spy(EngineAdapter, "_execute")
15431546
mocker.patch("sqlmesh.core.test.definition.random_id", return_value="jzngz56a")
15441547

@@ -2448,6 +2451,9 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
24482451
copy_test_file(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml", i)
24492452
copy_test_file(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml", i)
24502453

2454+
# Re-initialize context to pick up the new test files
2455+
context = Context(paths=tmp_path, config=config)
2456+
24512457
with capture_output() as captured_output:
24522458
context.test()
24532459

@@ -2463,13 +2469,12 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
24632469
"SELECT 1 AS col_1, 2 AS col_2, 3 AS col_3, 4 AS col_4, 5 AS col_5, 6 AS col_6, 7 AS col_7"
24642470
)
24652471

2466-
context.upsert_model(
2467-
_create_model(
2468-
meta="MODEL(name test.test_wide_model)",
2469-
query=wide_model_query,
2470-
default_catalog=context.default_catalog,
2471-
)
2472+
wide_model = _create_model(
2473+
meta="MODEL(name test.test_wide_model)",
2474+
query=wide_model_query,
2475+
default_catalog=context.default_catalog,
24722476
)
2477+
context.upsert_model(wide_model)
24732478

24742479
tests_dir = tmp_path / "tests"
24752480
tests_dir.mkdir()
@@ -2493,6 +2498,9 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
24932498

24942499
wide_test_file.write_text(wide_test_file_content)
24952500

2501+
context.load()
2502+
context.upsert_model(wide_model)
2503+
24962504
with capture_output() as captured_output:
24972505
context.test()
24982506

@@ -2549,6 +2557,9 @@ def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
25492557
"""
25502558
)
25512559

2560+
# Re-initialize context to pick up the modified test file
2561+
context = Context(paths=tmp_path, config=config)
2562+
25522563
with capture_output() as captured_output:
25532564
context.test()
25542565

@@ -3472,6 +3483,9 @@ def test_cte_failure(tmp_path: Path) -> None:
34723483
"""
34733484
)
34743485

3486+
# Re-initialize context to pick up the new test file
3487+
context = Context(paths=tmp_path, config=config)
3488+
34753489
with capture_output() as captured_output:
34763490
context.test()
34773491

@@ -3498,6 +3512,9 @@ def test_cte_failure(tmp_path: Path) -> None:
34983512
"""
34993513
)
35003514

3515+
# Re-initialize context to pick up the modified test file
3516+
context = Context(paths=tmp_path, config=config)
3517+
35013518
with capture_output() as captured_output:
35023519
context.test()
35033520

0 commit comments

Comments
 (0)