diff --git a/app/backend/api/schemas/project_schema.py b/app/backend/api/schemas/project_schema.py index e5d015d..f0538ca 100644 --- a/app/backend/api/schemas/project_schema.py +++ b/app/backend/api/schemas/project_schema.py @@ -229,4 +229,4 @@ class ProtocolWizardExecuteResponse(BaseModel): requiresUserInput: bool = False inputSchema: Optional[WizardInputSchemaResponse] = None preview: Optional[WizardPreviewResponse] = None - viewerState: Optional[WizardViewerStateResponse] = None \ No newline at end of file + viewerState: Optional[Dict[str, Any]] = None \ No newline at end of file diff --git a/app/backend/api/services/protocol_service.py b/app/backend/api/services/protocol_service.py index 9effe25..a8aa404 100644 --- a/app/backend/api/services/protocol_service.py +++ b/app/backend/api/services/protocol_service.py @@ -1,3 +1,28 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** # app/backend/api/services/protocol_service.py from typing import Dict, Any, Optional from app.backend.api.services.project_service import ProjectService diff --git a/app/backend/api/services/protocol_wizard_service.py b/app/backend/api/services/protocol_wizard_service.py index 719609f..45c129e 100644 --- a/app/backend/api/services/protocol_wizard_service.py +++ b/app/backend/api/services/protocol_wizard_service.py @@ -4,6 +4,24 @@ # * # * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC # * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * # ****************************************************************************** from __future__ import annotations @@ -81,15 +99,36 @@ def findWizardsWeb(self, protocol) -> Dict[str, List[Dict[str, Any]]]: return wizardMap + def _sanitizeWizardFormValues( + self, + params: Dict[str, Any], + ) -> Dict[str, Any]: + cleaned: Dict[str, Any] = {} + + for key, value in (params or {}).items(): + if value is None: + continue + + if isinstance(value, str) and value.strip() == "": + continue + + cleaned[key] = value + + return cleaned + def _serializeWizardDescriptor( - self, - wizardClass, - protocol, - targetParams: List[str], + self, + wizardClass, + protocol, + targetParams: List[str], ) -> Dict[str, Any]: wizardId = f"{wizardClass.__module__}.{wizardClass.__name__}" webView = self._safeGetWizardView(wizardClass) - kind = self._classifyWizardKind(wizardClass, webView) + kind = self._classifyWizardKind( + wizardClass=wizardClass, + webView=webView, + targetParams=targetParams, + ) computeKinds = { "compute", @@ -103,6 +142,7 @@ def _serializeWizardDescriptor( "downsample_preview", "filter_preview", "gaussian_preview", + "point_in_volume" } webSupported = kind in computeKinds @@ -120,6 +160,22 @@ def _serializeWizardDescriptor( "webView": webView, } + def _getWizardBaseClassNames(self, wizardClass) -> List[str]: + names: List[str] = [] + + try: + for cls in getattr(wizardClass, "__mro__", ()) or (): + name = getattr(cls, "__name__", None) + if not name: + continue + token = str(name).strip() + if token and token not in names: + names.append(token) + except Exception: + pass + + return names + def _safeGetWizardView(self, wizardClass) -> Optional[str]: try: getViewFn = getattr(wizardClass, "getView", None) @@ -133,11 +189,29 @@ def _safeGetWizardView(self, wizardClass) -> Optional[str]: return None - def _classifyWizardKind(self, wizardClass, webView: Optional[str]) -> str: + def _classifyWizardKind( + self, + wizardClass, + webView: Optional[str], + targetParams: Optional[List[str]] = None, + ) -> str: className = getattr(wizardClass, "__name__", "") or "" classNameLower = className.lower() webViewLower = (webView or "").lower() + normalizedTargetParams = [ + str(item).strip() + for item in (targetParams or []) + if str(item).strip() + ] + targetParamsLower = [item.lower() for item in normalizedTargetParams] + targetParamsSet = set(targetParamsLower) + + baseClassNamesLower = { + name.lower() + for name in self._getWizardBaseClassNames(wizardClass) + } + explicitKinds = { "XmippBoxSizeWizard": "box_size", "XmippParticleConsensusRadiusWizard": "consensus_radius", @@ -148,6 +222,73 @@ def _classifyWizardKind(self, wizardClass, webView: Optional[str]) -> str: if className in explicitKinds: return explicitKinds[className] + baseKindMap = { + "downsamplewizard": "downsample_preview", + "ctfwizard": "ctf_preview", + "particlemaskradiuswizard": "mask_radius", + "volumemaskradiuswizard": "mask_radius", + "particlesmaskradiiwizard": "mask_radii", + "volumemaskradiiwizard": "mask_radii", + "filterparticleswizard": "filter_preview", + "filtervolumeswizard": "filter_preview", + "gaussianparticleswizard": "gaussian_preview", + "gaussianvolumeswizard": "gaussian_preview", + "colorscalewizardbase": "viewer_color_scale", + } + + for baseName, kind in baseKindMap.items(): + if baseName in baseClassNamesLower: + return kind + + if {"xin", "yin", "zin"}.issubset(targetParamsSet): + return "point_in_volume" + + if targetParamsSet in ( + {"innerradius", "outerradius"}, + {"particleradius", "noiseradius"}, + ): + return "mask_radii" + + if len(targetParamsLower) == 1: + onlyParam = targetParamsLower[0] + if onlyParam in { + "radius", + "maskradius", + "volumeradius", + "volumeradiushalf", + "cylinderouterradius", + "cylinderinnerradius", + "consensusradius", + "cirmaskrad", + "rmax", + }: + return "mask_radius" + + if ( + len(targetParamsLower) >= 2 + and all("radius" in item for item in targetParamsLower) + ): + return "mask_radii" + + if ( + len(targetParamsLower) == 1 + and any(token in targetParamsLower[0] for token in ("down", "factor")) + ): + return "downsample_preview" + + if {"ctfdownfactor", "lowres", "highres"}.issubset(targetParamsSet): + return "ctf_preview" + + if any(item in targetParamsSet for item in {"freqsigma"}): + return "gaussian_preview" + + if ( + any(item in targetParamsSet for item in {"lowfreqa", "lowfreqdig"}) + and any(item in targetParamsSet for item in {"highfreqa", "highfreqdig"}) + and any(item in targetParamsSet for item in {"freqdecaya", "freqdecaydig"}) + ): + return "filter_preview" + if "lane" in classNameLower and "wizard" in classNameLower: return "compute_lane_selector" @@ -187,7 +328,7 @@ def _classifyWizardKind(self, wizardClass, webView: Optional[str]) -> str: return "legacy_web_view" if classNameLower.endswith("wizard") and any( - token in classNameLower for token in ("boxsize", "radius", "classes") + token in classNameLower for token in ("boxsize", "radius", "classes") ): return "compute" @@ -276,7 +417,8 @@ def _buildWizardReadyProtocol( self.currentProject._fixProtParamsConfiguration(protocol) - errors = self._applyFormValuesToProtocolInstance(protocol, formValues or {}) + sanitizedFormValues = self._sanitizeWizardFormValues(formValues or {}) + errors = self._applyFormValuesToProtocolInstance(protocol, sanitizedFormValues) if errors: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, @@ -381,7 +523,6 @@ def executeProtocolWizard( if kind in { "viewer_color_scale", - "point_in_volume", "legacy_web_view", "unknown", }: diff --git a/app/backend/api/services/wizards/point_in_volume.py b/app/backend/api/services/wizards/point_in_volume.py new file mode 100644 index 0000000..e427c4b --- /dev/null +++ b/app/backend/api/services/wizards/point_in_volume.py @@ -0,0 +1,342 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** +from __future__ import annotations + +import math +from typing import Any, Dict, Optional, Tuple + +import numpy as np +from fastapi import HTTPException, status + +from app.backend.utils.volume_utils import readVolumeArray3d + + +POINT_IN_VOLUME_HELP_MESSAGE = ( + "Select the new center inside the input volume and apply the coordinates." +) + + +def executePointInVolumeWizard( + *, + wizardClass, + protocol, + paramName: str, + descriptor: Optional[Dict[str, Any]] = None, + wizardInputs: Optional[Dict[str, Any]] = None, + currentProject=None, + projectId: Optional[int] = None, +) -> Dict[str, Any]: + wizardInputs = wizardInputs or {} + action = _normalizePointInVolumeAction(wizardInputs) + + volumePath = _resolveInputVolumePath(protocol) + volumeData, _props = readVolumeArray3d(volumePath) + + if volumeData is None or volumeData.ndim != 3: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Input volume is not a valid 3D map", + ) + + volumeData = np.asarray(volumeData, dtype=np.float32) + dimsZYX = [int(volumeData.shape[0]), int(volumeData.shape[1]), int(volumeData.shape[2])] + + if action == "apply": + point = _resolveAppliedPoint(wizardInputs, dimsZYX) + return { + "paramUpdates": { + "xin": float(point["x"]), + "yin": float(point["y"]), + "zin": float(point["z"]), + }, + "message": "Point in volume applied", + "availableValues": [], + } + + currentPoint = { + "x": _readProtocolFloatValue(protocol, "xin", default=0.0), + "y": _readProtocolFloatValue(protocol, "yin", default=0.0), + "z": _readProtocolFloatValue(protocol, "zin", default=0.0), + } + currentVoxel = _centerCoordsToVoxel(currentPoint, dimsZYX) + + previewVolume = _downsampleVolumePreviewUint8(volumeData, maxDim=64) + + return { + "paramUpdates": {}, + "message": POINT_IN_VOLUME_HELP_MESSAGE, + "requiresUserInput": True, + "availableValues": [], + "inputSchema": { + "type": "point_in_volume", + "paramName": paramName, + "title": "Wizard", + "fields": [], + }, + "viewerState": { + "dims": dimsZYX, + "previewDims": [ + int(previewVolume.shape[0]), + int(previewVolume.shape[1]), + int(previewVolume.shape[2]), + ], + "previewValues": previewVolume.ravel(order="C").astype(np.uint8).tolist(), + "axisOrder": ["z", "y", "x"], + "point": currentPoint, + "pointVoxel": currentVoxel, + "bounds": { + "xMin": -0.5 * float(dimsZYX[2]), + "xMax": 0.5 * float(dimsZYX[2]), + "yMin": -0.5 * float(dimsZYX[1]), + "yMax": 0.5 * float(dimsZYX[1]), + "zMin": -0.5 * float(dimsZYX[0]), + "zMax": 0.5 * float(dimsZYX[0]), + }, + }, + } + + +def _normalizePointInVolumeAction(wizardInputs: Dict[str, Any]) -> str: + if not wizardInputs: + return "open" + + actionRaw = wizardInputs.get("action") + if actionRaw is None: + if any( + key in wizardInputs + for key in ("x", "y", "z", "point", "pointVoxel", "voxelX", "voxelY", "voxelZ") + ): + return "apply" + return "open" + + action = str(actionRaw).strip().lower() + if action in {"open", "preview", "apply"}: + return action + + return "open" + + +def _resolveInputVolumePath(protocol) -> str: + inputVolHolder = getattr(protocol, "inputVol", None) + if inputVolHolder is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="This wizard requires protocol.inputVol", + ) + + getFn = getattr(inputVolHolder, "get", None) + volumeObj = getFn() if callable(getFn) else inputVolHolder + if volumeObj is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Select an input volume first", + ) + + fileNameFn = getattr(volumeObj, "getFileName", None) + if not callable(fileNameFn): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Input volume does not expose getFileName()", + ) + + volumePath = fileNameFn() + if not volumePath: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Input volume file not found", + ) + + return str(volumePath) + + +def _readProtocolFloatValue(protocol, paramName: str, default: float = 0.0) -> float: + protVar = getattr(protocol, paramName, None) + if protVar is None: + return float(default) + + getter = getattr(protVar, "get", None) + value = None + + if callable(getter): + try: + value = getter() + except Exception: + value = None + + if value is None: + value = protVar + + if value in (None, ""): + return float(default) + + try: + parsed = float(value) + if not math.isfinite(parsed): + raise ValueError + return parsed + except Exception: + return float(default) + + +def _coerceFloat(value: Any, default: float = 0.0) -> float: + if value in (None, ""): + return float(default) + + try: + parsed = float(value) + if not math.isfinite(parsed): + raise ValueError + return parsed + except Exception: + return float(default) + + +def _centerCoordsToVoxel(point: Dict[str, float], dimsZYX) -> Dict[str, float]: + zDim, yDim, xDim = dimsZYX + + voxelX = float(point["x"]) + 0.5 * float(xDim) + voxelY = float(point["y"]) + 0.5 * float(yDim) + voxelZ = float(point["z"]) + 0.5 * float(zDim) + + voxelX = min(max(voxelX, 0.0), float(xDim - 1)) + voxelY = min(max(voxelY, 0.0), float(yDim - 1)) + voxelZ = min(max(voxelZ, 0.0), float(zDim - 1)) + + return { + "x": voxelX, + "y": voxelY, + "z": voxelZ, + } + + +def _voxelCoordsToCenter(pointVoxel: Dict[str, float], dimsZYX) -> Dict[str, float]: + zDim, yDim, xDim = dimsZYX + + voxelX = min(max(float(pointVoxel["x"]), 0.0), float(xDim - 1)) + voxelY = min(max(float(pointVoxel["y"]), 0.0), float(yDim - 1)) + voxelZ = min(max(float(pointVoxel["z"]), 0.0), float(zDim - 1)) + + return { + "x": voxelX - 0.5 * float(xDim), + "y": voxelY - 0.5 * float(yDim), + "z": voxelZ - 0.5 * float(zDim), + } + + +def _resolveAppliedPoint(wizardInputs: Dict[str, Any], dimsZYX) -> Dict[str, float]: + point = wizardInputs.get("point") + if isinstance(point, dict): + if all(axis in point for axis in ("x", "y", "z")): + return { + "x": _coerceFloat(point.get("x"), 0.0), + "y": _coerceFloat(point.get("y"), 0.0), + "z": _coerceFloat(point.get("z"), 0.0), + } + + if all(key in wizardInputs for key in ("x", "y", "z")): + return { + "x": _coerceFloat(wizardInputs.get("x"), 0.0), + "y": _coerceFloat(wizardInputs.get("y"), 0.0), + "z": _coerceFloat(wizardInputs.get("z"), 0.0), + } + + pointVoxel = wizardInputs.get("pointVoxel") + if isinstance(pointVoxel, dict) and all(axis in pointVoxel for axis in ("x", "y", "z")): + return _voxelCoordsToCenter(pointVoxel, dimsZYX) + + if all(key in wizardInputs for key in ("voxelX", "voxelY", "voxelZ")): + return _voxelCoordsToCenter( + { + "x": _coerceFloat(wizardInputs.get("voxelX"), 0.0), + "y": _coerceFloat(wizardInputs.get("voxelY"), 0.0), + "z": _coerceFloat(wizardInputs.get("voxelZ"), 0.0), + }, + dimsZYX, + ) + + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Missing point coordinates for point_in_volume wizard", + ) + + +def _downsampleVolumePreviewUint8(volumeData: np.ndarray, maxDim: int = 64) -> np.ndarray: + volumeData = np.asarray(volumeData, dtype=np.float32) + + if volumeData.ndim != 3: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Invalid volume preview shape", + ) + + preview = _binVolumeToMaxDim(volumeData, maxDim=maxDim) + + finiteMask = np.isfinite(preview) + if not finiteMask.any(): + return np.zeros(preview.shape, dtype=np.uint8) + + valid = preview[finiteMask] + low = float(np.percentile(valid, 1.0)) + high = float(np.percentile(valid, 99.0)) + + if high <= low: + low = float(valid.min()) + high = float(valid.max()) + + if high <= low: + return np.zeros(preview.shape, dtype=np.uint8) + + clipped = np.clip(preview, low, high) + norm = (clipped - low) / (high - low + 1e-12) + return (255.0 * norm).astype(np.uint8) + + +def _binVolumeToMaxDim(volumeData: np.ndarray, maxDim: int = 64) -> np.ndarray: + zDim, yDim, xDim = volumeData.shape + maxCurrentDim = max(zDim, yDim, xDim) + + if maxCurrentDim <= maxDim: + return volumeData.astype(np.float32, copy=False) + + factor = int(np.ceil(float(maxCurrentDim) / float(maxDim))) + if factor <= 1: + return volumeData.astype(np.float32, copy=False) + + zCrop = (zDim // factor) * factor + yCrop = (yDim // factor) * factor + xCrop = (xDim // factor) * factor + + if min(zCrop, yCrop, xCrop) <= 0: + return volumeData.astype(np.float32, copy=False) + + cropped = volumeData[:zCrop, :yCrop, :xCrop] + binned = cropped.reshape( + zCrop // factor, factor, + yCrop // factor, factor, + xCrop // factor, factor, + ).mean(axis=(1, 3, 5)) + + return binned.astype(np.float32, copy=False) \ No newline at end of file diff --git a/app/backend/api/services/wizards/registry.py b/app/backend/api/services/wizards/registry.py index 2b842cb..57ec022 100644 --- a/app/backend/api/services/wizards/registry.py +++ b/app/backend/api/services/wizards/registry.py @@ -19,6 +19,7 @@ executeMaskRadiiWizard, ) from .downsample import executeDownsamplePreviewWizard +from .point_in_volume import executePointInVolumeWizard HANDLERS: Dict[str, Callable[..., Dict[str, Any]]] = { "compute": executeGenericComputeWizard, @@ -32,6 +33,7 @@ "filter_preview": executeFilterPreviewWizard, "gaussian_preview": executeGaussianPreviewWizard, "downsample_preview": executeDownsamplePreviewWizard, + "point_in_volume": executePointInVolumeWizard, } diff --git a/tests/unit/backend/api/services/test_project_service_coords3d.py b/tests/unit/backend/api/services/test_project_service_coords3d.py new file mode 100644 index 0000000..c57627d --- /dev/null +++ b/tests/unit/backend/api/services/test_project_service_coords3d.py @@ -0,0 +1,515 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import numpy as np +import pytest +from fastapi import HTTPException + + +class FakeTomogram: + # fakeTomogram + def __init__(self, tsId, label, fileName, samplingRate, dims): + self._tsId = tsId + self._label = label + self._fileName = fileName + self._samplingRate = samplingRate + self._dims = dims + + def getTsId(self): + return self._tsId + + def getObjId(self): + return self._tsId + + def getObjLabel(self): + return self._label + + def getFileName(self): + return self._fileName + + def getSamplingRate(self): + return self._samplingRate + + def getDim(self): + return self._dims + + +class FakeCoord: + # fakeCoord + def __init__( + self, + x, + y, + z, + objId=None, + classId=None, + label=None, + score=None, + weight=None, + matrix=None, + ): + self._x = x + self._y = y + self._z = z + self._objId = objId + self._classId = classId + self._objLabel = label + self._score = score + self._weight = weight + self._matrix = matrix if matrix is not None else np.eye(4, dtype=float) + + def getX(self, corner): + return self._x + + def getY(self, corner): + return self._y + + def getZ(self, corner): + return self._z + + def getObjId(self): + return self._objId + + def getClassId(self): + return self._classId + + def getObjLabel(self): + return self._objLabel + + def getScore(self): + return self._score + + def getWeight(self): + return self._weight + + def getMatrix(self): + return self._matrix + + +class FakeCoordinatesSet: + # fakeCoordinatesSet + def __init__(self, tomograms=None, coordsByTomogram=None, boxSize=24): + self._tomograms = tomograms or [] + self._coordsByTomogram = coordsByTomogram or {} + self._boxSize = boxSize + + def iterTomograms(self): + return iter(self._tomograms) + + def getBoxSize(self): + return self._boxSize + + def iterCoordinates(self, tomogram): + return iter(self._coordsByTomogram.get(tomogram.getTsId(), [])) + + def _getTomogram(self, key): + for tomo in self._tomograms: + if str(tomo.getTsId()) == str(key): + return tomo + return None + + def createCopy(self, protocolPath, prefix=None, copyInfo=True): + return FakeCreatedCoordinatesSet(prefix=prefix) + + def getTomograms(self): + return self._tomograms + + +class FakeCreatedCoordinatesSet: + # fakeCreatedCoordinatesSet + def __init__(self, prefix=None): + self.prefix = prefix + self.appended = [] + self.tomograms = None + self.written = False + self._samplingRate = 3.0 + + def setTomograms(self, tomograms): + self.tomograms = tomograms + + def append(self, coord): + self.appended.append(coord) + + def write(self): + self.written = True + + def getSamplingRate(self): + return self._samplingRate + + +class FakeCreatedCoordinate3D: + # fakeCreatedCoordinate3D + def __init__(self): + self.objId = None + self.volume = None + self.position = None + self.groupId = None + self.tomoId = None + self.boxSize = None + self.score = None + self.matrix = None + + def setObjId(self, value): + self.objId = value + + def setVolume(self, value): + self.volume = value + + def setPosition(self, x, y, z, corner): + self.position = { + "x": x, + "y": y, + "z": z, + "corner": corner, + } + + def setGroupId(self, value): + self.groupId = value + + def setTomoId(self, value): + self.tomoId = value + + def setBoxSize(self, value): + self.boxSize = value + + def setScore(self, value): + self.score = value + + def setMatrix(self, value): + self.matrix = value + + +class FakeProtocol: + # fakeProtocol + def __init__(self, outputName, output): + setattr(self, outputName, output) + self.definedOutputs = {} + self.stored = False + + def getNextOutputName(self, baseName): + return baseName + "_edited" + + def _getPath(self): + return "/tmp/fake-protocol-path" + + def _defineOutputs(self, **kwargs): + self.definedOutputs.update(kwargs) + + def _store(self): + self.stored = True + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, protocol): + self._protocol = protocol + + def getProtocol(self, protocolId): + return self._protocol + + +@pytest.fixture +def projectServiceModule(authTestEnv): + # projectServiceModule + return importlib.import_module("app.backend.api.services.project_service") + + +@pytest.fixture +def service(projectServiceModule): + # service + instance = object.__new__(projectServiceModule.ProjectService) + instance.currentProject = None + instance.tomoList = {} + return instance + + +def test_ListCoordinates3dTomogramsServiceBuildsTomogramList(service, tmp_path): + tomoPath1 = tmp_path / "tomo1.mrc" + tomoPath2 = tmp_path / "tomo2.mrc" + tomoPath1.write_text("placeholder", encoding="utf-8") + tomoPath2.write_text("placeholder", encoding="utf-8") + + tomo1 = FakeTomogram( + tsId="TS_001", + label="Tomogram 1", + fileName=str(tomoPath1), + samplingRate=2.5, + dims=[128, 128, 64], + ) + tomo2 = FakeTomogram( + tsId="TS_002", + label="Tomogram 2", + fileName=str(tomoPath2), + samplingRate=3.0, + dims=[64, 64, 32], + ) + + output = FakeCoordinatesSet(tomograms=[tomo1, tomo2]) + protocol = FakeProtocol("outputCoords3d", output) + service.currentProject = FakeCurrentProject(protocol) + + result = service.listCoordinates3dTomogramsService( + projectId=1, + protocolId=10, + outputName="outputCoords3d", + ) + + assert result == [ + { + "id": "TS_001", + "name": "Tomogram 1", + "label": "TS_001", + "dims": [128, 128, 64], + "voxelSize": [2.5, 2.5, 2.5], + }, + { + "id": "TS_002", + "name": "Tomogram 2", + "label": "TS_002", + "dims": [64, 64, 32], + "voxelSize": [3.0, 3.0, 3.0], + }, + ] + assert service.tomoList["TS_001"] is tomo1 + assert service.tomoList["TS_002"] is tomo2 + + +def test_GetCoordinates3dPointsServiceBuildsPointPayload(service, tmp_path): + tomoPath = tmp_path / "tomo1.mrc" + tomoPath.write_text("placeholder", encoding="utf-8") + + tomo = FakeTomogram( + tsId="TS_001", + label="Tomogram 1", + fileName=str(tomoPath), + samplingRate=2.5, + dims=[128, 128, 64], + ) + + coords = [ + FakeCoord( + x=10.0, + y=20.0, + z=30.0, + objId=101, + classId=7, + label="point-101", + score=0.87, + weight=1.5, + matrix=np.array([[1, 0], [0, 1]], dtype=float), + ), + ] + + output = FakeCoordinatesSet( + tomograms=[tomo], + coordsByTomogram={"TS_001": coords}, + boxSize=48, + ) + protocol = FakeProtocol("outputCoords3d", output) + service.currentProject = FakeCurrentProject(protocol) + service.tomoList = {"TS_001": tomo} + + result = service.getCoordinates3dPointsService( + projectId=1, + protocolId=10, + outputName="outputCoords3d", + tomogramId="TS_001", + ) + + assert result == [ + { + "x": 10.0, + "y": 20.0, + "z": 30.0, + "id": 101, + "classId": 7, + "label": "point-101", + "score": 0.87, + "weight": 1.5, + "radius": 48.0, + "matrix": [[1.0, 0.0], [0.0, 1.0]], + "tomoId": "TS_001", + } + ] + + +def test_GetCoordinates3dPointsServiceReturns404WhenTomogramMissing(service): + output = FakeCoordinatesSet(tomograms=[], coordsByTomogram={}) + protocol = FakeProtocol("outputCoords3d", output) + service.currentProject = FakeCurrentProject(protocol) + service.tomoList = {} + + with pytest.raises(HTTPException) as exc: + service.getCoordinates3dPointsService( + projectId=1, + protocolId=10, + outputName="outputCoords3d", + tomogramId="missing", + ) + + assert exc.value.status_code == 404 + assert exc.value.detail == "Tomogram 'missing' not found in SetOfCoordinates3D" + + +def test_RenderCoords3dTomogramSliceServiceReturnsImageResponse(projectServiceModule, service, monkeypatch, tmp_path): + tomoPath = tmp_path / "tomo1.mrc" + tomoPath.write_text("placeholder", encoding="utf-8") + + tomo = FakeTomogram( + tsId="TS_001", + label="Tomogram 1", + fileName=str(tomoPath), + samplingRate=2.5, + dims=[4, 4, 4], + ) + + output = FakeCoordinatesSet(tomograms=[tomo], coordsByTomogram={}) + protocol = FakeProtocol("outputCoords3d", output) + service.currentProject = FakeCurrentProject(protocol) + service.tomoList = {"TS_001": tomo} + + monkeypatch.setattr( + projectServiceModule, + "readVolumeArray3d", + lambda volumePath: ( + np.arange(64, dtype=np.float32).reshape((4, 4, 4)), + {}, + ), + ) + + response = service.renderCoords3dTomogramSliceService( + projectId=1, + protocolId=10, + outputName="outputCoords3d", + tomogramId="TS_001", + sliceIndex=1, + axis="z", + colormap=None, + normalize="minmax", + scale=1.0, + inline=True, + fmt="png", + thumb=None, + fast=False, + quality=75, + ) + + assert response.media_type == "image/png" + assert response.headers["x-preview-depth"] == "4" + assert response.headers["x-preview-tomogramid"] == "TS_001" + assert response.headers["x-preview-format"] == "PNG" + assert len(response.body) > 0 + + +def test_CreateCoords3dOutputFromPointsServiceCreatesNewOutput(projectServiceModule, service, monkeypatch, tmp_path): + tomoPath = tmp_path / "tomo1.mrc" + tomoPath.write_text("placeholder", encoding="utf-8") + + tomo = FakeTomogram( + tsId="TS_001", + label="Tomogram 1", + fileName=str(tomoPath), + samplingRate=2.5, + dims=[128, 128, 64], + ) + + output = FakeCoordinatesSet( + tomograms=[tomo], + coordsByTomogram={}, + boxSize=48, + ) + protocol = FakeProtocol("outputCoords3d", output) + service.currentProject = FakeCurrentProject(protocol) + service.tomoList = {"TS_001": tomo} + + monkeypatch.setattr(projectServiceModule, "Coordinate3D", FakeCreatedCoordinate3D) + + payload = { + "tomograms": [ + { + "tomoId": "TS_001", + "coords": [ + { + "x": 1.0, + "y": 2.0, + "z": 3.0, + "groupId": 5, + "score": 0.9, + "matrix": [[1, 0], [0, 1]], + "tomoId": "TS_001", + }, + { + "x": 4.0, + "y": 5.0, + "z": 6.0, + "tomoId": "TS_001", + }, + ], + } + ] + } + + result = service.createCoords3dOutputFromPointsService( + projectId=1, + protocolId=10, + outputName="outputCoords3d", + payload=payload, + ) + + assert result == { + "success": True, + "outputName": "outputCoords3d_edited", + "message": "Created new coords3d output 'outputCoords3d_edited'", + "data": { + "sourceOutputName": "outputCoords3d", + "replacedPoints": 2, + "copiedPoints": 0, + }, + } + + assert "outputCoords3d_edited" in protocol.definedOutputs + createdSet = protocol.definedOutputs["outputCoords3d_edited"] + assert createdSet.written is True + assert len(createdSet.appended) == 2 + + firstCoord = createdSet.appended[0] + assert firstCoord.position["x"] == 1.0 + assert firstCoord.position["y"] == 2.0 + assert firstCoord.position["z"] == 3.0 + assert firstCoord.groupId == 5 + assert firstCoord.tomoId == "TS_001" + assert firstCoord.score == 0.9 + assert firstCoord.boxSize == 3.0 + assert firstCoord.matrix.tolist() == [[1, 0], [0, 1]] + + secondCoord = createdSet.appended[1] + assert secondCoord.groupId == 0 + assert secondCoord.score == 0 + + assert protocol.stored is True \ No newline at end of file diff --git a/tests/unit/backend/api/services/test_project_service_ctftomo.py b/tests/unit/backend/api/services/test_project_service_ctftomo.py new file mode 100644 index 0000000..8355b7c --- /dev/null +++ b/tests/unit/backend/api/services/test_project_service_ctftomo.py @@ -0,0 +1,606 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import pytest + + +class FakeAcquisition: + # fakeAcquisition + def __init__(self, accumDose=None): + self._accumDose = accumDose + + def getAccumDose(self): + return self._accumDose + + +class FakeTiltView: + # fakeTiltView + def __init__(self, acqOrder, tiltAngle, accumDose): + self._acqOrder = acqOrder + self._tiltAngle = tiltAngle + self._acquisition = FakeAcquisition(accumDose=accumDose) + + def getTiltAngle(self): + return self._tiltAngle + + def getAcquisition(self): + return self._acquisition + + +class FakeAssociatedTiltSeries: + # fakeAssociatedTiltSeries + def __init__(self, items=None, dims=None, samplingRate=1.5): + self._items = items or {} + self._dims = dims or [128, 128, 40] + self._samplingRate = samplingRate + + def getItem(self, key, value): + if key != "_acqOrder": + return None + return self._items.get(value) + + def getDim(self): + return self._dims + + def getSamplingRate(self): + return self._samplingRate + + def getSize(self): + return len(self._items) + + +class FakeCtfMeasurement: + # fakeCtfMeasurement + def __init__( + self, + objId, + index, + defocusU, + defocusV, + defocusAngle, + resolution, + phaseShift, + acquisitionOrder, + psdFile, + enabled=True, + ): + self._objId = objId + self._index = index + self._defocusU = defocusU + self._defocusV = defocusV + self._defocusAngle = defocusAngle + self._resolution = resolution + self._phaseShift = phaseShift + self._acquisitionOrder = acquisitionOrder + self._psdFile = psdFile + self._enabled = enabled + + def getObjId(self): + return self._objId + + def getIndex(self): + return self._index + + def getDefocusU(self): + return self._defocusU + + def getDefocusV(self): + return self._defocusV + + def getDefocusAngle(self): + return self._defocusAngle + + def getResolution(self): + return self._resolution + + def getPhaseShift(self): + return self._phaseShift + + def getAcquisitionOrder(self): + return self._acquisitionOrder + + def getPsdFile(self): + return self._psdFile + + def isEnabled(self): + return self._enabled + + def setEnabled(self, value): + self._enabled = value + + def clone(self): + return FakeCtfMeasurement( + objId=self._objId, + index=self._index, + defocusU=self._defocusU, + defocusV=self._defocusV, + defocusAngle=self._defocusAngle, + resolution=self._resolution, + phaseShift=self._phaseShift, + acquisitionOrder=self._acquisitionOrder, + psdFile=self._psdFile, + enabled=self._enabled, + ) + + +class FakeCtftomoSeries: + # fakeCtftomoSeries + def __init__(self, tsId, label, tiltSeries, items=None): + self._tsId = tsId + self._label = label + self._tiltSeries = tiltSeries + self._items = items or [] + self._enabled = True + self._written = False + + def getTsId(self): + return self._tsId + + def getObjLabel(self): + return self._label + + def getTiltSeries(self): + return self._tiltSeries + + def iterItems(self, iterate=False): + return list(self._items) + + def clone(self): + return FakeCtftomoSeries( + tsId=self._tsId, + label=self._label, + tiltSeries=self._tiltSeries, + items=[], + ) + + def setEnabled(self, value): + self._enabled = value + + def append(self, item): + self._items.append(item) + + def write(self): + self._written = True + + +class FakeCtftomoOutputSet: + # fakeCtftomoOutputSet + def __init__(self, seriesList=None, associatedTiltSeriesSet=None): + self._seriesList = seriesList or [] + self._associatedTiltSeriesSet = associatedTiltSeriesSet + self._updated = [] + self._written = False + self._linkedTiltSeries = None + + def iterItems(self, iterate=False): + return list(self._seriesList) + + def getItem(self, key, value): + if key != "_tsId": + return None + for item in self._seriesList: + if str(item.getTsId()) == str(value): + return item + return None + + def getSetOfTiltSeries(self): + return self._associatedTiltSeriesSet + + def createCopy(self, protocolPath, prefix=None, copyInfo=True): + return FakeCtftomoOutputSet(seriesList=[], associatedTiltSeriesSet=self._associatedTiltSeriesSet) + + def append(self, item): + self._seriesList.append(item) + + def update(self, item): + self._updated.append(item) + + def write(self): + self._written = True + + def isEmpty(self): + return len(self._seriesList) == 0 + + def setSetOfTiltSeries(self, tiltSeriesSet): + self._linkedTiltSeries = tiltSeriesSet + + +class FakeProtocol: + # fakeProtocol + def __init__(self, outputName, output, protocolPath): + setattr(self, outputName, output) + self._protocolPath = protocolPath + self._stored = False + self._definedOutputs = {} + self._nextOutputName = "CTFTomoSeries_0" + + def getPath(self): + return str(self._protocolPath) + + def _getPath(self): + return str(self._protocolPath) + + def getNextOutputName(self, prefix): + return self._nextOutputName + + def _defineOutputs(self, **kwargs): + self._definedOutputs.update(kwargs) + + def _store(self): + self._stored = True + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, protocol): + self._protocol = protocol + + def getProtocol(self, protocolId): + return self._protocol + + +class FakeOutputsPreview: + # fakeOutputsPreview + instances = [] + + def __init__(self, currentProject, protocol, output=None, requestHeaders=None): + self.currentProject = currentProject + self.protocol = protocol + self.output = output + self.requestHeaders = requestHeaders + self.lastRenderCall = None + FakeOutputsPreview.instances.append(self) + + def renderImageFromFilePath( + self, + filePath, + size, + fmt, + index, + inline, + quality, + applyTransform, + rot, + shifts, + ): + self.lastRenderCall = { + "filePath": filePath, + "size": size, + "fmt": fmt, + "index": index, + "inline": inline, + "quality": quality, + "applyTransform": applyTransform, + "rot": rot, + "shifts": shifts, + } + return { + "rendered": True, + "filePath": filePath, + } + + +@pytest.fixture +def projectServiceModule(authTestEnv): + # projectServiceModule + return importlib.import_module("app.backend.api.services.project_service") + + +@pytest.fixture +def service(projectServiceModule): + # service + instance = object.__new__(projectServiceModule.ProjectService) + instance.tomoList = {} + instance.currentProject = None + return instance + + +def test_ListOutputCtftomoSeriesServiceBuildsSummaries(service, tmp_path): + associatedTs = FakeAssociatedTiltSeries( + items={ + 1: FakeTiltView(acqOrder=1, tiltAngle=-60.0, accumDose=2.5), + 2: FakeTiltView(acqOrder=2, tiltAngle=-58.0, accumDose=3.0), + }, + dims=[128, 128, 40], + samplingRate=1.25, + ) + series1 = FakeCtftomoSeries( + tsId="TS_001", + label="Series 1", + tiltSeries=associatedTs, + items=[], + ) + series2 = FakeCtftomoSeries( + tsId="TS_002", + label="Series 2", + tiltSeries=associatedTs, + items=[], + ) + output = FakeCtftomoOutputSet(seriesList=[series1, series2], associatedTiltSeriesSet=None) + protocol = FakeProtocol("outputCtftomo", output, tmp_path) + service.currentProject = FakeCurrentProject(protocol) + + result = service.listOutputCtftomoSeriesService( + projectId=1, + protocolId=10, + outputName="outputCtftomo", + ) + + assert result == [ + { + "tiltSeriesId": "TS_001", + "label": "Series 1", + "nViews": 2, + "dims": [128, 128, 40], + "pixelSize": 1.25, + "index": 0, + }, + { + "tiltSeriesId": "TS_002", + "label": "Series 2", + "nViews": 2, + "dims": [128, 128, 40], + "pixelSize": 1.25, + "index": 1, + }, + ] + + +def test_GetCtftomoSeriesViewsServiceBuildsFrames(service, tmp_path): + associatedTs = FakeAssociatedTiltSeries( + items={ + 1: FakeTiltView(acqOrder=1, tiltAngle=-60.0, accumDose=2.5), + 2: FakeTiltView(acqOrder=2, tiltAngle=-58.0, accumDose=3.0), + }, + dims=[128, 128, 40], + samplingRate=1.25, + ) + ctf1 = FakeCtfMeasurement( + objId=100, + index=1, + defocusU=12000.0, + defocusV=11000.0, + defocusAngle=45.0, + resolution=3.2, + phaseShift=0.15, + acquisitionOrder=1, + psdFile="psd1.mrc", + enabled=True, + ) + ctf2 = FakeCtfMeasurement( + objId=101, + index=2, + defocusU=13000.0, + defocusV=12500.0, + defocusAngle=50.0, + resolution=3.5, + phaseShift=0.12, + acquisitionOrder=2, + psdFile="psd2.mrc", + enabled=False, + ) + series = FakeCtftomoSeries( + tsId="TS_001", + label="Series 1", + tiltSeries=associatedTs, + items=[ctf1, ctf2], + ) + output = FakeCtftomoOutputSet(seriesList=[series], associatedTiltSeriesSet=FakeCtftomoOutputSet(seriesList=[])) + output._associatedTiltSeriesSet = type( + "TiltSeriesSet", + (), + { + "getItem": lambda self, key, value: associatedTs, + }, + )() + + protocol = FakeProtocol("outputCtftomo", output, tmp_path) + service.currentProject = FakeCurrentProject(protocol) + + result = service.getCtftomoSeriesViewsService( + projectId=1, + protocolId=10, + outputName="outputCtftomo", + tiltSeriesId="TS_001", + ) + + assert result["tiltSeriesId"] == "TS_001" + assert result["label"] == "Series 1" + assert result["nViews"] == 2 + assert result["dims"] == [128, 128, 40] + assert result["pixelSize"] == 1.25 + assert len(result["frames"]) == 2 + + frame1 = result["frames"][0] + assert frame1 == { + "index": 100, + "viewIndex": 100, + "tiltAngle": -60.0, + "dose": 2.5, + "defocusU": 12000.0, + "defocusV": 11000.0, + "astigmatism": 1000.0, + "defocusAngle": 45.0, + "resolution": 3.2, + "phaseShift": 0.15, + "order": 1, + "psdFile": "psd1.mrc", + "excluded": False, + } + + frame2 = result["frames"][1] + assert frame2["index"] == 101 + assert frame2["viewIndex"] == 101 + assert frame2["tiltAngle"] == -58.0 + assert frame2["dose"] == 3.0 + assert frame2["astigmatism"] == 500.0 + assert frame2["excluded"] is True + + +def test_RenderCtfTomoPsdImageServiceDelegatesToOutputsPreview(projectServiceModule, service, monkeypatch, tmp_path): + FakeOutputsPreview.instances = [] + + psdFile = tmp_path / "psd_001.mrc" + psdFile.write_text("placeholder", encoding="utf-8") + + output = FakeCtftomoOutputSet(seriesList=[], associatedTiltSeriesSet=None) + protocol = FakeProtocol("outputCtftomo", output, tmp_path) + service.currentProject = FakeCurrentProject(protocol) + + monkeypatch.setattr(projectServiceModule, "OutputsPreview", FakeOutputsPreview) + + result = service.renderCtfTomoPsdImageService( + projectId=1, + protocolId=10, + outputName="outputCtftomo", + psdPath="3@" + psdFile.name, + size=512, + fmt="png", + inline=False, + quality=80, + applyTransform=True, + rot=12.0, + shifts=(4.0, -1.0), + ) + + assert result == { + "rendered": True, + "filePath": str(psdFile.resolve()), + } + assert len(FakeOutputsPreview.instances) == 1 + assert FakeOutputsPreview.instances[0].lastRenderCall == { + "filePath": str(psdFile.resolve()), + "size": 512, + "fmt": "png", + "index": 3, + "inline": False, + "quality": 80, + "applyTransform": True, + "rot": 12.0, + "shifts": (4.0, -1.0), + } + + +def test_CreateNewSetOfCtftomoSeriesServiceReturnsEmptyWhenEverythingExcluded(service, tmp_path): + associatedTs = FakeAssociatedTiltSeries() + series1 = FakeCtftomoSeries( + tsId="TS_001", + label="Series 1", + tiltSeries=associatedTs, + items=[], + ) + inputSet = FakeCtftomoOutputSet(seriesList=[series1], associatedTiltSeriesSet=associatedTs) + protocol = FakeProtocol("outputCtftomo", inputSet, tmp_path) + service.currentProject = FakeCurrentProject(protocol) + + result = service.createNewSetOfCtftomoSeriesService( + projectId=1, + protocolId=10, + outputName="outputCtftomo", + exclusions={ + "TS_001": { + "excluded": True, + "tiltimages": [], + } + }, + restack=False, + ) + + assert result == { + "status": "empty", + "outputName": "CTFTomoSeries_0", + "createdSeries": 0, + "restack": False, + "message": "No output was generated because it cannot be empty", + } + + +def test_CreateNewSetOfCtftomoSeriesServiceCreatesFilteredSeries(service, tmp_path): + associatedTs = FakeAssociatedTiltSeries() + ctf1 = FakeCtfMeasurement( + objId=100, + index=1, + defocusU=12000.0, + defocusV=11000.0, + defocusAngle=45.0, + resolution=3.2, + phaseShift=0.15, + acquisitionOrder=1, + psdFile="psd1.mrc", + enabled=True, + ) + ctf2 = FakeCtfMeasurement( + objId=101, + index=2, + defocusU=13000.0, + defocusV=12500.0, + defocusAngle=50.0, + resolution=3.5, + phaseShift=0.12, + acquisitionOrder=2, + psdFile="psd2.mrc", + enabled=True, + ) + inputSeries = FakeCtftomoSeries( + tsId="TS_001", + label="Series 1", + tiltSeries=associatedTs, + items=[ctf1, ctf2], + ) + inputSet = FakeCtftomoOutputSet(seriesList=[inputSeries], associatedTiltSeriesSet=associatedTs) + protocol = FakeProtocol("outputCtftomo", inputSet, tmp_path) + service.currentProject = FakeCurrentProject(protocol) + + result = service.createNewSetOfCtftomoSeriesService( + projectId=1, + protocolId=10, + outputName="outputCtftomo", + exclusions={ + "TS_001": { + "excluded": False, + "tiltimages": [2], + } + }, + restack=False, + ) + + assert result == { + "status": "ok", + "outputName": "CTFTomoSeries_0", + "createdSeries": 1, + "restack": False, + } + assert "CTFTomoSeries_0" in protocol._definedOutputs + createdSet = protocol._definedOutputs["CTFTomoSeries_0"] + assert createdSet.isEmpty() is False + createdSeries = createdSet._seriesList[0] + assert len(createdSeries._items) == 2 + assert createdSeries._items[0]._enabled is True + assert createdSeries._items[1]._enabled is False + assert protocol._stored is True \ No newline at end of file diff --git a/tests/unit/backend/api/services/test_project_service_io_thumbnails.py b/tests/unit/backend/api/services/test_project_service_io_thumbnails.py new file mode 100644 index 0000000..6febd13 --- /dev/null +++ b/tests/unit/backend/api/services/test_project_service_io_thumbnails.py @@ -0,0 +1,406 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +import json +from pathlib import Path + +import pytest +from fastapi import HTTPException + + +class FakeOutput: + # fakeOutput + def __init__(self, fileName): + self._fileName = fileName + + def getFileName(self): + return self._fileName + + +class FakeProtocol: + # fakeProtocol + def __init__(self, protocolId, outputName=None, output=None): + self.protocolId = protocolId + if outputName is not None: + setattr(self, outputName, output) + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, protocols=None, exportPayload=None): + self.protocols = protocols or {} + self.exportPayload = exportPayload if exportPayload is not None else [{"id": 10}] + + def getProtocol(self, protocolId): + return self.protocols[int(protocolId)] + + def getProtocolsJson(self, protocolList): + return self.exportPayload + + +class FakeOutputsPreview: + # fakeOutputsPreview + instances = [] + + def __init__(self, currentProject, protocol, output, requestHeaders=None, colormapOverride=None): + self.currentProject = currentProject + self.protocol = protocol + self.output = output + self.requestHeaders = requestHeaders + self.colormapOverride = colormapOverride + self.lastPreviewCall = None + FakeOutputsPreview.instances.append(self) + + def preview(self, protocolId, outputPath, objMgr): + self.lastPreviewCall = { + "protocolId": protocolId, + "outputPath": outputPath, + "objMgr": objMgr, + } + return { + "preview": True, + "protocolId": protocolId, + "outputPath": outputPath, + "colormap": self.colormapOverride, + } + + +class FakeThumbnailService: + # fakeThumbnailService + instances = [] + + def __init__(self, currentProject): + self.currentProject = currentProject + self.calls = [] + FakeThumbnailService.instances.append(self) + + def buildProtocolThumbnail(self, protocolId, force=False, size=320, outputName=None): + self.calls.append( + { + "method": "buildProtocolThumbnail", + "protocolId": protocolId, + "force": force, + "size": size, + "outputName": outputName, + } + ) + return {"kind": "protocol", "protocolId": protocolId, "outputName": outputName} + + def buildProjectThumbnail(self, force=False, size=640, maxProtocols=6): + self.calls.append( + { + "method": "buildProjectThumbnail", + "force": force, + "size": size, + "maxProtocols": maxProtocols, + } + ) + return {"kind": "project", "size": size} + + def buildProtocolOutputThumbnail(self, protocolId, outputName, force=False, size=320): + self.calls.append( + { + "method": "buildProtocolOutputThumbnail", + "protocolId": protocolId, + "outputName": outputName, + "force": force, + "size": size, + } + ) + return {"kind": "output", "protocolId": protocolId, "outputName": outputName} + + def listProtocolThumbnailItems(self, projectId, force=False, size=320, maxProtocols=12, maxOutputsPerProtocol=4): + self.calls.append( + { + "method": "listProtocolThumbnailItems", + "projectId": projectId, + "force": force, + "size": size, + "maxProtocols": maxProtocols, + "maxOutputsPerProtocol": maxOutputsPerProtocol, + } + ) + return [{"projectId": projectId, "kind": "thumbnail-item"}] + + +class FakePayload: + # fakePayload + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + +@pytest.fixture +def projectServiceModule(authTestEnv): + # projectServiceModule + return importlib.import_module("app.backend.api.services.project_service") + + +@pytest.fixture +def service(projectServiceModule): + # service + instance = object.__new__(projectServiceModule.ProjectService) + instance.currentProject = FakeCurrentProject() + instance.tomoList = {} + return instance + + +def test_OutputPreviewDelegatesToOutputsPreview(projectServiceModule, service, monkeypatch, tmp_path): + FakeOutputsPreview.instances = [] + + outputFile = tmp_path / "output.sqlite" + outputFile.write_text("placeholder", encoding="utf-8") + + output = FakeOutput(str(outputFile)) + protocol = FakeProtocol(protocolId=10, outputName="outputMetadata", output=output) + service.currentProject = FakeCurrentProject(protocols={10: protocol}) + + monkeypatch.setattr(projectServiceModule, "OutputsPreview", FakeOutputsPreview) + monkeypatch.setattr(service, "_createObjectManager", lambda: {"manager": "fresh"}) + + result = service.outputPreview( + protocolId=10, + outputName="outputMetadata", + requestHeaders={"x-preview-colormap": "viridis"}, + colormap="plasma", + ) + + assert result == { + "preview": True, + "protocolId": 10, + "outputPath": str(outputFile), + "colormap": "plasma", + } + assert len(FakeOutputsPreview.instances) == 1 + assert FakeOutputsPreview.instances[0].lastPreviewCall == { + "protocolId": 10, + "outputPath": str(outputFile), + "objMgr": {"manager": "fresh"}, + } + + +def test_BuildProtocolThumbnailDelegatesToThumbnailService(projectServiceModule, service, monkeypatch): + FakeThumbnailService.instances = [] + monkeypatch.setattr(projectServiceModule, "ThumbnailService", FakeThumbnailService) + + result = service.buildProtocolThumbnail(protocolId=10, force=True, size=400, outputName="outputA") + + assert result == {"kind": "protocol", "protocolId": 10, "outputName": "outputA"} + assert FakeThumbnailService.instances[0].calls == [ + { + "method": "buildProtocolThumbnail", + "protocolId": 10, + "force": True, + "size": 400, + "outputName": "outputA", + } + ] + + +def test_BuildProjectThumbnailDelegatesToThumbnailService(projectServiceModule, service, monkeypatch): + FakeThumbnailService.instances = [] + monkeypatch.setattr(projectServiceModule, "ThumbnailService", FakeThumbnailService) + + result = service.buildProjectThumbnail(force=True, size=800, maxProtocols=9) + + assert result == {"kind": "project", "size": 800} + assert FakeThumbnailService.instances[0].calls == [ + { + "method": "buildProjectThumbnail", + "force": True, + "size": 800, + "maxProtocols": 9, + } + ] + + +def test_BuildProtocolOutputThumbnailDelegatesToThumbnailService(projectServiceModule, service, monkeypatch): + FakeThumbnailService.instances = [] + monkeypatch.setattr(projectServiceModule, "ThumbnailService", FakeThumbnailService) + + result = service.buildProtocolOutputThumbnail(protocolId=11, outputName="outputVol", force=False, size=256) + + assert result == {"kind": "output", "protocolId": 11, "outputName": "outputVol"} + assert FakeThumbnailService.instances[0].calls == [ + { + "method": "buildProtocolOutputThumbnail", + "protocolId": 11, + "outputName": "outputVol", + "force": False, + "size": 256, + } + ] + + +def test_ListProjectThumbnailItemsDelegatesToThumbnailService(projectServiceModule, service, monkeypatch): + FakeThumbnailService.instances = [] + monkeypatch.setattr(projectServiceModule, "ThumbnailService", FakeThumbnailService) + + result = service.listProjectThumbnailItems( + projectId=3, + force=True, + size=300, + maxProtocols=8, + maxOutputsPerProtocol=2, + ) + + assert result == [{"projectId": 3, "kind": "thumbnail-item"}] + assert FakeThumbnailService.instances[0].calls == [ + { + "method": "listProtocolThumbnailItems", + "projectId": 3, + "force": True, + "size": 300, + "maxProtocols": 8, + "maxOutputsPerProtocol": 2, + } + ] + + +def test_NormalizeExportJsonContentAcceptsJsonString(service): + content = service._normalizeExportJsonContent('[{"id": 10}]') + assert json.loads(content) == [{"id": 10}] + + +def test_NormalizeExportJsonContentSerializesDictAndList(service): + content = service._normalizeExportJsonContent({"ok": True}) + assert json.loads(content) == {"ok": True} + + +def test_NormalizeExportJsonContentRejectsEmptyString(service): + with pytest.raises(HTTPException) as exc: + service._normalizeExportJsonContent(" ") + + assert exc.value.status_code == 500 + assert exc.value.detail == "Scipion export returned empty content" + + +def test_NormalizeProtocolIdsForExportSkipsProjectAndDeduplicates(service): + result = service._normalizeProtocolIdsForExport([1, "1", "PROJECT", " ", "2", 2]) + assert result == ["1", "2"] + + +def test_SanitizeExportFilenameAddsJsonExtension(service): + assert service._sanitizeExportFilename("workflow_export") == "workflow_export.json" + assert service._sanitizeExportFilename("folder/name.json") == "name.json" + + +def test_GuardFsPathWithinRootForWriteRejectsEscape(service, tmp_path): + rootPath = tmp_path / "root" + rootPath.mkdir(parents=True, exist_ok=True) + + with pytest.raises(HTTPException) as exc: + service._guardFsPathWithinRootForWrite(rootPath, "../outside/file.json") + + assert exc.value.status_code == 403 + assert exc.value.detail == "Path escapes browser root" + + +def test_ExportProtocolsServiceWritesJsonFile(service, monkeypatch, tmp_path): + rootPath = tmp_path / "browser-root" + rootPath.mkdir(parents=True, exist_ok=True) + + protocol10 = FakeProtocol(protocolId=10) + protocol11 = FakeProtocol(protocolId=11) + service.currentProject = FakeCurrentProject( + protocols={10: protocol10, 11: protocol11}, + exportPayload=[{"protocolId": 10}, {"protocolId": 11}], + ) + + monkeypatch.setattr(service, "_resolveFsRootForWrite", lambda protocolId: rootPath) + + payload = FakePayload( + protocolIds=[10, 11], + directoryPath="exports", + filename="workflow-export", + ) + + result = service.exportProtocolsService( + mapper=object(), + projectId=1, + currentUser={"id": 1}, + payload=payload, + ) + + exportedPath = rootPath / "exports" / "workflow-export.json" + assert exportedPath.exists() is True + assert json.loads(exportedPath.read_text(encoding="utf-8")) == [ + {"protocolId": 10}, + {"protocolId": 11}, + ] + assert result == { + "success": True, + "path": str(exportedPath.resolve()), + "filename": "workflow-export.json", + "size": exportedPath.stat().st_size, + "mimeType": "application/json", + "protocolIds": ["10", "11"], + } + + +def test_ExportProtocolsServiceRejectsMissingProtocolIds(service): + payload = FakePayload( + protocolIds=[], + directoryPath="exports", + filename="workflow-export", + ) + + with pytest.raises(HTTPException) as exc: + service.exportProtocolsService( + mapper=object(), + projectId=1, + currentUser={"id": 1}, + payload=payload, + ) + + assert exc.value.status_code == 422 + assert exc.value.detail == "Missing protocolIds" + + +def test_WriteRemoteFileServiceWritesContent(service, monkeypatch, tmp_path): + rootPath = tmp_path / "browser-root" + rootPath.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr(service, "_resolveFsRootForWrite", lambda protocolId: rootPath) + + payload = FakePayload( + path="exports/result.json", + content='{"ok": true}', + mimeType="application/json", + ) + + result = service.writeRemoteFileService(protocolId="-1", payload=payload) + + targetPath = rootPath / "exports" / "result.json" + assert targetPath.exists() is True + assert targetPath.read_text(encoding="utf-8") == '{"ok": true}' + assert result == { + "success": True, + "path": str(targetPath.resolve()), + "size": targetPath.stat().st_size, + "mimeType": "application/json", + } \ No newline at end of file diff --git a/tests/unit/backend/api/services/test_project_service_tiltseries.py b/tests/unit/backend/api/services/test_project_service_tiltseries.py new file mode 100644 index 0000000..ef4cd9f --- /dev/null +++ b/tests/unit/backend/api/services/test_project_service_tiltseries.py @@ -0,0 +1,510 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +import math +from pathlib import Path + +import pytest + + +class FakeAcquisition: + # fakeAcquisition + def __init__(self, tiltAxisAngle=None, accumDose=None): + self._tiltAxisAngle = tiltAxisAngle + self._accumDose = accumDose + + def getTiltAxisAngle(self): + return self._tiltAxisAngle + + def getAccumDose(self): + return self._accumDose + + +class FakeTransform: + # fakeTransform + def __init__(self, rotDegrees=15.0, shiftX=3.5, shiftY=-2.0): + self._rotDegrees = rotDegrees + self._shiftX = shiftX + self._shiftY = shiftY + + def getEulerAngles(self): + return (0.0, 0.0, math.radians(-self._rotDegrees)) + + def getMatrixAsList(self): + return [1, 0, self._shiftX, 0, 1, self._shiftY, 0, 0, 1] + + +class FakeTiltImage: + # fakeTiltImage + def __init__( + self, + objId, + index, + order, + tiltAngle, + fileName, + enabled=True, + accumDose=None, + transform=None, + ): + self._objId = objId + self._index = index + self._order = order + self._tiltAngle = tiltAngle + self._fileName = fileName + self._enabled = enabled + self._acquisition = FakeAcquisition(accumDose=accumDose) + self._transform = transform + + def getObjId(self): + return self._objId + + def getIndex(self): + return self._index + + def getAcquisitionOrder(self): + return self._order + + def getTiltAngle(self): + return self._tiltAngle + + def isEnabled(self): + return self._enabled + + def getAcquisition(self): + return self._acquisition + + def getFileName(self): + return self._fileName + + def hasTransform(self): + return self._transform is not None + + def getTransform(self): + return self._transform + + +class FakeTiltSeries: + # fakeTiltSeries + def __init__(self, tsId, size, dims, samplingRate, tiltAxisAngle, items=None): + self._tsId = tsId + self._size = size + self._dims = dims + self._samplingRate = samplingRate + self._acquisition = FakeAcquisition(tiltAxisAngle=tiltAxisAngle) + self._items = items or [] + + def getTsId(self): + return self._tsId + + def getSize(self): + return self._size + + def getDim(self): + return self._dims + + def getSamplingRate(self): + return self._samplingRate + + def getAcquisition(self): + return self._acquisition + + def iterItems(self, iterate=False): + return list(self._items) + + def getItem(self, key, value): + if key != "_index": + return None + for item in self._items: + if item.getIndex() == value: + return item + return None + + +class FakeTiltSeriesSet: + # fakeTiltSeriesSet + def __init__(self, items=None, hasOddEven=False, dims=None): + self._items = items or [] + self._hasOddEven = hasOddEven + self._dims = dims or [64, 64, 32] + + def iterItems(self, iterate=False): + return list(self._items) + + def getItem(self, key, value): + if key != "_tsId": + return None + for item in self._items: + if str(item.getTsId()) == str(value): + return item + return None + + def hasOddEven(self): + return self._hasOddEven + + def getDim(self): + return self._dims + + def getSize(self): + return len(self._items) + + +class FakeProtocol: + # fakeProtocol + def __init__(self, outputName, output): + setattr(self, outputName, output) + + def getPath(self): + return "/tmp/fake-protocol-path" + + def _getExtraPath(self): + return "/tmp/fake-extra" + + def getOutputsSize(self): + return 0 + + def getNextOutputName(self, prefix): + return "TiltSeries_0" + + def _defineOutputs(self, **kwargs): + self._definedOutputs = kwargs + + def _store(self): + self._stored = True + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, protocol): + self._protocol = protocol + + def getProtocol(self, protocolId): + return self._protocol + + +class FakeOutputsPreview: + # fakeOutputsPreview + instances = [] + + def __init__(self, currentProject, protocol, output, requestHeaders=None): + self.currentProject = currentProject + self.protocol = protocol + self.output = output + self.requestHeaders = requestHeaders + self.lastRenderCall = None + FakeOutputsPreview.instances.append(self) + + def renderImageFromFilePath( + self, + filePath, + size, + fmt, + index, + applyTransform, + inline, + rot, + shifts, + ): + self.lastRenderCall = { + "filePath": filePath, + "size": size, + "fmt": fmt, + "index": index, + "applyTransform": applyTransform, + "inline": inline, + "rot": rot, + "shifts": shifts, + } + return {"rendered": True, "filePath": filePath} + + +class FakeCreatedTiltSeriesOutputSet: + # fakeCreatedTiltSeriesOutputSet + def __init__(self): + self._items = [] + self._dim = None + self._copiedInfoFrom = None + self._written = False + + def copyInfo(self, inputSet): + self._copiedInfoFrom = inputSet + + def setDim(self, dims): + self._dim = dims + + def append(self, item): + self._items.append(item) + + def getSize(self): + return len(self._items) + + def write(self): + self._written = True + + +@pytest.fixture +def projectServiceModule(authTestEnv): + # projectServiceModule + return importlib.import_module("app.backend.api.services.project_service") + + +@pytest.fixture +def service(projectServiceModule): + # service + instance = object.__new__(projectServiceModule.ProjectService) + instance.tomoList = {} + instance.currentProject = None + return instance + + +def test_ListOutputTiltSeriesServiceBuildsSummaries(service): + ts1 = FakeTiltSeries( + tsId="TS_001", + size=5, + dims=[128, 128, 40], + samplingRate=1.5, + tiltAxisAngle=90.0, + ) + ts2 = FakeTiltSeries( + tsId="TS_002", + size=3, + dims=[64, 64, 20], + samplingRate=2.0, + tiltAxisAngle=85.0, + ) + output = FakeTiltSeriesSet([ts1, ts2]) + protocol = FakeProtocol("outputTiltSeries", output) + service.currentProject = FakeCurrentProject(protocol) + + result = service.listOutputTiltSeriesService( + projectId=1, + protocolId=10, + outputName="outputTiltSeries", + ) + + assert result == [ + { + "tiltSeriesId": "TS_001", + "label": "TiltSeries TS_001", + "nViews": 5, + "dims": [128, 128, 40], + "pixelSize": 1.5, + "tiltAxisAngle": 90.0, + }, + { + "tiltSeriesId": "TS_002", + "label": "TiltSeries TS_002", + "nViews": 3, + "dims": [64, 64, 20], + "pixelSize": 2.0, + "tiltAxisAngle": 85.0, + }, + ] + + +def test_GetTiltSeriesFramesServiceBuildsFramesFromSelectedSeries(service, tmp_path): + imagePath1 = tmp_path / "tilt-1.mrc" + imagePath2 = tmp_path / "tilt-2.mrc" + imagePath1.write_text("placeholder", encoding="utf-8") + imagePath2.write_text("placeholder", encoding="utf-8") + + item1 = FakeTiltImage( + objId=101, + index=1, + order=1, + tiltAngle=-60.0, + fileName=str(imagePath1), + enabled=True, + accumDose=2.5, + transform=None, + ) + item2 = FakeTiltImage( + objId=102, + index=2, + order=2, + tiltAngle=-58.0, + fileName=str(imagePath2), + enabled=False, + accumDose=3.0, + transform=FakeTransform(rotDegrees=12.0, shiftX=4.0, shiftY=-1.0), + ) + + ts = FakeTiltSeries( + tsId="TS_001", + size=2, + dims=[128, 128, 40], + samplingRate=1.5, + tiltAxisAngle=90.0, + items=[item1, item2], + ) + output = FakeTiltSeriesSet([ts]) + protocol = FakeProtocol("outputTiltSeries", output) + service.currentProject = FakeCurrentProject(protocol) + + result = service.getTiltSeriesFramesService( + projectId=1, + protocolId=10, + outputName="outputTiltSeries", + tiltSeriesId="TS_001", + ) + + assert result["tiltSeriesId"] == "TS_001" + assert result["label"] == "TS_001" + assert len(result["frames"]) == 2 + + frame1 = result["frames"][0] + assert frame1 == { + "viewId": 101, + "index": 1, + "order": 1, + "tiltAngle": -60.0, + "excluded": False, + "dose": 2.5, + "path": "1@" + str(imagePath1), + } + + frame2 = result["frames"][1] + assert frame2["viewId"] == 102 + assert frame2["index"] == 2 + assert frame2["order"] == 2 + assert frame2["tiltAngle"] == -58.0 + assert frame2["excluded"] is True + assert frame2["dose"] == 3.0 + assert frame2["path"] == "2@" + str(imagePath2) + assert frame2["rot"] == pytest.approx(12.0) + assert frame2["shiftX"] == pytest.approx(4.0) + assert frame2["shiftY"] == pytest.approx(-1.0) + + +def test_RenderTiltSeriesImageServiceDelegatesToOutputsPreview(projectServiceModule, service, monkeypatch, tmp_path): + FakeOutputsPreview.instances = [] + + imagePath = tmp_path / "tilt-1.mrc" + imagePath.write_text("placeholder", encoding="utf-8") + + tiltImage = FakeTiltImage( + objId=101, + index=2, + order=2, + tiltAngle=-58.0, + fileName=str(imagePath), + enabled=True, + accumDose=2.0, + transform=FakeTransform(rotDegrees=20.0, shiftX=6.0, shiftY=-3.0), + ) + ts = FakeTiltSeries( + tsId="TS_001", + size=1, + dims=[128, 128, 40], + samplingRate=1.5, + tiltAxisAngle=90.0, + items=[tiltImage], + ) + output = FakeTiltSeriesSet([ts]) + protocol = FakeProtocol("outputTiltSeries", output) + service.currentProject = FakeCurrentProject(protocol) + + monkeypatch.setattr(projectServiceModule, "OutputsPreview", FakeOutputsPreview) + + result = service.renderTiltSeriesImageService( + projectId=1, + protocolId=10, + outputName="outputTiltSeries", + tiltSeriesId="TS_001", + index=2, + size=512, + fmt="png", + applyTransform=True, + inline=False, + ) + + assert result == { + "rendered": True, + "filePath": str(imagePath), + } + assert len(FakeOutputsPreview.instances) == 1 + assert FakeOutputsPreview.instances[0].lastRenderCall == { + "filePath": str(imagePath), + "size": 512, + "fmt": "png", + "index": 2, + "applyTransform": True, + "inline": False, + "rot": 20.0, + "shifts": (6.0, -3.0), + } + + +def test_CreateNewSetOfTiltSeriesServiceReturnsEmptyWhenNoSeriesCreated(projectServiceModule, service, monkeypatch): + createdOutputSet = FakeCreatedTiltSeriesOutputSet() + + class FakeSetOfTiltSeriesFactory: + # fakeSetOfTiltSeriesFactory + @staticmethod + def create(projectPath, suffix): + return createdOutputSet + + inputSet = FakeTiltSeriesSet(items=[], hasOddEven=False, dims=[128, 128, 40]) + protocol = FakeProtocol("outputTiltSeries", inputSet) + service.currentProject = FakeCurrentProject(protocol) + + monkeypatch.setattr(projectServiceModule, "SetOfTiltSeries", FakeSetOfTiltSeriesFactory) + + result = service.createNewSetOfTiltSeriesService( + projectId=1, + protocolId=10, + outputName="outputTiltSeries", + exclusions={}, + restack=False, + ) + + assert result == { + "status": "empty", + "outputName": "TiltSeries_0", + "createdTiltSeries": 0, + "hasOddEven": False, + "restack": False, + "message": "No output was generated because it cannot be empty", + } + assert createdOutputSet.getSize() == 0 + assert createdOutputSet._copiedInfoFrom is inputSet + assert createdOutputSet._dim == [128, 128, 40] + + +def test_ResolveOutputForTiltSeriesReturns404WhenProtocolMissing(service): + class BrokenCurrentProject: + # brokenCurrentProject + def getProtocol(self, protocolId): + raise RuntimeError("missing protocol") + + service.currentProject = BrokenCurrentProject() + + with pytest.raises(Exception) as exc: + service._resolveOutputForTiltSeries(10, "outputTiltSeries") + + assert exc.value.status_code == 404 + assert exc.value.detail == "Protocol not found" \ No newline at end of file diff --git a/tests/unit/backend/utils/test_file_handlers.py b/tests/unit/backend/utils/test_file_handlers.py new file mode 100644 index 0000000..76dd887 --- /dev/null +++ b/tests/unit/backend/utils/test_file_handlers.py @@ -0,0 +1,224 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import pytest +from fastapi import HTTPException, Response + + +class FakeProtocol: + # fakeProtocol + def __init__(self, protocolPath): + self._protocolPath = protocolPath + + def getPath(self): + return self._protocolPath + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, projectPath, protocolPath): + self._projectPath = projectPath + self._protocol = FakeProtocol(protocolPath) + + def getProtocol(self, protocolId): + return self._protocol + + def getPath(self): + return self._projectPath + + +@pytest.fixture +def fileHandlersModule(authTestEnv): + # fileHandlersModule + return importlib.import_module("app.backend.utils.file_handlers") + + +@pytest.fixture +def handlers(fileHandlersModule, tmp_path): + # handlers + projectRoot = tmp_path / "DemoProject" + protocolPath = projectRoot / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + currentProject = FakeCurrentProject( + projectPath=str(projectRoot), + protocolPath=str(protocolPath), + ) + return fileHandlersModule.FileHandlers(currentProject) + + +def test_GetProtocolPathBuildsBrowserContract(handlers): + result = handlers.getProtocolPath("10") + + assert result["rootAbs"].endswith("/DemoProject") + assert result["startPath"] == "Runs/000010_ProtImport" + assert result["protocolRoot"] == "Runs/000010_ProtImport" + assert result["path"].endswith("/DemoProject/Runs/000010_ProtImport") + + +def test_NormalizeRelPathClampsTraversal(fileHandlersModule): + assert fileHandlersModule.FileHandlers._normalizeRelPath("../a/../../b/./c") == "b/c" + assert fileHandlersModule.FileHandlers._normalizeRelPath("") == "" + assert fileHandlersModule.FileHandlers._normalizeRelPath("/") == "" + + +def test_GuardJoinRejectsAbsolutePaths(fileHandlersModule, tmp_path): + root = (tmp_path / "root").resolve() + root.mkdir(parents=True, exist_ok=True) + + with pytest.raises(HTTPException) as exc: + fileHandlersModule.FileHandlers._guardJoin(root, "/etc/passwd") + + assert exc.value.status_code == 400 + assert exc.value.detail == "Invalid path" + + +def test_ResolveWithinRootAcceptsAbsoluteChildPath(handlers, tmp_path): + root = (tmp_path / "root").resolve() + root.mkdir(parents=True, exist_ok=True) + + child = root / "folder" / "file.txt" + child.parent.mkdir(parents=True, exist_ok=True) + child.write_text("hello", encoding="utf-8") + + resolved = handlers._resolveWithinRoot(root, str(child)) + + assert resolved == child + + +def test_ListRemoteDirectoryUnderRootReturnsSortedEntries(handlers, tmp_path): + root = tmp_path / "browser-root" + root.mkdir(parents=True, exist_ok=True) + + folder = root / "FolderA" + folder.mkdir() + + fileTxt = root / "zeta.txt" + fileTxt.write_text("demo", encoding="utf-8") + + result = handlers.listRemoteDirectoryUnderRoot(root, "") + + assert result[0]["name"] == "FolderA" + assert result[0]["isDir"] is True + assert result[0]["path"] == "FolderA" + + assert result[1]["name"] == "zeta.txt" + assert result[1]["isDir"] is False + assert result[1]["path"] == "zeta.txt" + assert result[1]["size"] == 4 + assert result[1]["mime"] == "text/plain" + + +def test_PreviewTextFileUnderRootReturnsPlainText(handlers, tmp_path): + root = tmp_path / "browser-root" + root.mkdir(parents=True, exist_ok=True) + + fileTxt = root / "notes.txt" + fileTxt.write_text("hello file handlers", encoding="utf-8") + + response = handlers.previewTextFileUnderRoot(root, "notes.txt") + + assert isinstance(response, Response) + assert response.media_type == "text/plain; charset=utf-8" + assert response.body.decode("utf-8") == "hello file handlers" + + +def test_PreviewTextFileUnderRootRejectsBinaryFile(handlers, tmp_path): + root = tmp_path / "browser-root" + root.mkdir(parents=True, exist_ok=True) + + fileBin = root / "data.bin" + fileBin.write_bytes(b"\x00\x01\x02") + + with pytest.raises(HTTPException) as exc: + handlers.previewTextFileUnderRoot(root, "data.bin") + + assert exc.value.status_code == 415 + assert exc.value.detail == "Preview not available for this file type" + + +def test_BuildPreviewHeadersIncludesExposeList(handlers): + headers = handlers._buildPreviewHeaders( + { + "kind": "text", + "name": "notes.txt", + "mime": "text/plain", + "responseMime": "text/plain; charset=utf-8", + "width": 10, + "height": 20, + "sizeBytes": 123, + "note": "preview note", + } + ) + + assert headers["X-Preview-Kind"] == "text" + assert headers["X-Preview-Name"] == "notes.txt" + assert headers["X-Preview-Mime"] == "text/plain" + assert headers["X-Preview-ResponseMime"] == "text/plain; charset=utf-8" + assert headers["X-Preview-Width"] == "10" + assert headers["X-Preview-Height"] == "20" + assert headers["X-Preview-SizeBytes"] == "123" + assert headers["X-Preview-Note"] == "preview note" + assert headers["X-Preview-Schema"] == "scipion" + assert "X-Preview-Kind" in headers["Access-Control-Expose-Headers"] + + +def test_AttachPreviewContractAddsHeaders(handlers): + response = Response(content=b"hello", media_type="text/plain") + + enriched = handlers._attachPreviewContract( + response=response, + kind="text", + name="notes.txt", + meta={"mime": "text/plain", "sizeBytes": 5}, + ) + + assert enriched.headers["Content-Disposition"] == 'inline; filename="notes.txt"' + assert enriched.headers["X-Preview-Kind"] == "text" + assert enriched.headers["X-Preview-Name"] == "notes.txt" + assert enriched.headers["X-Preview-Mime"] == "text/plain" + assert enriched.headers["X-Preview-ResponseMime"] == "text/plain" + assert enriched.headers["X-Preview-SizeBytes"] == "5" + + +def test_PreviewRemoteEntryUnderRootWrapsTextPreviewWithContract(handlers, tmp_path): + root = tmp_path / "browser-root" + root.mkdir(parents=True, exist_ok=True) + + fileTxt = root / "notes.txt" + fileTxt.write_text("hello preview contract", encoding="utf-8") + + response = handlers.previewRemoteEntryUnderRoot(root, "notes.txt") + + assert response.media_type == "text/plain; charset=utf-8" + assert response.body.decode("utf-8") == "hello preview contract" + assert response.headers["X-Preview-Kind"] == "text" + assert response.headers["X-Preview-Name"] == "notes.txt" + assert response.headers["X-Preview-Mime"] == "text/plain" + assert response.headers["X-Preview-Schema"] == "scipion" \ No newline at end of file diff --git a/tests/unit/backend/utils/test_outputs_preview_core.py b/tests/unit/backend/utils/test_outputs_preview_core.py new file mode 100644 index 0000000..76e247a --- /dev/null +++ b/tests/unit/backend/utils/test_outputs_preview_core.py @@ -0,0 +1,350 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +import io +import json +import tarfile +import zipfile +from pathlib import Path + +import pytest +from fastapi.responses import Response + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, protocolPath): + self._protocolPath = protocolPath + + def getPath(self): + return self._protocolPath + + def getProtocol(self, protocolId): + return FakeProtocol(self._protocolPath) + + +class FakeProtocol: + # fakeProtocol + def __init__(self, protocolPath): + self._protocolPath = protocolPath + + def getPath(self): + return self._protocolPath + + def getObjId(self): + return 10 + + +class FakeOutput: + # fakeOutput + def __init__(self, objId=None): + self._objId = objId + + def getObjId(self): + return self._objId + + +class FakeObjectManager: + # fakeObjectManager + def __init__(self): + self._fileName = None + self._dao = None + self._tables = {} + self.selected = False + self.loaded = False + + def selectDAO(self): + self.selected = True + + def getTables(self): + self.loaded = True + return {} + + +@pytest.fixture +def outputsPreviewModule(authTestEnv): + # outputsPreviewModule + return importlib.import_module("app.backend.utils.outputs_preview") + + +@pytest.fixture +def preview(outputsPreviewModule, tmp_path): + # preview + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + currentProject = FakeCurrentProject(str(protocolPath.parent.parent)) + protocol = FakeProtocol(str(protocolPath)) + output = FakeOutput(objId=77) + + return outputsPreviewModule.OutputsPreview( + currentProject=currentProject, + protocol=protocol, + output=output, + ) + + +def test_OutputSignatureUsesObjId(outputsPreviewModule): + output = FakeOutput(objId=77) + + result = outputsPreviewModule._outputSignature(output) + + assert result == "FakeOutput:77" + + +def test_OutputSignatureFallsBackToPythonId(outputsPreviewModule): + class NoObjId: + pass + + output = NoObjId() + result = outputsPreviewModule._outputSignature(output) + + assert result.startswith("NoObjId:") + + +def test_PreviewPdfReturnsInlineResponse(preview, tmp_path): + pdfPath = tmp_path / "demo.pdf" + pdfPath.write_bytes(b"%PDF-1.4 demo") + + response = preview._previewPdf(pdfPath, inline=True) + + assert response.media_type == "application/pdf" + assert response.headers["Content-Disposition"] == 'inline; filename="demo.pdf"' + assert response.body == b"%PDF-1.4 demo" + + +def test_IsArchiveSuffixRecognizesArchiveExtensions(preview): + assert preview._isArchiveSuffix(".zip") is True + assert preview._isArchiveSuffix(".tar") is True + assert preview._isArchiveSuffix(".txt") is False + + +def test_PreviewArchiveInlineZipReturnsEntries(preview, tmp_path): + zipPath = tmp_path / "demo.zip" + with zipfile.ZipFile(zipPath, "w") as zf: + zf.writestr("folder/", "") + zf.writestr("folder/file.txt", "hello") + + response = preview._previewArchive(zipPath, inline=True) + + assert response.headers["X-Preview-Type"] == "archive" + assert response.headers["X-Archive-Kind"] == "zip" + + payload = json.loads(response.body.decode("utf-8")) + assert payload["entries"] == [ + {"name": "folder/", "isDir": True, "size": None, "compressedSize": None}, + {"name": "folder/file.txt", "isDir": False, "size": 5, "compressedSize": 5}, + ] + + +def test_PreviewArchiveAttachmentReturnsRawBytes(preview, tmp_path): + zipPath = tmp_path / "demo.zip" + with zipfile.ZipFile(zipPath, "w") as zf: + zf.writestr("file.txt", "hello") + + response = preview._previewArchive(zipPath, inline=False) + + assert response.headers["Content-Disposition"] == 'attachment; filename="demo.zip"' + assert response.body == zipPath.read_bytes() + + +def test_PreviewCsvTsvReturnsColumnsAndRows(preview, tmp_path): + csvPath = tmp_path / "table.csv" + csvPath.write_text("id,name\n1,alpha\n2,beta\n", encoding="utf-8") + + response = preview._previewCsvTsv(csvPath, limit=10, delimiter=",") + + assert response.headers["X-Preview-Type"] == "table" + assert response.headers["X-Preview-Format"] == "csv" + + payload = json.loads(response.body.decode("utf-8")) + assert payload == { + "columns": ["id", "name"], + "rows": [ + {"id": "1", "name": "alpha"}, + {"id": "2", "name": "beta"}, + ], + } + + +def test_PreviewStarParsesLoopBlock(preview, tmp_path): + starPath = tmp_path / "particles.star" + starPath.write_text( + "\n".join( + [ + "data_particles", + "loop_", + "_rlnImageName #1", + "_rlnDefocusU #2", + "1@stack.mrcs 15000", + "2@stack.mrcs 16000", + ] + ), + encoding="utf-8", + ) + + response = preview._previewStar(starPath, limit=10) + + assert response.headers["X-Preview-Type"] == "table" + assert response.headers["X-Preview-Format"] == "star" + + payload = json.loads(response.body.decode("utf-8")) + assert payload == { + "columns": ["rlnImageName", "rlnDefocusU"], + "rows": [ + {"rlnImageName": "1@stack.mrcs", "rlnDefocusU": "15000"}, + {"rlnImageName": "2@stack.mrcs", "rlnDefocusU": "16000"}, + ], + } + + +def test_PreviewStarWithoutLoopReturnsTextPreview(preview, tmp_path): + starPath = tmp_path / "nolook.star" + starPath.write_text("data_particles\n# no loop block here\n", encoding="utf-8") + + response = preview._previewStar(starPath, limit=10) + + assert response.headers["X-Preview-Type"] == "text" + assert "STAR without loop_ block" in response.headers["X-Preview-Note"] + + payload = json.loads(response.body.decode("utf-8")) + assert "data_particles" in payload["text"] + + +def test_FallbackBinaryAddsPreviewHeaders(preview, tmp_path): + binPath = tmp_path / "data.bin" + binPath.write_bytes(b"\x00\x01\x02\x03") + + response = preview._fallbackBinary(binPath, inline=True) + + assert response.headers["Content-Disposition"] == 'inline; filename="data.bin"' + assert response.headers["X-Preview-Mime"] == "application/octet-stream" + assert response.headers["X-Preview-SizeBytes"] == "4" + assert response.body == b"\x00\x01\x02\x03" + + +def test_MergeHeadersAndGetHeaderAreCaseInsensitive(preview): + preview._mergeHeaders( + { + "X-Scipion-Colormap": "viridis", + "X-Preview-Colormap": "plasma", + } + ) + + assert preview.requestHeaders["x-scipion-colormap"] == "viridis" + assert preview.requestHeaders["x-preview-colormap"] == "plasma" + assert preview._getHeader("X-Preview-Colormap") == "plasma" + + +def test_ResolveColormapForOutputTypeUsesOverrideFirst(outputsPreviewModule, tmp_path, monkeypatch): + class FakeSetOfVolumes: + # fakeSetOfVolumes + pass + + monkeypatch.setattr(outputsPreviewModule, "SetOfVolumes", FakeSetOfVolumes) + monkeypatch.setattr(outputsPreviewModule, "SetOfClasses3D", type("FakeSetOfClasses3D", (), {})) + monkeypatch.setattr( + outputsPreviewModule.RegistryViewerConfig, + "getConfig", + staticmethod(lambda outputType: {}), + ) + + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + preview = outputsPreviewModule.OutputsPreview( + currentProject=FakeCurrentProject(str(protocolPath.parent.parent)), + protocol=FakeProtocol(str(protocolPath)), + output=FakeSetOfVolumes(), + colormapOverride="inferno", + ) + + assert preview._resolveColormapForOutputType(defaultCmap="viridis") == "inferno" + + +def test_ResolveColormapForOutputTypeUsesHeaderWhenValid(outputsPreviewModule, tmp_path, monkeypatch): + class FakeSetOfVolumes: + # fakeSetOfVolumes + pass + + monkeypatch.setattr(outputsPreviewModule, "SetOfVolumes", FakeSetOfVolumes) + monkeypatch.setattr(outputsPreviewModule, "SetOfClasses3D", type("FakeSetOfClasses3D", (), {})) + monkeypatch.setattr( + outputsPreviewModule.RegistryViewerConfig, + "getConfig", + staticmethod(lambda outputType: {}), + ) + + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + preview = outputsPreviewModule.OutputsPreview( + currentProject=FakeCurrentProject(str(protocolPath.parent.parent)), + protocol=FakeProtocol(str(protocolPath)), + output=FakeSetOfVolumes(), + requestHeaders={"X-Scipion-Colormap": "viridis"}, + ) + + assert preview._resolveColormapForOutputType(defaultCmap="plasma") == "viridis" + + +def test_PreviewDispatchesToPdf(preview, monkeypatch, tmp_path): + pdfPath = tmp_path / "demo.pdf" + pdfPath.write_bytes(b"%PDF-1.4 demo") + + monkeypatch.setattr(preview, "_previewPdf", lambda filePath, inline: {"kind": "pdf", "path": str(filePath)}) + + result = preview.preview(protocolId=10, path=str(pdfPath), objectManager=FakeObjectManager()) + + assert result == {"kind": "pdf", "path": str(pdfPath)} + + +def test_PreviewDispatchesToTextDelegate(preview, monkeypatch, tmp_path): + textPath = tmp_path / "notes.txt" + textPath.write_text("hello", encoding="utf-8") + + monkeypatch.setattr(preview, "previewProtocolTextFile", lambda protocolId, path: {"kind": "text", "path": path}) + + result = preview.preview(protocolId=10, path=str(textPath), objectManager=FakeObjectManager()) + + assert result == {"kind": "text", "path": str(textPath)} + + +def test_PreviewDispatchesToImageDelegate(preview, monkeypatch, tmp_path): + imgPath = tmp_path / "slice.mrc" + imgPath.write_bytes(b"dummy") + + monkeypatch.setattr(preview, "_isPreviewableMrc", lambda filePath: True) + monkeypatch.setattr( + preview, + "previewProtocolImageFile", + lambda protocolId, path, inline: {"kind": "image", "path": path, "inline": inline}, + ) + + result = preview.preview(protocolId=10, path=str(imgPath), objectManager=FakeObjectManager()) + + assert result == {"kind": "image", "path": str(imgPath), "inline": True} \ No newline at end of file diff --git a/tests/unit/backend/utils/test_outputs_preview_gallery_fsc_tiltseries.py b/tests/unit/backend/utils/test_outputs_preview_gallery_fsc_tiltseries.py new file mode 100644 index 0000000..76426b4 --- /dev/null +++ b/tests/unit/backend/utils/test_outputs_preview_gallery_fsc_tiltseries.py @@ -0,0 +1,413 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import numpy as np +import pytest +from PIL import Image + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, protocolPath): + self._protocolPath = protocolPath + + def getPath(self): + return self._protocolPath + + +class FakeProtocol: + # fakeProtocol + def __init__(self, protocolPath): + self._protocolPath = protocolPath + + def getPath(self): + return self._protocolPath + + def getObjId(self): + return 10 + + +class FakeColumn: + # fakeColumn + def __init__(self, name): + self._name = name + + def getName(self): + return self._name + + +class FakeRow: + # fakeRow + def __init__(self, rowId, values): + self._id = rowId + self._values = values + + def getId(self): + return self._id + + def getValues(self): + return self._values + + +class FakeObjectManager: + # fakeObjectManager + def __init__(self, rowsByTable): + self.rowsByTable = rowsByTable + + def getTableRowCount(self, tableName): + return len(self.rowsByTable.get(tableName, [])) + + def getRows(self, tableName, offset, limit): + rows = self.rowsByTable.get(tableName, []) + return rows[offset:offset + limit] + + +class FakeFSC: + # fakeFSC + def __init__(self, label, x, y, resolution=None): + self._label = label + self._x = x + self._y = y + self._resolution = resolution + + def clone(self): + return self + + def getObjLabel(self): + return self._label + + def getData(self): + return [self._x, self._y] + + def calculateResolution(self, threshold): + return self._resolution + + +class FakeTiltSeries: + # fakeTiltSeries + def __init__(self, tsId, fileName, tiltAngles=None, label=None): + self._tsId = tsId + self._fileName = fileName + self._tiltAngles = tiltAngles or [] + self._label = label + + def getTsId(self): + return self._tsId + + def getObjLabel(self): + return self._label + + def getFileName(self): + return self._fileName + + def getTiltAngles(self): + return self._tiltAngles + + +class FakeTiltSeriesSet(list): + # fakeTiltSeriesSet + pass + + +class FakeImageReader: + # fakeImageReader + instances = [] + + def __init__(self, imageArray): + self.imageArray = imageArray + FakeImageReader.instances.append(self) + + def getImages(self): + return self.imageArray + + def getImage(self, index=0, pilImage=False): + slice2d = self.imageArray[index] + img = Image.fromarray(slice2d.astype(np.uint8), mode="L") + return img if pilImage else slice2d + + def getCentralImage(self, pilImage=False): + mid = self.imageArray.shape[0] // 2 + slice2d = self.imageArray[mid] + img = Image.fromarray(slice2d.astype(np.uint8), mode="L") + return img if pilImage else slice2d + + def highlightSlice(self, arr): + return arr + + def normalizeSlice(self, arr): + return arr + + +@pytest.fixture +def outputsPreviewModule(authTestEnv): + # outputsPreviewModule + return importlib.import_module("app.backend.utils.outputs_preview") + + +@pytest.fixture +def preview(outputsPreviewModule, tmp_path): + # preview + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + currentProject = FakeCurrentProject(str(protocolPath.parent.parent)) + protocol = FakeProtocol(str(protocolPath)) + + class GenericOutput: + # genericOutput + pass + + output = GenericOutput() + + return outputsPreviewModule.OutputsPreview( + currentProject=currentProject, + protocol=protocol, + output=output, + ) + + +def test_MakeGalleryFromTilesBuildsGrayscaleGallery(preview): + tiles = [ + np.full((20, 20), 50, dtype=np.uint8), + np.full((20, 20), 180, dtype=np.uint8), + ] + + pngBytes, meta = preview.makeGalleryFromTiles( + tiles=tiles, + cols=2, + tileSize=32, + labels=["A", "B"], + summary="2 items", + ) + + assert isinstance(pngBytes, bytes) + assert len(pngBytes) > 0 + assert meta["tiles"] == 2 + assert meta["grid"] == [1, 2] + assert meta["tileSize"] == 32 + assert meta["hasSummary"] is True + + +def test_MakeGalleryFromTilesBuildsRgbGallery(preview): + tiles = [ + np.zeros((16, 16, 3), dtype=np.uint8), + np.full((16, 16, 3), 255, dtype=np.uint8), + ] + + pngBytes, meta = preview.makeGalleryFromTiles( + tiles=tiles, + cols=2, + tileSize=32, + labels=None, + summary=None, + forceRgb=True, + ) + + assert isinstance(pngBytes, bytes) + assert len(pngBytes) > 0 + assert meta["tiles"] == 2 + assert meta["grid"] == [1, 2] + + +def test_BuildPreviewHeadersFallbackIncludesExpectedFields(preview): + headers = preview.buildPreviewHeadersFallback( + { + "mime": "image/png", + "width": 320, + "height": 180, + "tiles": 6, + "note": "gallery", + } + ) + + assert headers["X-Preview-Mime"] == "image/png" + assert headers["X-Preview-Width"] == "320" + assert headers["X-Preview-Height"] == "180" + assert headers["X-Preview-Tiles"] == "6" + assert headers["X-Preview-Note"] == "gallery" + assert "X-Preview-Mime" in headers["Access-Control-Expose-Headers"] + + +def test_PickSampleRowsReturnsFirstDeterministicRows(preview): + rows = [FakeRow(i, [i]) for i in range(10)] + objMgr = FakeObjectManager({"objects": rows}) + + result = preview._pickSampleRows(objMgr, "objects", want=4) + + assert [row._id for row in result] == [0, 1, 2, 3] + + +def test_GetRenderColumnIndexSupportsCaseInsensitiveAndSubstring(preview): + columns = [ + FakeColumn("_filename"), + FakeColumn("MicName"), + FakeColumn("stackReference"), + ] + + assert preview.getRenderColumnIndex(["micname"], columns) == 0 + assert preview.getRenderColumnIndex(["stack"], columns) == 0 + + +def test_ExtractPathFromRowParsesStackSpec(preview): + row = FakeRow(1, ["3@Runs/stack.mrcs"]) + + relPath, sliceIndex = preview.extractPathFromRow(row, 0) + + assert relPath == "Runs/stack.mrcs" + assert sliceIndex == 3 + + +def test_MakeFSCResponseBuildsPng(outputsPreviewModule, tmp_path, monkeypatch): + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + class FakeSetOfFSCs(list): + # fakeSetOfFSCs + pass + + monkeypatch.setattr(outputsPreviewModule, "SetOfFSCs", FakeSetOfFSCs) + + output = FakeSetOfFSCs( + [ + FakeFSC( + label="gold-standard", + x=np.array([0.05, 0.1, 0.15, 0.2], dtype=float), + y=np.array([0.9, 0.6, 0.3, 0.1], dtype=float), + resolution=5.2, + ) + ] + ) + + preview = outputsPreviewModule.OutputsPreview( + currentProject=FakeCurrentProject(str(protocolPath.parent.parent)), + protocol=FakeProtocol(str(protocolPath)), + output=output, + ) + + response = preview._makeFSCResponse("fsc_preview.png") + + assert response.media_type == "image/png" + assert response.headers["Content-Disposition"] == 'inline; filename="fsc_preview.png"' + assert response.headers["X-Preview-Mime"] == "image/png" + assert len(response.body) > 0 + + +def test_ListTiltSeriesFramesReturnsMetadata(outputsPreviewModule, tmp_path, monkeypatch): + FakeImageReader.instances = [] + + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + stackPath = protocolPath / "ts1.mrcs" + stackPath.write_text("placeholder", encoding="utf-8") + + imageArray = np.arange(3 * 4 * 5, dtype=np.uint8).reshape((3, 4, 5)) + monkeypatch.setattr( + outputsPreviewModule.ImageReadersRegistry, + "open", + staticmethod(lambda path: FakeImageReader(imageArray)), + ) + monkeypatch.setattr(outputsPreviewModule, "SetOfTiltSeries", FakeTiltSeriesSet) + + output = FakeTiltSeriesSet( + [ + FakeTiltSeries( + tsId="TS_001", + fileName=str(stackPath), + tiltAngles=[-60.0, -58.0, -56.0], + label="Series 1", + ) + ] + ) + + preview = outputsPreviewModule.OutputsPreview( + currentProject=FakeCurrentProject(str(protocolPath.parent.parent)), + protocol=FakeProtocol(str(protocolPath)), + output=output, + ) + + result = preview.listTiltSeriesFrames("TS_001") + + assert result == { + "name": "TS_001", + "nFrames": 3, + "dims": [5, 4], + "stackRelPath": "ts1.mrcs", + "tiltAngles": [-60.0, -58.0, -56.0], + } + + +def test_RenderTiltSeriesFrameBuildsImageResponse(outputsPreviewModule, tmp_path, monkeypatch): + FakeImageReader.instances = [] + + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + stackPath = protocolPath / "ts1.mrcs" + stackPath.write_text("placeholder", encoding="utf-8") + + imageArray = np.arange(3 * 4 * 5, dtype=np.uint8).reshape((3, 4, 5)) + monkeypatch.setattr( + outputsPreviewModule.ImageReadersRegistry, + "open", + staticmethod(lambda path: FakeImageReader(imageArray)), + ) + monkeypatch.setattr(outputsPreviewModule, "SetOfTiltSeries", FakeTiltSeriesSet) + + output = FakeTiltSeriesSet( + [ + FakeTiltSeries( + tsId="TS_001", + fileName=str(stackPath), + tiltAngles=[-60.0, -58.0, -56.0], + label="Series 1", + ) + ] + ) + + preview = outputsPreviewModule.OutputsPreview( + currentProject=FakeCurrentProject(str(protocolPath.parent.parent)), + protocol=FakeProtocol(str(protocolPath)), + output=output, + ) + + response = preview.renderTiltSeriesFrame( + tiltSeriesName="TS_001", + index=1, + size=64, + fmt="png", + inline=True, + applyTransform=True, + ) + + assert response.media_type == "image/png" + assert response.headers["Content-Disposition"] == 'inline; filename="TS_001_tilt-1.png"' + assert response.headers["X-Preview-Mime"] == "image/png" + assert response.headers["X-Preview-Note"] == "tiltSeries=TS_001 index=1" + assert len(response.body) > 0 \ No newline at end of file diff --git a/tests/unit/backend/utils/test_thumbnail_service_core.py b/tests/unit/backend/utils/test_thumbnail_service_core.py new file mode 100644 index 0000000..86fc87e --- /dev/null +++ b/tests/unit/backend/utils/test_thumbnail_service_core.py @@ -0,0 +1,367 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import pytest + + +class FakeOutput: + # fakeOutput + def __init__(self, className="SetOfParticles", size=5): + self._className = className + self._size = size + + def getClassName(self): + return self._className + + def getSize(self): + return self._size + + +class FakeProtocol: + # fakeProtocol + def __init__(self, objId, label="Protocol", status="finished", outputs=None): + self._objId = objId + self._label = label + self._status = status + self._outputs = outputs or [] + + def getObjId(self): + return self._objId + + def getObjLabel(self): + return self._label + + def getStatus(self): + return self._status + + def iterOutputAttributes(self): + for item in self._outputs: + yield item + + +class FakeNode: + # fakeNode + def __init__(self, run=None): + self.run = run + + +class FakeGraph: + # fakeGraph + def __init__(self, nodesDict): + self._nodesDict = nodesDict + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, projectPath, protocols=None, graph=None): + self._projectPath = projectPath + self._protocols = protocols or {} + self._graph = graph + + def getPath(self): + return self._projectPath + + def getProtocol(self, protocolId): + return self._protocols.get(int(protocolId)) + + def getRunsGraph(self, refresh=False, checkPids=False): + return self._graph + + +@pytest.fixture +def thumbnailServiceModule(authTestEnv): + # thumbnailServiceModule + return importlib.import_module("app.backend.utils.thumbnail_service") + + +@pytest.fixture +def service(thumbnailServiceModule, tmp_path): + # service + projectPath = tmp_path / "DemoProject" + projectPath.mkdir(parents=True, exist_ok=True) + + currentProject = FakeCurrentProject(projectPath=str(projectPath)) + return thumbnailServiceModule.ThumbnailService(currentProject) + + +def test_CacheSafeTokenSanitizesAndAddsDigest(service): + token = service._cacheSafeToken("output particles / avg") + + assert token.startswith("output_particles_avg_") + assert len(token.split("_")[-1]) == 10 + + +def test_SlugOutputNameReturnsBestForEmpty(service): + assert service._slugOutputName(None) == "best" + assert service._slugOutputName("") == "best" + + +def test_SlugOutputNameSanitizesValue(service): + assert service._slugOutputName("output particles / avg") == "output_particles_avg" + + +def test_GetProtocolCachePathUsesExpectedPattern(service): + path = service._getProtocolCachePath(protocolId=10, size=320, outputName="outputAvg") + + assert path.name == "protocol_10_outputAvg_320_v1.png" + assert path.parent.name == ".thumbnail_cache" + + +def test_GetProjectCachePathUsesExpectedPattern(service): + path = service._getProjectCachePath(size=720, maxProtocols=6) + + assert path.name == "project_720_6_v1.png" + assert path.parent.name == ".thumbnail_cache" + + +def test_UniqueIntsDeduplicatesAndSkipsNegatives(service): + assert service._uniqueInts([3, "3", -1, 5, 5, "bad", 0]) == [3, 5, 0] + + +def test_IsLikelyPreviewFileRecognizesRelevantNames(service): + assert service._isLikelyPreviewFile("thumb_protocol.png") is True + assert service._isLikelyPreviewFile("class_average.mrc") is True + assert service._isLikelyPreviewFile("notes.txt") is False + + +def test_ProjectProtocolSizeScalesByMaxProtocols(service): + assert service._projectProtocolSize(size=900, maxProtocols=1) >= 400 + assert service._projectProtocolSize(size=900, maxProtocols=2) >= 340 + assert service._projectProtocolSize(size=900, maxProtocols=6) >= 300 + + +def test_ScoreProtocolStatusMapsKnownStates(service): + protocolFinished = FakeProtocol(1, status="finished") + protocolRunning = FakeProtocol(2, status="running") + protocolFailed = FakeProtocol(3, status="failed") + protocolUnknown = FakeProtocol(4, status="whatever") + + assert service._scoreProtocolStatus(protocolFinished) == 120 + assert service._scoreProtocolStatus(protocolRunning) == 70 + assert service._scoreProtocolStatus(protocolFailed) == -200 + assert service._scoreProtocolStatus(protocolUnknown) == 10 + + +def test_ScoreOutputRewardsUsefulOutputs(service): + particles = FakeOutput(className="SetOfParticles", size=5) + mask = FakeOutput(className="VolumeMask", size=2) + generic = FakeOutput(className="SomethingRenderable", size=1) + + assert service._scoreOutput("outputParticles", particles) > 0 + assert service._scoreOutput("outputMask", mask) > 0 + assert service._scoreOutput("tmpDebug", generic) == 0 + + +def test_IterProtocolsSkipsProjectAndSortsByStatus(thumbnailServiceModule, tmp_path): + projectPath = tmp_path / "DemoProject" + projectPath.mkdir(parents=True, exist_ok=True) + + protocol1 = FakeProtocol(1, label="One", status="running") + protocol2 = FakeProtocol(2, label="Two", status="finished") + + graph = FakeGraph( + { + "PROJECT": FakeNode(run=None), + "1": FakeNode(run=protocol1), + "2": FakeNode(run=protocol2), + } + ) + currentProject = FakeCurrentProject( + projectPath=str(projectPath), + protocols={1: protocol1, 2: protocol2}, + graph=graph, + ) + service = thumbnailServiceModule.ThumbnailService(currentProject) + + protocols = service._iterProtocols() + + assert [p.getObjId() for p in protocols] == [2, 1] + + +def test_ListUsefulProtocolsFiltersAndSortsCandidates(service, monkeypatch): + protocol1 = FakeProtocol(1, label="Prot 1", status="finished") + protocol2 = FakeProtocol(2, label="Prot 2", status="running") + protocol3 = FakeProtocol(3, label="Prot 3", status="failed") + + monkeypatch.setattr(service, "_iterProtocols", lambda: [protocol1, protocol2, protocol3]) + monkeypatch.setattr( + service, + "_selectBestOutput", + lambda protocol: { + "outputName": "outputA", + "output": FakeOutput(className="SetOfParticles", size=5), + "outputClassName": "SetOfParticles", + "score": {1: 120, 2: 110, 3: 40}[protocol.getObjId()], + }, + ) + monkeypatch.setattr( + service, + "_safeOutputSize", + lambda output: output.getSize(), + ) + monkeypatch.setattr( + service, + "_scoreProtocolStatus", + lambda protocol: {1: 120, 2: 70, 3: -200}[protocol.getObjId()], + ) + + result = service.listUsefulProtocols(maxProtocols=5) + + assert [item["protocolId"] for item in result] == [1, 2] + assert result[0]["protocolLabel"] == "Prot 1" + assert result[0]["outputName"] == "outputA" + assert result[0]["itemsCount"] == 5 + assert result[0]["score"] > result[1]["score"] + + +def test_BuildProtocolThumbnailReturnsCachedEntry(service, monkeypatch, tmp_path): + protocol = FakeProtocol(10, label="Prot 10", status="finished") + service.currentProject._protocols = {10: protocol} + + cachePath = tmp_path / "protocol_10_best_320_v1.png" + cachePath.write_text("cached", encoding="utf-8") + + monkeypatch.setattr(service, "_getProtocolCachePath", lambda protocolId, size, outputName=None: cachePath) + + result = service.buildProtocolThumbnail(protocolId=10, force=False, size=320) + + assert result == { + "protocolId": 10, + "protocolLabel": "Prot 10", + "status": "finished", + "outputName": None, + "outputClassName": None, + "absolutePath": str(cachePath), + "cached": True, + "exists": True, + } + + +def test_BuildProtocolThumbnailReturnsMissingWhenRequestedOutputDoesNotExist(service, monkeypatch): + protocol = FakeProtocol(10, label="Prot 10", status="finished") + service.currentProject._protocols = {10: protocol} + + monkeypatch.setattr(service, "_collectSortedOutputCandidates", lambda protocolObj: []) + + result = service.buildProtocolThumbnail( + protocolId=10, + force=False, + size=320, + outputName="missingOutput", + ) + + assert result == { + "protocolId": 10, + "protocolLabel": "Prot 10", + "status": "finished", + "outputName": "missingOutput", + "outputClassName": None, + "absolutePath": None, + "cached": False, + "exists": False, + } + + +def test_BuildProjectThumbnailReturnsCachedStrip(service, monkeypatch, tmp_path): + cachePath = tmp_path / "project_720_6_v1.png" + cachePath.write_text("cached", encoding="utf-8") + + monkeypatch.setattr(service, "_getProjectCachePath", lambda size, maxProtocols: cachePath) + + result = service.buildProjectThumbnail(force=False, size=720, maxProtocols=6) + + assert result == { + "absolutePath": str(cachePath), + "cached": True, + "items": None, + } + + +def test_ListProtocolThumbnailItemsBuildsGroups(service, monkeypatch): + protocol = FakeProtocol(11, label="Prot 11", status="running") + + monkeypatch.setattr(service, "_iterProtocols", lambda: [protocol]) + monkeypatch.setattr( + service, + "_collectSortedOutputCandidates", + lambda protocolObj: [ + { + "outputName": "outputVol", + "outputClassName": "SetOfVolumes", + "score": 100, + "itemsCount": 3, + }, + { + "outputName": "outputParticles", + "outputClassName": "SetOfParticles", + "score": 95, + "itemsCount": 6, + }, + ], + ) + monkeypatch.setattr( + service, + "buildProtocolThumbnail", + lambda protocolId, force, size, outputName=None: { + "exists": True, + "absolutePath": f"/tmp/{protocolId}_{outputName}.png", + }, + ) + + result = service.listProtocolThumbnailItems( + projectId=7, + force=True, + size=300, + maxProtocols=5, + maxOutputsPerProtocol=2, + ) + + assert result == [ + { + "protocolId": 11, + "label": "Prot 11", + "status": "running", + "outputs": [ + { + "outputName": "outputVol", + "outputClassName": "SetOfVolumes", + "exists": True, + "thumbnailUrl": "/projects/7/protocols/11/thumbnail?outputName=outputVol", + "thumbnailRebuildUrl": "/projects/7/protocols/11/thumbnail/rebuild?outputName=outputVol", + }, + { + "outputName": "outputParticles", + "outputClassName": "SetOfParticles", + "exists": True, + "thumbnailUrl": "/projects/7/protocols/11/thumbnail?outputName=outputParticles", + "thumbnailRebuildUrl": "/projects/7/protocols/11/thumbnail/rebuild?outputName=outputParticles", + }, + ], + } + ] \ No newline at end of file diff --git a/tests/unit/backend/utils/test_thumbnail_service_helpers.py b/tests/unit/backend/utils/test_thumbnail_service_helpers.py new file mode 100644 index 0000000..5d6fa24 --- /dev/null +++ b/tests/unit/backend/utils/test_thumbnail_service_helpers.py @@ -0,0 +1,230 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import numpy as np +import pytest +from PIL import Image + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, projectPath): + self._projectPath = projectPath + + def getPath(self): + return self._projectPath + + +@pytest.fixture +def thumbnailServiceModule(authTestEnv): + # thumbnailServiceModule + return importlib.import_module("app.backend.utils.thumbnail_service") + + +@pytest.fixture +def service(thumbnailServiceModule, tmp_path): + # service + projectPath = tmp_path / "DemoProject" + projectPath.mkdir(parents=True, exist_ok=True) + + currentProject = FakeCurrentProject(str(projectPath)) + return thumbnailServiceModule.ThumbnailService(currentProject) + + +def test_NormalizeArrayToUint8ScalesFiniteData(service): + array = np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32) + + result = service._normalizeArrayToUint8(array) + + assert result.dtype == np.uint8 + assert result.shape == (2, 2) + assert int(result.min()) == 0 + assert int(result.max()) == 255 + + +def test_NormalizeArrayToUint8HandlesNaNs(service): + array = np.array([[0.0, np.nan], [np.inf, 1.0]], dtype=np.float32) + + result = service._normalizeArrayToUint8(array) + + assert result.dtype == np.uint8 + assert result.shape == (2, 2) + + +def test_GrayTileToImageReturnsRgbImage(service): + gray = np.array([[0, 128], [255, 64]], dtype=np.uint8) + + image = service._grayTileToImage(gray) + + assert image is not None + assert image.mode == "RGB" + assert image.size == (2, 2) + + +def test_RgbTileToImageReturnsRgbImage(service): + rgb = np.zeros((3, 4, 3), dtype=np.uint8) + rgb[:, :, 0] = 255 + + image = service._rgbTileToImage(rgb) + + assert image is not None + assert image.mode == "RGB" + assert image.size == (4, 3) + + +def test_ArrayToImageUsesColormap(service): + array = np.array([[0.0, 0.5], [0.75, 1.0]], dtype=np.float32) + + image = service._arrayToImage(array, cmapName="viridis") + + assert image is not None + assert image.mode == "RGB" + assert image.size == (2, 2) + + +def test_ApplyColormapReturnsRgbUint8(service): + gray = np.array([[0, 64], [128, 255]], dtype=np.uint8) + + rgb = service._applyColormap(gray, cmapName="viridis") + + assert rgb.dtype == np.uint8 + assert rgb.shape == (2, 2, 3) + + +def test_NormalizePilImageConvertsToRgb(service): + image = Image.new("L", (10, 6), color=120) + + normalized = service._normalizePilImage(image) + + assert normalized.mode == "RGB" + assert normalized.size == (10, 6) + + +def test_StatusAccentReturnsExpectedColors(service): + assert service._statusAccent("finished") == (75, 170, 96) + assert service._statusAccent("running") == (59, 130, 246) + assert service._statusAccent("failed") == (220, 38, 38) + assert service._statusAccent("unknown") == (148, 163, 184) + + +def test_MixColorInterpolatesBetweenTwoColors(service): + mixed = service._mixColor((0, 0, 0), (255, 255, 255), 0.5) + + assert mixed == (128, 128, 128) + + +def test_BuildRoundedMaskCreatesExpectedSize(service): + mask = service._buildRoundedMask((120, 80), radius=12) + + assert mask.mode == "L" + assert mask.size == (120, 80) + + +def test_MakeProtocolPlaceholderPreviewBuildsCanvas(service): + image = service._makeProtocolPlaceholderPreview(status="running", size=320) + + assert image.mode == "RGB" + assert image.size[0] >= 180 + assert image.size[1] >= 120 + + +def test_FinalizeProtocolThumbnailProducesExpectedAspect(service): + preview = Image.new("RGB", (120, 120), color=(200, 200, 200)) + + thumb = service._finalizeProtocolThumbnail( + previewImage=preview, + size=320, + protocolId=10, + ) + + assert thumb.mode == "RGB" + assert thumb.size == (320, 218) + + +def test_ComposeProjectStripBuildsHorizontalCanvas(service, tmp_path): + img1 = tmp_path / "thumb1.png" + img2 = tmp_path / "thumb2.png" + + Image.new("RGB", (120, 80), color=(255, 0, 0)).save(img1) + Image.new("RGB", (120, 80), color=(0, 255, 0)).save(img2) + + strip = service._composeProjectStrip( + items=[ + {"absolutePath": str(img1)}, + {"absolutePath": str(img2)}, + ], + size=720, + ) + + assert strip.mode == "RGB" + assert strip.size[0] > strip.size[1] + assert strip.size[0] > 200 + + +def test_ComposeCleanGridBuildsMosaic(service): + tiles = [ + Image.new("RGB", (80, 80), color=(255, 0, 0)), + Image.new("RGB", (80, 80), color=(0, 255, 0)), + Image.new("RGB", (80, 80), color=(0, 0, 255)), + ] + + grid = service._composeCleanGrid( + tiles=tiles, + maxCols=2, + targetWidth=320, + ) + + assert grid.mode == "RGB" + assert grid.size[0] > 0 + assert grid.size[1] > 0 + + +def test_ComposeCleanStripBuildsStrip(service): + panels = [ + Image.new("RGB", (90, 120), color=(255, 0, 0)), + Image.new("RGB", (90, 120), color=(0, 255, 0)), + ] + + strip = service._composeCleanStrip(panels=panels, targetHeight=180) + + assert strip.mode == "RGB" + assert strip.size[0] > strip.size[1] + + +def test_ComposeTriptychBuildsThreePanelLayout(service): + panels = [ + Image.new("RGB", (120, 120), color=(255, 0, 0)), + Image.new("RGB", (120, 120), color=(0, 255, 0)), + Image.new("RGB", (120, 120), color=(0, 0, 255)), + ] + + triptych = service._composeTriptych(panels=panels, targetHeight=180) + + assert triptych.mode == "RGB" + assert triptych.size[0] > triptych.size[1] \ No newline at end of file diff --git a/tests/unit/backend/utils/test_thumbnail_service_utils.py b/tests/unit/backend/utils/test_thumbnail_service_utils.py new file mode 100644 index 0000000..0773a67 --- /dev/null +++ b/tests/unit/backend/utils/test_thumbnail_service_utils.py @@ -0,0 +1,296 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import importlib +from pathlib import Path + +import pytest + + +class FakeCurrentProject: + # fakeCurrentProject + def __init__(self, projectPath): + self._projectPath = projectPath + + def getPath(self): + return self._projectPath + + +class FakeProtocol: + # fakeProtocol + def __init__(self, objId=10, label="Protocol", status="finished", protocolPath=None): + self._objId = objId + self._label = label + self._status = status + self._protocolPath = protocolPath + + def getObjId(self): + return self._objId + + def getObjLabel(self): + return self._label + + def getStatus(self): + return self._status + + def getPath(self): + return self._protocolPath + + def __str__(self): + return "ProtocolString" + + +class FakeOutput: + # fakeOutput + def __init__(self, className="SetOfParticles", size=5, fileName=None): + self._className = className + self._size = size + self._fileName = fileName + + def getClassName(self): + return self._className + + def getSize(self): + return self._size + + def getFileName(self): + return self._fileName + + +class FakeOutputBrokenSize: + # fakeOutputBrokenSize + def getSize(self): + raise RuntimeError("size error") + + +class FakeItem: + # fakeItem + def __init__(self, fileName=None, enabled=True): + self._fileName = fileName + self._enabled = enabled + + def getFileName(self): + return self._fileName + + def isEnabled(self): + return self._enabled + + +class FakeOutputWithItems: + # fakeOutputWithItems + def __init__(self, fileName=None, items=None): + self._fileName = fileName + self._items = items or [] + + def getFileName(self): + return self._fileName + + def iterItems(self, iterate=False): + return list(self._items) + + +class FakeRenderableByTomograms: + # fakeRenderableByTomograms + def getTomograms(self): + return [] + + +@pytest.fixture +def thumbnailServiceModule(authTestEnv): + # thumbnailServiceModule + return importlib.import_module("app.backend.utils.thumbnail_service") + + +@pytest.fixture +def service(thumbnailServiceModule, tmp_path): + # service + projectPath = tmp_path / "DemoProject" + projectPath.mkdir(parents=True, exist_ok=True) + + currentProject = FakeCurrentProject(str(projectPath)) + return thumbnailServiceModule.ThumbnailService(currentProject) + + +def test_GetProtocolLabelUsesObjLabel(service): + protocol = FakeProtocol(label="My Protocol") + + assert service._getProtocolLabel(protocol) == "My Protocol" + + +def test_GetProtocolLabelFallsBackToString(service): + class NoLabelProtocol: + # noLabelProtocol + def __str__(self): + return "StringFallback" + + assert service._getProtocolLabel(NoLabelProtocol()) == "StringFallback" + + +def test_GetProtocolStatusUsesCallable(service): + protocol = FakeProtocol(status="running") + + assert service._getProtocolStatus(protocol) == "running" + + +def test_GetProtocolStatusFallsBackToUnknown(service): + class NoStatusProtocol: + pass + + assert service._getProtocolStatus(NoStatusProtocol()) == "unknown" + + +def test_GetOutputClassNameUsesGetter(service): + output = FakeOutput(className="SetOfVolumes") + + assert service._getOutputClassName(output) == "SetOfVolumes" + + +def test_SafeOutputSizeReturnsInteger(service): + output = FakeOutput(size=12) + + assert service._safeOutputSize(output) == 12 + + +def test_SafeOutputSizeReturnsNoneOnFailure(service): + assert service._safeOutputSize(FakeOutputBrokenSize()) is None + + +def test_IsEnabledUsesMethod(service): + assert service._isEnabled(FakeItem(enabled=True)) is True + assert service._isEnabled(FakeItem(enabled=False)) is False + + +def test_IsEnabledFallsBackToTrueWhenMissing(service): + class NoEnabledInfo: + pass + + assert service._isEnabled(NoEnabledInfo()) is True + + +def test_ResolveFilePathFindsAbsoluteExistingPath(service, tmp_path): + filePath = tmp_path / "absolute.mrc" + filePath.write_text("placeholder", encoding="utf-8") + + protocol = FakeProtocol(protocolPath=str(tmp_path / "Runs" / "Prot")) + + resolved = service._resolveFilePath(protocol, str(filePath)) + + assert resolved == filePath.resolve() + + +def test_ResolveFilePathFindsRelativePathUnderProtocol(service, tmp_path): + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + relativeFile = protocolPath / "extra" / "image.mrc" + relativeFile.parent.mkdir(parents=True, exist_ok=True) + relativeFile.write_text("placeholder", encoding="utf-8") + + protocol = FakeProtocol(protocolPath=str(protocolPath)) + + resolved = service._resolveFilePath(protocol, "extra/image.mrc") + + assert resolved == relativeFile.resolve() + + +def test_ResolveFilePathReturnsNoneWhenMissing(service, tmp_path): + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + protocol = FakeProtocol(protocolPath=str(protocolPath)) + + resolved = service._resolveFilePath(protocol, "missing/file.mrc") + + assert resolved is None + + +def test_CollectDirectVolumePathsDeduplicates(thumbnailServiceModule, service, monkeypatch, tmp_path): + protocolPath = tmp_path / "DemoProject" / "Runs" / "000010_ProtImport" + protocolPath.mkdir(parents=True, exist_ok=True) + + vol1 = protocolPath / "vol1.mrc" + vol2 = protocolPath / "vol2.mrc" + vol1.write_text("placeholder", encoding="utf-8") + vol2.write_text("placeholder", encoding="utf-8") + + monkeypatch.setattr(thumbnailServiceModule, "EMSet", FakeOutputWithItems) + + protocol = FakeProtocol(protocolPath=str(protocolPath)) + output = FakeOutputWithItems( + fileName="vol1.mrc", + items=[ + FakeItem(fileName="vol1.mrc"), + FakeItem(fileName="vol2.mrc"), + FakeItem(fileName="vol2.mrc"), + ], + ) + + paths = service._collectDirectVolumePaths(protocol, output, maxItems=6) + + + assert paths == [vol1.resolve(), vol2.resolve()] + + +def test_LooksRenderableOutputRecognizesFileBackedOutput(service): + output = FakeOutput(fileName="output.mrc") + + assert service._looksRenderableOutput(output) is True + + +def test_LooksRenderableOutputRecognizesIterableOutput(service): + output = FakeOutputWithItems(fileName=None, items=[]) + + assert service._looksRenderableOutput(output) is True + + +def test_LooksRenderableOutputRecognizesTomogramsGetter(service): + assert service._looksRenderableOutput(FakeRenderableByTomograms()) is True + + +def test_LooksRenderableOutputRejectsNone(service): + assert service._looksRenderableOutput(None) is False + + +def test_FilesystemPreviewSortKeyPrioritizesThumbnailNames(service, tmp_path): + thumb = tmp_path / "thumb_preview.png" + raw = tmp_path / "raw_data.mrc" + + thumb.write_text("thumb", encoding="utf-8") + raw.write_text("raw", encoding="utf-8") + + keyThumb = service._filesystemPreviewSortKey(thumb) + keyRaw = service._filesystemPreviewSortKey(raw) + + assert keyThumb < keyRaw + + +def test_VolumeLikeExtensionsContainsExpectedTypes(service): + exts = service._volumeLikeExtensions() + + assert ".mrc" in exts + assert ".map" in exts + assert ".mrcs" in exts + assert ".h5" in exts \ No newline at end of file