diff --git a/.flake8 b/.flake8 deleted file mode 100644 index e0ea542fd..000000000 --- a/.flake8 +++ /dev/null @@ -1,3 +0,0 @@ -[flake8] -max-line-length = 88 -extend-ignore = E203 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 8c598538c..000000000 --- a/Dockerfile +++ /dev/null @@ -1,91 +0,0 @@ -FROM continuumio/miniconda3 -LABEL maintainer="lachlan.grose@monash.edu" -#This docker image has been adapted from the lavavu dockerfile -# install things - -RUN apt-get update -qq && \ - DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends \ - gcc \ - g++ \ - libc-dev \ - gfortran \ - openmpi-bin \ - libopenmpi-dev \ - make -# RUN conda install -c conda-forge python=3.9 -y -RUN conda install -c conda-forge -c loop3d\ - pip \ - map2model\ - hjson\ - owslib\ - beartype\ - gdal=3.5.2\ - rasterio=1.2.10 \ - meshio\ - scikit-learn \ - cython \ - numpy \ - pandas \ - scipy \ - pymc3 \ - jupyter \ - pyamg \ - # arviz==0.11.0 \ - pygraphviz \ - geopandas \ - shapely \ - ipywidgets \ - ipyleaflet \ - folium \ - jupyterlab \ - nodejs \ - rasterio\ - geopandas\ - -y - -RUN pip install ipyfilechooser -RUN jupyter nbextension enable --py --sys-prefix ipyleaflet -RUN pip install lavavu-osmesa mplstereonet - -ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/lavavu/ - - -ENV NB_USER jovyan -ENV NB_UID 1000 -ENV HOME /home/${NB_USER} - -RUN adduser --disabled-password \ - --gecos "Default user" \ - --uid ${NB_UID} \ - ${NB_USER} -WORKDIR ${HOME} - -USER root -RUN chown -R ${NB_UID} ${HOME} - -RUN pip install snakeviz - -# Add Tini -ENV TINI_VERSION v0.19.0 -ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /tini -RUN chmod +x /tini -ENTRYPOINT ["/tini", "--"] - -USER ${NB_USER} - -RUN mkdir notebooks -RUN git clone https://github.com/Loop3D/map2loop-2.git map2loop -RUN git clone https://github.com/Loop3D/LoopProjectFile.git -RUN git clone https://github.com/TOMOFAST/Tomofast-x.git -RUN pip install LoopStructural -RUN pip install -e map2loop -RUN pip install -e LoopProjectFile -# WORKDIR Tomofast-x -# RUN make -WORKDIR ${HOME}/notebooks - -# RUN pip install -e LoopStructural -CMD ["jupyter", "lab", "--ip='0.0.0.0'", "--NotebookApp.token=''", "--no-browser" ] - -EXPOSE 8050 -EXPOSE 8080:8090 \ No newline at end of file diff --git a/DockerfileDev b/DockerfileDev deleted file mode 100644 index 7bfb6aade..000000000 --- a/DockerfileDev +++ /dev/null @@ -1,80 +0,0 @@ -FROM continuumio/miniconda3 -LABEL maintainer="lachlan.grose@monash.edu" -#This docker image has been adapted from the lavavu dockerfile -# install things - -RUN apt-get update -qq && \ - DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends \ - gcc \ - g++ \ - libc-dev \ - gfortran \ - openmpi-bin \ - libopenmpi-dev \ - make -# RUN conda install -c conda-forge python=3.9 -y -RUN conda install -c conda-forge "python<=3.9" \ - pip \ - scikit-learn \ - cython \ - numpy \ - pandas \ - scipy \ - pymc3 \ - jupyter \ - pyamg \ - # arviz==0.11.0 \ - pygraphviz \ - geopandas \ - shapely \ - ipywidgets \ - ipyleaflet \ - folium \ - jupyterlab \ - nodejs \ - rasterio\ - -y - -RUN pip install ipyfilechooser -RUN jupyter nbextension enable --py --sys-prefix ipyleaflet -RUN pip install lavavu-osmesa==1.8.32 pyevtk - -ENV NB_USER jovyan -ENV NB_UID 1000 -ENV HOME /home/${NB_USER} - -RUN adduser --disabled-password \ - --gecos "Default user" \ - --uid ${NB_UID} \ - ${NB_USER} -WORKDIR ${HOME} - -USER root -RUN chown -R ${NB_UID} ${HOME} - -RUN pip install snakeviz - -# Add Tini -ENV TINI_VERSION v0.19.0 -ADD https://github.com/krallin/tini/releases/download/${TINI_VERSION}/tini /tini -RUN chmod +x /tini -ENTRYPOINT ["/tini", "--"] - -USER ${NB_USER} - -RUN mkdir notebooks -RUN git clone https://github.com/Loop3D/LoopStructural.git -RUN git clone https://github.com/Loop3D/map2loop-2.git map2loop -RUN git clone https://github.com/Loop3D/LoopProjectFile.git -RUN git clone https://github.com/TOMOFAST/Tomofast-x.git -RUN pip install -e LoopStructural -RUN pip install -e map2loop -RUN pip install -e LoopProjectFile -# WORKDIR Tomofast-x -# RUN make -WORKDIR ${HOME} -# RUN pip install -e LoopStructural -CMD ["jupyter", "notebook", "--ip='0.0.0.0'", "--NotebookApp.token=''", "--no-browser" ] - -EXPOSE 8050 -EXPOSE 8080:8090 diff --git a/LoopStructural/__init__.py b/LoopStructural/__init__.py index ca6c3855f..2fc1e7596 100644 --- a/LoopStructural/__init__.py +++ b/LoopStructural/__init__.py @@ -20,6 +20,7 @@ loggers = {} from .modelling.core.geological_model import GeologicalModel from .modelling.core.stratigraphic_column import StratigraphicColumn +from .modelling.core.fault_topology import FaultTopology from .interpolators._api import LoopInterpolator from .interpolators import InterpolatorBuilder from .datatypes import BoundingBox diff --git a/LoopStructural/modelling/core/fault_topology.py b/LoopStructural/modelling/core/fault_topology.py new file mode 100644 index 000000000..33ab88932 --- /dev/null +++ b/LoopStructural/modelling/core/fault_topology.py @@ -0,0 +1,234 @@ +from ..features.fault import FaultSegment +from ...utils import Observable +from .stratigraphic_column import StratigraphicColumn +import enum +import numpy as np +class FaultRelationshipType(enum.Enum): + ABUTTING = "abutting" + FAULTED = "faulted" + NONE = "none" + +class FaultTopology(Observable['FaultTopology']): + """A graph representation of the relationships between faults and the + relationship with stratigraphic units. + """ + def __init__(self, stratigraphic_column: 'StratigraphicColumn'): + super().__init__() + self.faults = [] + self.stratigraphic_column = stratigraphic_column + self.adjacency = {} + self.stratigraphy_fault_relationships = {} + def add_fault(self, fault: FaultSegment): + """ + Adds a fault to the fault topology. + """ + if not isinstance(fault, str): + raise TypeError("Expected a fault name.") + + self.faults.append(fault) + self.notify('fault_added', fault=fault) + + def remove_fault(self, fault: str): + """ + Removes a fault from the fault topology. + """ + if fault not in self.faults: + raise ValueError(f"Fault {fault} not found in the topology.") + + self.faults.remove(fault) + # Remove any relationships involving this fault + self.adjacency = {k: v for k, v in self.adjacency.items() if fault not in k} + self.stratigraphy_fault_relationships = { + k: v for k, v in self.stratigraphy_fault_relationships.items() if k[1] != fault + } + self.notify('fault_removed', fault=fault) + + def add_abutting_relationship(self, fault_name: str, abutting_fault: str): + """ + Adds an abutting relationship between two faults. + """ + if fault_name not in self.faults or abutting_fault not in self.faults: + raise ValueError("Both faults must be part of the fault topology.") + + if fault_name not in self.adjacency: + self.adjacency[fault_name] = [] + + self.adjacency[(fault_name, abutting_fault)] = FaultRelationshipType.ABUTTING + self.notify('abutting_relationship_added', {'fault': fault_name, 'abutting_fault': abutting_fault}) + def add_stratigraphy_fault_relationship(self, unit_name:str, fault_name: str): + """ + Adds a relationship between a stratigraphic unit and a fault. + """ + if fault_name not in self.faults: + raise ValueError("Fault must be part of the fault topology.") + + if unit_name is None: + raise ValueError(f"No stratigraphic group found for unit name: {unit_name}") + self.stratigraphy_fault_relationships[(unit_name,fault_name)] = True + + self.notify('stratigraphy_fault_relationship_added', {'unit': unit_name, 'fault': fault_name}) + def add_faulted_relationship(self, fault_name: str, faulted_fault_name: str): + """ + Adds a faulted relationship between two faults. + """ + if fault_name not in self.faults or faulted_fault_name not in self.faults: + raise ValueError("Both faults must be part of the fault topology.") + + if fault_name not in self.adjacency: + self.adjacency[fault_name] = [] + + self.adjacency[(fault_name, faulted_fault_name)] = FaultRelationshipType.FAULTED + self.notify('faulted_relationship_added', {'fault': fault_name, 'faulted_fault': faulted_fault_name}) + def remove_fault_relationship(self, fault_name: str, related_fault_name: str): + """ + Removes a relationship between two faults. + """ + if (fault_name, related_fault_name) in self.adjacency: + del self.adjacency[(fault_name, related_fault_name)] + elif (related_fault_name, fault_name) in self.adjacency: + del self.adjacency[(related_fault_name, fault_name)] + else: + raise ValueError(f"No relationship found between {fault_name} and {related_fault_name}.") + self.notify('fault_relationship_removed', {'fault': fault_name, 'related_fault': related_fault_name}) + def update_fault_relationship(self, fault_name: str, related_fault_name: str, new_relationship_type: FaultRelationshipType): + if new_relationship_type == FaultRelationshipType.NONE: + self.adjacency.pop((fault_name, related_fault_name), None) + else: + self.adjacency[(fault_name, related_fault_name)] = new_relationship_type + self.notify('fault_relationship_updated', {'fault': fault_name, 'related_fault': related_fault_name, 'new_relationship_type': new_relationship_type}) + def change_relationship_type(self, fault_name: str, related_fault_name: str, new_relationship_type: FaultRelationshipType): + """ + Changes the relationship type between two faults. + """ + if (fault_name, related_fault_name) in self.adjacency: + self.adjacency[(fault_name, related_fault_name)] = new_relationship_type + + else: + raise ValueError(f"No relationship found between {fault_name} and {related_fault_name}.") + self.notify('relationship_type_changed', {'fault': fault_name, 'related_fault': related_fault_name, 'new_relationship_type': new_relationship_type}) + def get_fault_relationships(self, fault_name: str): + """ + Returns a list of relationships for a given fault. + """ + relationships = [] + for (f1, f2), relationship_type in self.adjacency.items(): + if f1 == fault_name or f2 == fault_name: + relationships.append((f1, f2, relationship_type)) + return relationships + def get_fault_relationship(self, fault_name: str, related_fault_name: str): + """ + Returns the relationship type between two faults. + """ + return self.adjacency.get((fault_name, related_fault_name), FaultRelationshipType.NONE) + def get_faults(self): + """ + Returns a list of all faults in the topology. + """ + return self.faults + + def get_stratigraphy_fault_relationships(self): + """ + Returns a dictionary of stratigraphic unit to fault relationships. + """ + return self.stratigraphy_fault_relationships + def get_fault_stratigraphic_unit_relationships(self): + units_group_pairs = self.stratigraphic_column.get_group_unit_pairs() + matrix = np.zeros((len(self.faults), len(units_group_pairs)), dtype=int) + for i, fault in enumerate(self.faults): + for j, (unit_name, _group) in enumerate(units_group_pairs): + if (unit_name, fault) in self.stratigraphy_fault_relationships: + matrix[i, j] = 1 + + return matrix + def get_fault_stratigraphic_relationship(self, unit_name: str, fault:str) -> bool: + """ + Returns a dictionary of fault to stratigraphic unit relationships. + """ + if unit_name is None: + raise ValueError(f"No stratigraphic group found for unit name: {unit_name}") + if (unit_name, fault) not in self.stratigraphy_fault_relationships: + return False + return self.stratigraphy_fault_relationships[(unit_name, fault)] + + def update_fault_stratigraphy_relationship(self, unit_name: str, fault_name: str, flag: bool = True): + """ + Updates the relationship between a stratigraphic unit and a fault. + """ + if not flag: + if (unit_name, fault_name) in self.stratigraphy_fault_relationships: + del self.stratigraphy_fault_relationships[(unit_name, fault_name)] + else: + self.stratigraphy_fault_relationships[(unit_name, fault_name)] = flag + + self.notify('stratigraphy_fault_relationship_updated', {'unit': unit_name, 'fault': fault_name}) + + def remove_fault_stratigraphy_relationship(self, unit_name: str, fault_name: str): + """ + Removes a relationship between a stratigraphic unit and a fault. + """ + if (unit_name, fault_name) not in self.stratigraphy_fault_relationships: + raise ValueError(f"No relationship found between unit {unit_name} and fault {fault_name}.") + else: + self.stratigraphy_fault_relationships.pop((unit_name, fault_name), None) + + self.notify('stratigraphy_fault_relationship_removed', {'unit': unit_name, 'fault': fault_name}) + def get_matrix(self): + """ + Returns a matrix representation of the fault relationships. + """ + matrix = np.zeros((len(self.faults), len(self.faults)), dtype=int) + for (fault_name, related_fault_name), relationship_type in self.adjacency.items(): + fault_index = self.faults.index(next(f for f in self.faults if f == fault_name)) + related_fault_index = self.faults.index(next(f for f in self.faults if f == related_fault_name)) + if relationship_type == FaultRelationshipType.ABUTTING: + matrix[fault_index, related_fault_index] = 1 + elif relationship_type == FaultRelationshipType.FAULTED: + matrix[fault_index, related_fault_index] = 2 + return matrix + + def to_dict(self): + """ + Returns a dictionary representation of the fault topology. + """ + return { + "faults": self.faults, + "adjacency": self.adjacency, + "stratigraphy_fault_relationships": self.stratigraphy_fault_relationships, + } + + def update_from_dict(self, data): + """ + Updates the fault topology from a dictionary representation. + """ + with self.freeze_notifications(): + self.faults.extend(data.get("faults", [])) + adjacency = data.get("adjacency", {}) + stratigraphy_fault_relationships = data.get("stratigraphy_fault_relationships", {}) + for (fault,abutting_fault) in adjacency.values(): + if fault not in self.faults: + self.add_fault(fault) + if abutting_fault not in self.faults: + self.add_fault(abutting_fault) + self.add_abutting_relationship(fault, abutting_fault) + for unit_name, fault_names in stratigraphy_fault_relationships.items(): + for fault_name in fault_names: + if fault_name not in self.faults: + self.add_fault(fault_name) + self.add_stratigraphy_fault_relationship(unit_name, fault_name) + + @classmethod + def from_dict(cls, data): + """ + Creates a FaultTopology instance from a dictionary representation. + """ + from .stratigraphic_column import StratigraphicColumn + stratigraphic_column = data.get("stratigraphic_column",None) + if not isinstance(stratigraphic_column, StratigraphicColumn): + if isinstance(stratigraphic_column, dict): + stratigraphic_column = StratigraphicColumn.from_dict(stratigraphic_column) + elif not isinstance(stratigraphic_column, StratigraphicColumn): + raise TypeError("Expected 'stratigraphic_column' to be a StratigraphicColumn instance or dict.") + + topology = cls(stratigraphic_column) + topology.update_from_dict(data) + return topology diff --git a/LoopStructural/modelling/core/geological_model.py b/LoopStructural/modelling/core/geological_model.py index c7ec03651..b45619017 100644 --- a/LoopStructural/modelling/core/geological_model.py +++ b/LoopStructural/modelling/core/geological_model.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from typing import List, Optional +from typing import List, Optional, Union, Dict import pathlib from ...modelling.features.fault import FaultSegment @@ -123,8 +123,7 @@ def __init__(self, *args): self.feature_name_index = {} self._data = pd.DataFrame() # None - self.stratigraphic_column = StratigraphicColumn() - + self._stratigraphic_column = StratigraphicColumn() self.tol = 1e-10 * np.max(self.bounding_box.maximum - self.bounding_box.origin) self._dtm = None @@ -187,7 +186,6 @@ def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame: ].astype(float) return data - if "type" in data: logger.warning("'type' is deprecated replace with 'feature_name' \n") data.rename(columns={"type": "feature_name"}, inplace=True) @@ -409,7 +407,6 @@ def fault_names(self): """ return [f.name for f in self.faults] - def to_file(self, file): """Save a model to a pickle file requires dill @@ -506,10 +503,34 @@ def data(self, data: pd.DataFrame): self._data = data.copy() # self._data[['X','Y','Z']] = self.bounding_box.project(self._data[['X','Y','Z']].to_numpy()) - def set_model_data(self, data): logger.warning("deprecated method. Model data can now be set using the data attribute") self.data = data.copy() + @property + def stratigraphic_column(self): + """Get the stratigraphic column of the model + + Returns + ------- + StratigraphicColumn + the stratigraphic column of the model + """ + return self._stratigraphic_column + @stratigraphic_column.setter + def stratigraphic_column(self, stratigraphic_column: Union[StratigraphicColumn,Dict]): + """Set the stratigraphic column of the model + + Parameters + ---------- + stratigraphic_column : StratigraphicColumn + the stratigraphic column to set + """ + if isinstance(stratigraphic_column, dict): + self.set_stratigraphic_column(stratigraphic_column) + return + elif not isinstance(stratigraphic_column, StratigraphicColumn): + raise ValueError("stratigraphic_column must be a StratigraphicColumn object") + self._stratigraphic_column = stratigraphic_column def set_stratigraphic_column(self, stratigraphic_column, cmap="tab20"): """ @@ -1400,7 +1421,6 @@ def rescale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray: return self.bounding_box.reproject(points, inplace=inplace) - # TODO move scale to bounding box/transformer def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray: """Take points in UTM coordinates and reproject @@ -1419,7 +1439,6 @@ def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray: """ return self.bounding_box.project(np.array(points).astype(float), inplace=inplace) - def regular_grid(self, *, nsteps=None, shuffle=True, rescale=False, order="C"): """ Return a regular grid within the model bounding box @@ -1494,22 +1513,18 @@ def evaluate_model(self, xyz: np.ndarray, *, scale: bool = True) -> np.ndarray: if self.stratigraphic_column is None: logger.warning("No stratigraphic column defined") return strat_id - for group in reversed(self.stratigraphic_column.keys()): - if group == "faults": - continue - feature_id = self.feature_name_index.get(group, -1) + + s_id = 0 + for g in reversed(self.stratigraphic_column.get_groups()): + feature_id = self.feature_name_index.get(g.name, -1) if feature_id >= 0: - feature = self.features[feature_id] - vals = feature.evaluate_value(xyz) - for series in self.stratigraphic_column[group].values(): - strat_id[ - np.logical_and( - vals < series.get("max", feature.max()), - vals > series.get("min", feature.min()), - ) - ] = series["id"] + vals = self.features[feature_id].evaluate_value(xyz) + for u in g.units: + strat_id[np.logical_and(vals < u.max, vals > u.min)] = s_id + s_id += 1 if feature_id == -1: - logger.error(f"Model does not contain {group}") + logger.error(f"Model does not contain {g.name}") + return strat_id def evaluate_model_gradient(self, points: np.ndarray, *, scale: bool = True) -> np.ndarray: diff --git a/LoopStructural/modelling/core/stratigraphic_column.py b/LoopStructural/modelling/core/stratigraphic_column.py index 02c8caf53..44394fd89 100644 --- a/LoopStructural/modelling/core/stratigraphic_column.py +++ b/LoopStructural/modelling/core/stratigraphic_column.py @@ -1,8 +1,7 @@ import enum -from typing import Dict +from typing import Dict, Optional, List, Tuple import numpy as np -from LoopStructural.utils import rng, getLogger - +from LoopStructural.utils import rng, getLogger, Observable logger = getLogger(__name__) logger.info("Imported LoopStructural Stratigraphic Column module") class UnconformityType(enum.Enum): @@ -154,7 +153,7 @@ def __init__(self, name=None, units=None): self.units = units if units is not None else [] -class StratigraphicColumn: +class StratigraphicColumn(Observable['StratigraphicColumn']): """ A class to represent a stratigraphic column, which is a vertical section of the Earth's crust showing the sequence of rock layers and their relationships. @@ -164,6 +163,7 @@ def __init__(self): """ Initializes the StratigraphicColumn with a name and a list of layers. """ + super().__init__() self.order = [StratigraphicUnit(name='Basement', colour='grey', thickness=np.inf),StratigraphicUnconformity(name='Base Unconformity', unconformity_type=UnconformityType.ERODE)] self.group_mapping = {} def clear(self,basement=True): @@ -175,7 +175,7 @@ def clear(self,basement=True): else: self.order = [] self.group_mapping = {} - + self.notify('column_cleared') def add_unit(self, name,*, colour=None, thickness=None, where='top'): unit = StratigraphicUnit(name=name, colour=colour, thickness=thickness) @@ -185,7 +185,7 @@ def add_unit(self, name,*, colour=None, thickness=None, where='top'): self.order.insert(0, unit) else: raise ValueError("Invalid 'where' argument. Use 'top' or 'bottom'.") - + self.notify('unit_added', unit=unit) return unit def remove_unit(self, uuid): @@ -195,7 +195,9 @@ def remove_unit(self, uuid): for i, element in enumerate(self.order): if element.uuid == uuid: del self.order[i] + self.notify('unit_removed', uuid=uuid) return True + return False def add_unconformity(self, name, *, unconformity_type=UnconformityType.ERODE, where='top' ): @@ -209,6 +211,7 @@ def add_unconformity(self, name, *, unconformity_type=UnconformityType.ERODE, wh self.order.insert(0, unconformity) else: raise ValueError("Invalid 'where' argument. Use 'top' or 'bottom'.") + self.notify('unconformity_added', unconformity=unconformity) return unconformity def get_element_by_index(self, index): @@ -228,6 +231,7 @@ def get_unit_by_name(self, name): return unit return None + def get_unconformity_by_name(self, name): """ Retrieves an unconformity by its name from the stratigraphic column. @@ -245,6 +249,15 @@ def get_element_by_uuid(self, uuid): if element.uuid == uuid: return element raise KeyError(f"No element found with uuid: {uuid}") + + def get_group_for_unit_name(self, unit_name:str) -> Optional[StratigraphicGroup]: + """ + Retrieves the group for a given unit name. + """ + for group in self.get_groups(): + if any(unit.name == unit_name for unit in group.units): + return group + return None def add_element(self, element): """ Adds a StratigraphicColumnElement to the stratigraphic column. @@ -296,7 +309,18 @@ def get_unitname_groups(self): group = [u.name for u in g.units if isinstance(u, StratigraphicUnit)] groups_list.append(group) return groups_list - + + def get_group_unit_pairs(self) -> List[Tuple[str,str]]: + """ + Returns a list of tuples containing group names and unit names. + """ + groups = self.get_groups() + group_unit_pairs = [] + for g in groups: + for u in g.units: + if isinstance(u, StratigraphicUnit): + group_unit_pairs.append((g.name, u.name)) + return group_unit_pairs def __getitem__(self, uuid): """ @@ -316,6 +340,7 @@ def update_order(self, new_order): self.order = [ self.__getitem__(uuid) for uuid in new_order if self.__getitem__(uuid) is not None ] + self.notify('order_updated', new_order=self.order) def update_element(self, unit_data: Dict): """ @@ -334,6 +359,7 @@ def update_element(self, unit_data: Dict): element.unconformity_type = UnconformityType( unit_data.get('unconformity_type', element.unconformity_type.value) ) + self.notify('element_updated', element=element) def __str__(self): """ @@ -354,14 +380,15 @@ def update_from_dict(self, data): """ if not isinstance(data, dict): raise TypeError("Data must be a dictionary") - self.clear(basement=False) - elements_data = data.get("elements", []) - for element_data in elements_data: - if "unconformity_type" in element_data: - element = StratigraphicUnconformity.from_dict(element_data) - else: - element = StratigraphicUnit.from_dict(element_data) - self.add_element(element) + with self.freeze_notifications(): + self.clear(basement=False) + elements_data = data.get("elements", []) + for element_data in elements_data: + if "unconformity_type" in element_data: + element = StratigraphicUnconformity.from_dict(element_data) + else: + element = StratigraphicUnit.from_dict(element_data) + self.add_element(element) @classmethod def from_dict(cls, data): """ diff --git a/LoopStructural/utils/__init__.py b/LoopStructural/utils/__init__.py index fab47c92e..0aaab4099 100644 --- a/LoopStructural/utils/__init__.py +++ b/LoopStructural/utils/__init__.py @@ -38,3 +38,4 @@ from ._surface import LoopIsosurfacer, surface_list from .colours import random_colour, random_hex_colour +from .observer import Callback, Disposable, Observable \ No newline at end of file diff --git a/LoopStructural/utils/_surface.py b/LoopStructural/utils/_surface.py index efbf21038..5af1d7e2b 100644 --- a/LoopStructural/utils/_surface.py +++ b/LoopStructural/utils/_surface.py @@ -115,12 +115,17 @@ def fit( values, ) logger.info(f'Isosurfacing at values: {isovalues}') + individual_names = False if name is None: names = ["surface"] * len(isovalues) if isinstance(name, str): names = [name] * len(isovalues) + if len(isovalues) == 1: + individual_names = True if isinstance(name, list): names = name + if len(names) == len(isovalues): + individual_names = True if colours is None: colours = [None] * len(isovalues) for name, isovalue, colour in zip(names, isovalues, colours): @@ -151,7 +156,7 @@ def fit( vertices=verts, triangles=faces, normals=normals, - name=f"{name}_{isovalue}", + name=name if individual_names else f"{name}_{isovalue}", values=values, colour=colour, ) diff --git a/LoopStructural/utils/observer.py b/LoopStructural/utils/observer.py new file mode 100644 index 000000000..77657c0b8 --- /dev/null +++ b/LoopStructural/utils/observer.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from collections.abc import Callable +from contextlib import contextmanager +from typing import Any, Generic, Protocol, TypeAlias, TypeVar, runtime_checkable +import threading +import weakref + +__all__ = ["Observer", "Observable", "Disposable"] + + +@runtime_checkable +class Observer(Protocol): + """Objects implementing an *update* method can subscribe.""" + + def update(self, observable: "Observable", event: str, *args: Any, **kwargs: Any) -> None: + """Receive a notification.""" + + +Callback: TypeAlias = Callable[["Observable", str, Any], None] +T = TypeVar("T", bound="Observable") + + +class Disposable: + """A small helper that detaches an observer when disposed.""" + + __slots__ = ("_detach",) + + def __init__(self, detach: Callable[[], None]): + self._detach = detach + + def dispose(self) -> None: + """Detach the associated observer immediately.""" + + self._detach() + + # Allow use as a context‑manager for temporary subscriptions + def __enter__(self) -> "Disposable": + return self + + def __exit__(self, exc_type, exc, tb): + self.dispose() + return False # do not swallow exceptions + + +class Observable(Generic[T]): + """Base‑class that provides Observer pattern plumbing.""" + + #: Internal storage: mapping *event* → WeakSet[Callback] + _observers: dict[str, weakref.WeakSet[Callback]] + _any_observers: weakref.WeakSet[Callback] + + def __init__(self) -> None: + self._lock = threading.RLock() + self._observers = {} + self._any_observers = weakref.WeakSet() + self._frozen = 0 + self._pending: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = [] + + # ‑‑‑ subscription api -------------------------------------------------- + def attach(self, listener: Observer | Callback, event: str | None = None) -> Disposable: + """Register *listener* for *event* (all events if *event* is None). + + Returns a :class:`Disposable` so the caller can easily detach again. + """ + callback: Callback = ( + listener.update # type: ignore[attr‑defined] + if isinstance(listener, Observer) # type: ignore[misc] + else listener # already a callable + ) + + with self._lock: + if event is None: + self._any_observers.add(callback) + else: + self._observers.setdefault(event, weakref.WeakSet()).add(callback) + + return Disposable(lambda: self.detach(listener, event)) + + def detach(self, listener: Observer | Callback, event: str | None = None) -> None: + """Unregister a previously attached *listener*.""" + + callback: Callback = ( + listener.update # type: ignore[attr‑defined] + if isinstance(listener, Observer) # type: ignore[misc] + else listener + ) + + with self._lock: + if event is None: + self._any_observers.discard(callback) + for s in self._observers.values(): + s.discard(callback) + else: + self._observers.get(event, weakref.WeakSet()).discard(callback) + def __getstate__(self): + state = self.__dict__.copy() + state.pop('_lock', None) # RLock cannot be pickled + state.pop('_observers', None) # WeakSet cannot be pickled + state.pop('_any_observers', None) + return state + def __setstate__(self, state): + self.__dict__.update(state) + self._lock = threading.RLock() + self._observers = {} + self._any_observers = weakref.WeakSet() + self._frozen = 0 + # ‑‑‑ notification api -------------------------------------------------- + def notify(self: T, event: str, *args: Any, **kwargs: Any) -> None: + """Notify observers that *event* happened.""" + + with self._lock: + if self._frozen: + # defer until freeze_notifications() exits + self._pending.append((event, args, kwargs)) + return + + observers = list(self._any_observers) + observers.extend(self._observers.get(event, ())) + + # Call outside lock — prevent deadlocks if observers trigger other + # notifications. + for cb in observers: + try: + cb(self, event, *args, **kwargs) + except Exception: # pragma: no cover + # Optionally log; never allow an observer error to break flow. + import logging + + logging.getLogger(__name__).exception( + "Unhandled error in observer %s for event %s", cb, event + ) + + # ‑‑‑ batching ---------------------------------------------------------- + @contextmanager + def freeze_notifications(self): + """Context manager that batches notifications until exit.""" + + with self._lock: + self._frozen += 1 + try: + yield self + finally: + with self._lock: + self._frozen -= 1 + if self._frozen == 0 and self._pending: + pending = self._pending[:] + self._pending.clear() + for event, args, kw in pending: # type: ignore[has‑type] + self.notify(event, *args, **kw) diff --git a/docker-compose-win.yml b/docker-compose-win.yml deleted file mode 100644 index 13596a74f..000000000 --- a/docker-compose-win.yml +++ /dev/null @@ -1,15 +0,0 @@ -version: "3" - -services: - structural: - build: - context: ./ - dockerfile: DockerfileDev - volumes: - - C:\Users\lachl\OneDrive\Documents\GitHub\LoopStructural:/home/jovyan/LoopStructural - - C:\Users\lachl\OneDrive\Documents\Loop\notebooks:/home/jovyan/notebooks - ports: - - 8888:8888 - - 8050:8050 - - 8080-8090:8080-8090 - # command: jupyter notebook --ip='0.0.0.0' --NotebookApp.token='' --no-browser diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 6102d8281..000000000 --- a/docker-compose.yml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3" - -services: - structural: - build: - context: ./ - dockerfile: DockerfileDev - volumes: - - /home/lgrose/dev/python/LoopStructural/:/home/jovyan/LoopStructural - - /home/lgrose/LoopStructural/:/home/jovyan/notebooks - - /home/lgrose/dev/python/map2loop-2/:/home/jovyan/map2loop - - /home/lgrose/dev/python/LoopProjectFile/:/home/jovyan/LoopProjectFile - - /home/lgrose/dev/fortran/tomofast/:/home/jovyan/tomofast - ports: - - 8888:8888 - - 8050:8050 - - 8080-8090:8080-8090 - # command: jupyter notebook --ip='0.0.0.0' --NotebookApp.token='' --no-browser diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 33482f439..000000000 --- a/setup.cfg +++ /dev/null @@ -1 +0,0 @@ -[metadata] \ No newline at end of file