Skip to content

Commit 7c95bc1

Browse files
authored
fix: ensure unit tests dtypes match (#1503)
1 parent 31f5a7e commit 7c95bc1

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

sqlmesh/core/test/definition.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,24 @@ def tearDown(self) -> None:
9090
for table in self.body.get("inputs", {}):
9191
self.engine_adapter.drop_view(table)
9292

93-
def assert_equal(self, df1: pd.DataFrame, df2: pd.DataFrame) -> None:
93+
def assert_equal(self, expected: pd.DataFrame, actual: pd.DataFrame) -> None:
9494
"""Compare two DataFrames"""
95-
df1 = df1.replace({np.nan: None, "nan": None})
96-
df2 = df2.replace({np.nan: None, "nan": None})
95+
expected = expected.astype(actual.dtypes.to_dict())
96+
expected = expected.replace({np.nan: None, "nan": None})
97+
actual = actual.replace({np.nan: None, "nan": None})
9798

9899
try:
99100
pd.testing.assert_frame_equal(
100-
df1.sort_index(axis=1),
101-
df2.sort_index(axis=1),
101+
expected.sort_index(axis=1),
102+
actual.sort_index(axis=1),
102103
check_dtype=False,
103104
check_datetimelike_compat=True,
104105
)
105106
except AssertionError as e:
106107
diff = "\n".join(
107108
difflib.ndiff(
108-
[str(x) for x in df1.to_dict("records")],
109-
[str(x) for x in df2.to_dict("records")],
109+
[str(x) for x in expected.to_dict("records")],
110+
[str(x) for x in actual.to_dict("records")],
110111
)
111112
)
112113
e.args = (f"Data differs\n{diff}",)

0 commit comments

Comments
 (0)