Skip to content

Commit 68ea9f6

Browse files
authored
Feat: add support for YAML dictionaries in unit tests (MVP) (#2264)
* Feat: add support for YAML dictionaries in unit tests * Replace applymap with map since it got deprecated after pandas 2.1.0 * Revert dict logic * Refactor
1 parent c73f3e5 commit 68ea9f6

File tree

7 files changed

+62
-46
lines changed

7 files changed

+62
-46
lines changed

sqlmesh/core/dialect.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525

2626
SQLMESH_MACRO_PREFIX = "@"
2727

28-
JSON_TYPE = exp.DataType.build("json")
29-
3028
TABLES_META = "sqlmesh.tables"
3129

3230

@@ -942,13 +940,12 @@ def _transform(node: exp.Expression) -> exp.Expression:
942940
def transform_values(
943941
values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType]
944942
) -> t.Iterator[t.Any]:
945-
"""Perform transformations on values given columns_to_types.
946-
947-
Currently, the only transformation is wrapping JSON columns with PARSE_JSON().
948-
"""
943+
"""Perform transformations on values given columns_to_types."""
949944
for value, col_type in zip(values, columns_to_types.values()):
950-
if col_type == JSON_TYPE:
945+
if col_type.is_type(exp.DataType.Type.JSON):
951946
yield exp.func("PARSE_JSON", f"'{value}'")
947+
elif isinstance(value, dict) and col_type.is_type(*exp.DataType.STRUCT_TYPES):
948+
yield _dict_to_struct(value)
952949
else:
953950
yield value
954951

@@ -994,3 +991,13 @@ def _unquote_schema(schema: t.Dict) -> t.Dict:
994991
return {
995992
k.strip('"'): _unquote_schema(v) if isinstance(v, dict) else v for k, v in schema.items()
996993
}
994+
995+
996+
def _dict_to_struct(values: t.Dict) -> exp.Struct:
997+
expressions = []
998+
for key, value in values.items():
999+
key = exp.to_identifier(key)
1000+
value = _dict_to_struct(value) if isinstance(value, dict) else exp.convert(value)
1001+
expressions.append(exp.PropertyEQ(this=key, expression=value))
1002+
1003+
return exp.Struct(expressions=expressions)

sqlmesh/core/engine_adapter/base.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,20 +1000,17 @@ def insert_append(
10001000
)
10011001

10021002
@t.overload
1003-
@classmethod
1004-
def _escape_json(cls, value: Query) -> Query: ...
1003+
def _escape_json(self, value: Query) -> Query: ...
10051004

10061005
@t.overload
1007-
@classmethod
1008-
def _escape_json(cls, value: str) -> str: ...
1006+
def _escape_json(self, value: str) -> str: ...
10091007

1010-
@classmethod
1011-
def _escape_json(cls, value: Query | str) -> Query | str:
1008+
def _escape_json(self, value: Query | str) -> Query | str:
10121009
"""
10131010
Some engines need to add an extra escape to literals that contain JSON values. By default we don't do this
10141011
though
10151012
"""
1016-
if cls.ESCAPE_JSON:
1013+
if self.ESCAPE_JSON:
10171014
if isinstance(value, str):
10181015
return double_escape(value)
10191016
return t.cast(
@@ -1093,9 +1090,8 @@ def insert_overwrite_by_time_partition(
10931090
)
10941091
self._insert_overwrite_by_condition(table_name, source_queries, columns_to_types, where)
10951092

1096-
@classmethod
10971093
def _values_to_sql(
1098-
cls,
1094+
self,
10991095
values: t.List[PandasNamedTuple],
11001096
columns_to_types: t.Dict[str, exp.DataType],
11011097
batch_start: int,
@@ -1111,7 +1107,7 @@ def _values_to_sql(
11111107
alias=alias,
11121108
)
11131109
if contains_json:
1114-
query = t.cast(exp.Select, cls._escape_json(query))
1110+
query = t.cast(exp.Select, self._escape_json(query))
11151111
return query
11161112

11171113
def _insert_overwrite_by_condition(

sqlmesh/core/test/definition.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def setUp(self) -> None:
8282
rows = values["rows"]
8383
if not columns_to_types and rows:
8484
for i, v in rows[0].items():
85-
# convert ruamel into python
86-
v = v.real if hasattr(v, "real") else v
8785
v_type = annotate_types(exp.convert(v)).type or type(v).__name__
88-
columns_to_types[i] = exp.maybe_parse(v_type, into=exp.DataType)
86+
columns_to_types[i] = exp.maybe_parse(
87+
v_type, into=exp.DataType, dialect=self.dialect
88+
)
8989

9090
test_fixture_table = _fully_qualified_test_fixture_table(table_name, self.dialect)
9191
if test_fixture_table.db:
@@ -138,6 +138,7 @@ def _to_hashable(x: t.Any) -> t.Any:
138138
actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True)
139139
expected = expected.apply(lambda col: col.map(_to_hashable))
140140
expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True)
141+
141142
try:
142143
pd.testing.assert_frame_equal(
143144
expected,

tests/core/engine_adapter/test_base.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,9 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
842842
target_table="target",
843843
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
844844
columns_to_types={
845-
"ID": exp.DataType.Type.INT,
846-
"ts": exp.DataType.Type.TIMESTAMP,
847-
"val": exp.DataType.Type.INT,
845+
"ID": exp.DataType.build("int"),
846+
"ts": exp.DataType.build("timestamp"),
847+
"val": exp.DataType.build("int"),
848848
},
849849
unique_key=[exp.to_identifier("ID", quoted=True)],
850850
)
@@ -873,9 +873,9 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
873873
target_table="target",
874874
source_table=parse_one("SELECT id, ts, val FROM source"),
875875
columns_to_types={
876-
"id": exp.DataType.Type.INT,
877-
"ts": exp.DataType.Type.TIMESTAMP,
878-
"val": exp.DataType.Type.INT,
876+
"id": exp.DataType.build("int"),
877+
"ts": exp.DataType.build("timestamp"),
878+
"val": exp.DataType.build("int"),
879879
},
880880
unique_key=[exp.column("id"), exp.column("ts")],
881881
)
@@ -894,9 +894,9 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable):
894894
target_table="target",
895895
source_table=df,
896896
columns_to_types={
897-
"id": exp.DataType.Type.INT,
898-
"ts": exp.DataType.Type.TIMESTAMP,
899-
"val": exp.DataType.Type.INT,
897+
"id": exp.DataType.build("int"),
898+
"ts": exp.DataType.build("timestamp"),
899+
"val": exp.DataType.build("int"),
900900
},
901901
unique_key=[exp.to_identifier("id")],
902902
)
@@ -911,9 +911,9 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable):
911911
target_table="target",
912912
source_table=df,
913913
columns_to_types={
914-
"id": exp.DataType.Type.INT,
915-
"ts": exp.DataType.Type.TIMESTAMP,
916-
"val": exp.DataType.Type.INT,
914+
"id": exp.DataType.build("int"),
915+
"ts": exp.DataType.build("timestamp"),
916+
"val": exp.DataType.build("int"),
917917
},
918918
unique_key=[exp.to_identifier("id"), exp.to_identifier("ts")],
919919
)
@@ -931,9 +931,9 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e
931931
target_table="target",
932932
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
933933
columns_to_types={
934-
"ID": exp.DataType.Type.INT,
935-
"ts": exp.DataType.Type.TIMESTAMP,
936-
"val": exp.DataType.Type.INT,
934+
"ID": exp.DataType.build("int"),
935+
"ts": exp.DataType.build("timestamp"),
936+
"val": exp.DataType.build("int"),
937937
},
938938
unique_key=[exp.to_identifier("ID", quoted=True)],
939939
when_matched=exp.When(

tests/core/engine_adapter/test_databricks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def test_replace_query_pandas_exists(mocker: MockFixture, make_mocked_engine_ada
6565
)
6666
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter)
6767
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
68-
adapter.replace_query("test_table", df, {"a": "int", "b": "int"})
68+
adapter.replace_query(
69+
"test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}
70+
)
6971

7072
assert to_sql_calls(adapter) == [
7173
"INSERT OVERWRITE TABLE `test_table` (`a`, `b`) SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM VALUES (1, 4), (2, 5), (3, 6) AS `t`(`a`, `b`)",

tests/core/engine_adapter/test_redshift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_values_to_sql(adapter: t.Callable, mocker: MockerFixture):
166166
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
167167
result = adapter._values_to_sql(
168168
values=list(df.itertuples(index=False, name=None)),
169-
columns_to_types={"a": "int", "b": "int"},
169+
columns_to_types={"a": exp.DataType.build("int"), "b": exp.DataType.build("int")},
170170
batch_start=0,
171171
batch_end=2,
172172
)

tests/core/test_test.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -667,28 +667,38 @@ def test_source_func() -> None:
667667

668668

669669
def test_nested_data_types() -> None:
670+
raw = _create_model(
671+
"SELECT array::INT[], struct::STRUCT(x INT[], y VARCHAR, z INT, w STRUCT(a INT)) FROM sushi.unknown",
672+
meta="MODEL (name sushi.raw, kind FULL)",
673+
default_catalog="memory",
674+
)
675+
context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")))
676+
context.upsert_model(raw)
677+
670678
result = _create_test(
671679
body=load_yaml(
672680
"""
673681
test_foo:
674682
model: sushi.foo
675683
inputs:
676-
raw:
677-
- value: [1, 2, 3]
678-
- value:
684+
sushi.raw:
685+
- array: [1, 2, 3]
686+
struct: {'x': [1, 2, 3], 'y': 'foo', 'z': 1, 'w': {'a': 5}}
687+
- array:
679688
- 2
680689
- 3
681-
- value: [0, 4, 1]
690+
- array: [0, 4, 1]
682691
outputs:
683692
query:
684-
- value: [0, 4, 1]
685-
- value: [1, 2, 3]
686-
- value: [2, 3]
693+
- array: [0, 4, 1]
694+
- array: [1, 2, 3]
695+
struct: {'x': [1, 2, 3], 'y': 'foo', 'z': 1, 'w': {'a': 5}}
696+
- array: [2, 3]
687697
"""
688698
),
689699
test_name="test_foo",
690-
model=_create_model("SELECT value FROM raw"),
691-
context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))),
700+
model=_create_model("SELECT array, struct FROM sushi.raw", default_catalog="memory"),
701+
context=context,
692702
).run()
693703

694704
_check_successful_or_raise(result)

0 commit comments

Comments
 (0)