From f60cd1c68ec722e8bc697a0d41b8aa897512f462 Mon Sep 17 00:00:00 2001 From: Paul Prescod Date: Fri, 26 Nov 2021 13:27:37 -0800 Subject: [PATCH] Make failure to close a stream an error, as it would be by default. --- snowfakery/api.py | 13 +++---------- snowfakery/output_streams.py | 18 +++++++++++++++--- tests/test_embedding.py | 21 ++++++++------------- tests/test_output_streams.py | 24 ++++++++++++++++++++++++ 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/snowfakery/api.py b/snowfakery/api.py index 7524a0b2..5e5ec761 100644 --- a/snowfakery/api.py +++ b/snowfakery/api.py @@ -224,16 +224,9 @@ def configure_output_stream( try: yield output_stream finally: - try: - messages = output_stream.close() - except Exception as e: - messages = None - parent_application.echo( - f"Could not close {output_stream}: {str(e)}", err=True - ) - if messages: - for message in messages: - parent_application.echo(message) + messages = output_stream.close() or [] + for message in messages: + parent_application.echo(message) @contextmanager diff --git a/snowfakery/output_streams.py b/snowfakery/output_streams.py index c95f9a99..64e01b04 100644 --- a/snowfakery/output_streams.py +++ b/snowfakery/output_streams.py @@ -125,7 +125,7 @@ def close(self) -> Optional[Sequence[str]]: Return a list of messages to print out. """ - return super().close() + raise NotImplementedError() def __enter__(self, *args): return self @@ -550,7 +550,7 @@ def _render(self, dotfile, outfile): assert dotfile.exists() try: out = subprocess.Popen( - ["dot", "-T" + self.format, dotfile, "-o" + str(outfile)], + ["dot", "-T" + self.format, str(dotfile), "-o" + str(outfile)], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) @@ -578,8 +578,20 @@ def write_row(self, tablename: str, row_with_references: Dict) -> None: stream.write_row(tablename, row_with_references) def close(self) -> Optional[Sequence[str]]: + all_messages = [] + closing_errors = [] for stream in self.outputstreams: - stream.close() + try: + messages = stream.close() or [] + all_messages.extend(messages) + except Exception as e: + closing_errors.append(e) + + if len(closing_errors) == 1: + raise closing_errors[0] + elif closing_errors: + raise IOError(f"Could not close streams: {closing_errors}") + return all_messages def write_single_row(self, tablename: str, row: Dict) -> None: return super().write_single_row(tablename, row) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 062d406c..af22318b 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -63,7 +63,7 @@ def test_continuation_as_open_file(self): with mapping_file.open() as f: assert yaml.safe_load(f) - def test_parent_application__echo(self): + def test_parent_application__exception_raised(self): called = False class MyEmbedder(SnowfakeryApplication): @@ -74,10 +74,10 @@ def echo(self, *args, **kwargs): meth = "snowfakery.output_streams.DebugOutputStream.close" with mock.patch(meth) as close: close.side_effect = AssertionError - generate_data( - yaml_file="examples/company.yml", parent_application=MyEmbedder() - ) - assert called + with pytest.raises(AssertionError): + generate_data( + yaml_file="examples/company.yml", parent_application=MyEmbedder() + ) def test_parent_application__early_finish(self, generated_rows): class MyEmbedder(SnowfakeryApplication): @@ -89,14 +89,9 @@ def check_if_finished(self, idmanager): assert self.__class__.count < 100, "Runaway recipe!" return idmanager["Employee"] >= 10 - meth = "snowfakery.output_streams.DebugOutputStream.close" - with mock.patch(meth) as close: - close.side_effect = AssertionError - generate_data( - yaml_file="examples/company.yml", parent_application=MyEmbedder() - ) - # called 5 times, after generating 2 employees each - assert MyEmbedder.count == 5 + generate_data(yaml_file="examples/company.yml", parent_application=MyEmbedder()) + # called 5 times, after generating 2 employees each + assert MyEmbedder.count == 5 def test_embedding__cannot_infer_output_format(self): with pytest.raises(exc.DataGenError, match="No format"): diff --git a/tests/test_output_streams.py b/tests/test_output_streams.py index b4b0b313..7ed8e4b5 100644 --- a/tests/test_output_streams.py +++ b/tests/test_output_streams.py @@ -6,6 +6,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from contextlib import redirect_stdout +from unittest import mock import pytest @@ -375,3 +376,26 @@ def test_external_output_stream__failure(self): generate_cli.callback( yaml_file=sample_yaml, output_format="no.such.output.Stream" ) + + +class TestMultiplexOutputStream: + @mock.patch("snowfakery.output_streams.DebugOutputStream.close") + def test_cannot_close_multiple_streams(self, close): + close.side_effect = AssertionError + with TemporaryDirectory() as t: + files = [Path(t) / "1.txt", Path(t) / "2.txt"] + with pytest.raises(IOError) as e: + generate_cli.callback( + yaml_file="examples/company.yml", output_files=files + ) + assert "Could not close streams:" in str(e.value) + + @mock.patch("snowfakery.output_streams.DebugOutputStream.close") + def test_cannot_close_one_stream(self, close): + close.side_effect = AssertionError + with TemporaryDirectory() as t: + files = [Path(t) / "1.txt", Path(t) / "2.jpg"] + with pytest.raises(AssertionError): + generate_cli.callback( + yaml_file="examples/company.yml", output_files=files + )