diff --git a/src/controller.py b/src/controller.py index d31ec9c..7f78e60 100644 --- a/src/controller.py +++ b/src/controller.py @@ -1,3 +1,4 @@ +from src.validation import validate_extracted_data from src.file_manipulator import FileManipulator class Controller: @@ -5,7 +6,12 @@ def __init__(self): self.file_manipulator = FileManipulator() def fill_form(self, user_input: str, fields: list, pdf_form_path: str): - return self.file_manipulator.fill_form(user_input, fields, pdf_form_path) + data = self.file_manipulator.fill_form(user_input, fields, pdf_form_path) + + if not validate_extracted_data(data): + raise ValueError("Invalid extracted data") + + return data def create_template(self, pdf_path: str): - return self.file_manipulator.create_template(pdf_path) \ No newline at end of file + return self.file_manipulator.create_template(pdf_path) diff --git a/src/validation.py b/src/validation.py new file mode 100644 index 0000000..0c41938 --- /dev/null +++ b/src/validation.py @@ -0,0 +1,13 @@ +def validate_extracted_data(data: dict) -> bool: + """ + Basic validation for extracted form data. + Ensures required fields are present and non-empty. + """ + + required_fields = ["patient_name", "age", "diagnosis"] + + for field in required_fields: + if field not in data or not data[field]: + return False + + return True diff --git a/tests/test_controller.py b/tests/test_controller.py new file mode 100644 index 0000000..a182333 --- /dev/null +++ b/tests/test_controller.py @@ -0,0 +1,16 @@ +from src.controller import Controller + + +def test_fill_form_validation_fail(monkeypatch): + controller = Controller() + + def mock_fill_form(user_input, fields, pdf_form_path): + return {"patient_name": "", "age": 30, "diagnosis": "Flu"} + + monkeypatch.setattr(controller.file_manipulator, "fill_form", mock_fill_form) + + try: + controller.fill_form("input", [], "file.pdf") + assert False # Should not reach here + except ValueError: + assert True diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..1f30f7a --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,28 @@ +import pytest +from src.validation import validate_extracted_data + + +def test_valid_data(): + data = { + "patient_name": "John Doe", + "age": 30, + "diagnosis": "Flu" + } + assert validate_extracted_data(data) == True + + +def test_missing_field(): + data = { + "patient_name": "John Doe", + "age": 30 + } + assert validate_extracted_data(data) == False + + +def test_empty_field(): + data = { + "patient_name": "", + "age": 30, + "diagnosis": "Flu" + } + assert validate_extracted_data(data) == False