diff --git a/modelx/serialize/serializer_6.py b/modelx/serialize/serializer_6.py index 8bfc10ca..e0efd6d2 100644 --- a/modelx/serialize/serializer_6.py +++ b/modelx/serialize/serializer_6.py @@ -36,6 +36,12 @@ IOSpecUnpickler, ModelUnpickler, IOSpecPickler, ModelPickler) +PSEUDO_PYTHON_HEADER = """\ +# modelx: pseudo-python +# This file is part of a modelx model. +# It can be imported as a Python module, but functions defined herein +# are model formulas and may not be executable as standard Python.""" + class TupleID(tuple): @@ -455,7 +461,7 @@ def __init__(self, writer, ) def encode(self): - lines = [] + lines = [PSEUDO_PYTHON_HEADER] if self.model.doc is not None: lines.append("\"\"\"" + self.model.doc + "\"\"\"") @@ -513,7 +519,7 @@ def __init__(self, writer, target, srcpath=None): def encode(self): - lines = [] + lines = [PSEUDO_PYTHON_HEADER] if self.space.doc is not None: lines.append("\"\"\"" + self.space.doc + "\"\"\"") @@ -1037,12 +1043,11 @@ class DocstringParser(BaseNodeParser): """Docstring at module level""" @classmethod def condition(cls, stmt: StatementTokens): - # stmt has 1 element that is STRING and starts at line 1. + # stmt has 1 element that is STRING. if stmt.section == "DEFAULT" and len(stmt) == 1: elm = stmt[0] if elm.type == tokenize.STRING: - if elm.start[0] == 1: - return True + return True return False diff --git a/modelx/tests/serialize/test_serialize.py b/modelx/tests/serialize/test_serialize.py index a3cfe2b2..ae821c21 100644 --- a/modelx/tests/serialize/test_serialize.py +++ b/modelx/tests/serialize/test_serialize.py @@ -280,3 +280,30 @@ def test_false_value(tmp_path, write_method): s.a = False getattr(mx, write_method)(m, tmp_path / "model") m2 = mx.read_model(tmp_path / "model") + + +def test_pseudo_python_header(tmp_path): + from modelx.serialize.serializer_6 import PSEUDO_PYTHON_HEADER + m = mx.new_model("HeaderTest") + s = m.new_space("Space1") + + @mx.defcells + def foo(x): + return x + + mx.write_model(m, tmp_path / "model") + + # Check model __init__.py + model_init = (tmp_path / "model" / "__init__.py").read_text() + assert model_init.startswith(PSEUDO_PYTHON_HEADER) + + # Check space __init__.py + space_init = (tmp_path / "model" / "Space1" / "__init__.py").read_text() + assert space_init.startswith(PSEUDO_PYTHON_HEADER) + + # Verify the model can be read back correctly + m2 = mx.read_model(tmp_path / "model") + assert m2.name == "HeaderTest" + assert m2.Space1.foo(3) == 3 + m2.close() + m.close()