diff --git a/app/backend/api/routers/coords2d_router.py b/app/backend/api/routers/coords2d_router.py new file mode 100644 index 0000000..577b789 --- /dev/null +++ b/app/backend/api/routers/coords2d_router.py @@ -0,0 +1,187 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import logging +from typing import Any, List, Optional +from pydantic import BaseModel, Field + +from fastapi import APIRouter, Depends, Query, status + +from app.backend.api.dependencies import getCurrentUser +from app.backend.api.services.coords2d_service import Coords2dService +from app.backend.database import getMapper +from app.backend.mapper.postgresql import PostgresqlFlatMapper + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/projects", tags=["coordinates2d"]) + + +class Coords2dPointPayload(BaseModel): + id: Optional[Any] = None + micId: Optional[Any] = None + x: float + y: float + + +class Coords2dMicrographPayload(BaseModel): + id: Any + coordinates: List[Coords2dPointPayload] = Field(default_factory=list) + + +class CreateCoords2dOutputPayload(BaseModel): + boxSize: Optional[int] = None + outputName: Optional[str] = None + micrographs: List[Coords2dMicrographPayload] = Field(default_factory=list) + + +def getCoords2dService() -> Coords2dService: + return Coords2dService() + + +@router.get( + "/{projectId}/protocols/{protocolId}/outputs/{outputName}/coords2d/micrographs", + response_model=Any, + status_code=status.HTTP_200_OK, +) +def listCoords2dMicrographs( + projectId: int, + protocolId: int, + outputName: str, + currentUser=Depends(getCurrentUser), + mapper: PostgresqlFlatMapper = Depends(getMapper), + service: Coords2dService = Depends(getCoords2dService), +): + return service.listMicrographs( + mapper=mapper, + projectId=projectId, + currentUser=currentUser, + protocolId=protocolId, + outputName=outputName, + ) + + +@router.get( + "/{projectId}/protocols/{protocolId}/outputs/{outputName}/coords2d/micrographs/{micId}/coordinates", + response_model=Any, + status_code=status.HTTP_200_OK, +) +def getCoords2dMicrographCoordinates( + projectId: int, + protocolId: int, + outputName: str, + micId: str, + currentUser=Depends(getCurrentUser), + mapper: PostgresqlFlatMapper = Depends(getMapper), + service: Coords2dService = Depends(getCoords2dService), +): + return service.listCoordinatesForMicrograph( + mapper=mapper, + projectId=projectId, + currentUser=currentUser, + protocolId=protocolId, + outputName=outputName, + micId=micId, + ) + + +@router.get( + "/{projectId}/protocols/{protocolId}/outputs/{outputName}/coords2d/micrographs/{micId}/image", + response_model=Any, + status_code=status.HTTP_200_OK, +) +def getCoords2dMicrographImage( + projectId: int, + protocolId: int, + outputName: str, + micId: str, + size: int = Query(2200, ge=64, le=4096), + format: str = Query("png", pattern="^(png|webp|jpeg|jpg)$"), + currentUser=Depends(getCurrentUser), + mapper: PostgresqlFlatMapper = Depends(getMapper), + service: Coords2dService = Depends(getCoords2dService), +): + return service.renderMicrographImage( + mapper=mapper, + projectId=projectId, + currentUser=currentUser, + protocolId=protocolId, + outputName=outputName, + micId=micId, + size=size, + fmt=format, + ) + + +@router.get( + "/{projectId}/protocols/{protocolId}/outputs/{outputName}/coords2d/micrographs/{micId}/thumbnail", + response_model=Any, + status_code=status.HTTP_200_OK, +) +def getCoords2dMicrographThumbnail( + projectId: int, + protocolId: int, + outputName: str, + micId: str, + size: int = Query(180, ge=32, le=512), + format: str = Query("png", pattern="^(png|webp|jpeg|jpg)$"), + currentUser=Depends(getCurrentUser), + mapper: PostgresqlFlatMapper = Depends(getMapper), + service: Coords2dService = Depends(getCoords2dService), +): + return service.renderMicrographImage( + mapper=mapper, + projectId=projectId, + currentUser=currentUser, + protocolId=protocolId, + outputName=outputName, + micId=micId, + size=size, + fmt=format, + ) + +@router.post( + "/{projectId}/protocols/{protocolId}/outputs/{outputName}/coords2d/create-output", + response_model=Any, + status_code=status.HTTP_200_OK, +) +def createCoords2dCoordinatesOutput( + projectId: int, + protocolId: int, + outputName: str, + payload: CreateCoords2dOutputPayload, + currentUser=Depends(getCurrentUser), + mapper: PostgresqlFlatMapper = Depends(getMapper), + service: Coords2dService = Depends(getCoords2dService), +): + return service.createCoordinatesOutput( + mapper=mapper, + projectId=projectId, + currentUser=currentUser, + protocolId=protocolId, + outputName=outputName, + payload=payload.dict(), + ) \ No newline at end of file diff --git a/app/backend/api/services/coords2d_service.py b/app/backend/api/services/coords2d_service.py new file mode 100644 index 0000000..be97aea --- /dev/null +++ b/app/backend/api/services/coords2d_service.py @@ -0,0 +1,818 @@ +# ****************************************************************************** +# * +# * Authors: Yunior C. Fonseca Reyna +# * +# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC +# * +# * This program is free software; you can redistribute it and/or modify +# * it under the terms of the GNU General Public License as published by +# * the Free Software Foundation; either version 3 of the License, or +# * (at your option) any later version. +# * +# * This program is distributed in the hope that it will be useful, +# * but WITHOUT ANY WARRANTY; without even the implied warranty of +# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# * GNU General Public License for more details. +# * +# * You should have received a copy of the GNU General Public License +# * along with this program; if not, write to the Free Software +# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA +# * 02111-1307 USA +# * +# * All comments concerning this program package may be sent to the +# * e-mail address 'scipion@cnb.csic.es' +# * +# ****************************************************************************** + +import io +import logging +import os +from typing import Any, Dict, List, Optional, Tuple +from uuid import uuid4 + +from fastapi import HTTPException, Response, status +from PIL import Image, ImageEnhance, ImageOps +from pwem.emlib.image.image_readers import ImageReadersRegistry +from pwem.objects import Coordinate + +from app.backend.api.services.project_service import ProjectService +from app.backend.mapper.postgresql import PostgresqlFlatMapper + +logger = logging.getLogger(__name__) + + +class Coords2dService: + def __init__(self): + self.projectService = ProjectService() + + def _loadCoordinatesOutput( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + currentUser: Any, + protocolId: int, + outputName: str, + ) -> Tuple[Any, Any]: + project = self.projectService.getProjectById( + mapper, + projectId, + currentUser, + refresh=False, + checkPid=False, + ) + if not project: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found", + ) + + currentProject = self.projectService.currentProject + if currentProject is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Project could not be loaded", + ) + + try: + protocol = currentProject.getProtocol(int(protocolId)) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol '{protocolId}' not found: {e}", + ) + + if protocol is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Protocol '{protocolId}' not found", + ) + + if not hasattr(protocol, outputName): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Output '{outputName}' not found in protocol '{protocolId}'", + ) + + coordinatesSet = getattr(protocol, outputName) + if coordinatesSet is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Output '{outputName}' is empty", + ) + + if not hasattr(coordinatesSet, "getMicrographs") or not hasattr(coordinatesSet, "iterCoordinates"): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Output '{outputName}' is not a SetOfCoordinates output", + ) + + return protocol, coordinatesSet + + @staticmethod + def _safeCall(obj: Any, methodName: str, default: Any = None) -> Any: + try: + method = getattr(obj, methodName, None) + if not callable(method): + return default + value = method() + return default if value is None else value + except Exception: + return default + + @staticmethod + def _safeNumber(value: Any, default: Optional[float] = None) -> Optional[float]: + try: + if value is None: + return default + return float(value) + except Exception: + return default + + @staticmethod + def _tryInt(value: Any) -> Optional[int]: + try: + if value is None: + return None + return int(value) + except Exception: + return None + + @staticmethod + def _micrographId(micrograph: Any) -> str: + value = Coords2dService._safeCall(micrograph, "getObjId", None) + return str(value) if value is not None else "" + + @staticmethod + def _splitLocationValue(location: Any) -> Tuple[Optional[int], str]: + if location is None: + return None, "" + + if isinstance(location, (tuple, list)) and len(location) >= 2: + first = location[0] + second = location[1] + + firstIndex = Coords2dService._tryInt(first) + secondIndex = Coords2dService._tryInt(second) + + if firstIndex is not None: + return firstIndex, str(second or "") + + if secondIndex is not None: + return secondIndex, str(first or "") + + return None, str(second or first or "") + + locationText = str(location or "").strip() + if not locationText: + return None, "" + + if "@" in locationText: + rawIndex, rawFileName = locationText.split("@", 1) + imageIndex = Coords2dService._tryInt(rawIndex) + return imageIndex, rawFileName + + return None, locationText + + @staticmethod + def _micrographLocation(micrograph: Any) -> Tuple[Optional[int], str]: + location = Coords2dService._safeCall(micrograph, "getLocation", None) + imageIndex, fileName = Coords2dService._splitLocationValue(location) + + if fileName: + return imageIndex, fileName + + fileName = str(Coords2dService._safeCall(micrograph, "getFileName", "") or "") + parsedIndex, parsedFileName = Coords2dService._splitLocationValue(fileName) + + if parsedFileName: + return parsedIndex if parsedIndex is not None else imageIndex, parsedFileName + + return imageIndex, fileName + + @staticmethod + def _micrographFileName(micrograph: Any) -> str: + _, fileName = Coords2dService._micrographLocation(micrograph) + return fileName + + @staticmethod + def _micrographName(micrograph: Any) -> str: + micName = Coords2dService._safeCall(micrograph, "getMicName", None) + if micName: + return str(micName) + + label = Coords2dService._safeCall(micrograph, "getObjLabel", None) + if label: + return str(label) + + _, fileName = Coords2dService._micrographLocation(micrograph) + return os.path.basename(str(fileName)) or "Untitled" + + @staticmethod + def _micrographDims(micrograph: Any) -> Tuple[Optional[int], Optional[int]]: + dims = Coords2dService._safeCall(micrograph, "getDim", None) + if not dims: + return None, None + + try: + width = int(dims[0]) if len(dims) > 0 and dims[0] is not None else None + height = int(dims[1]) if len(dims) > 1 and dims[1] is not None else None + return width, height + except Exception: + return None, None + + @staticmethod + def _coordinateMicId(coordinate: Any) -> Optional[str]: + value = Coords2dService._safeCall(coordinate, "getMicId", None) + return str(value) if value is not None else None + + @staticmethod + def _extractCoordinateScore(coordinate: Any) -> Optional[float]: + for methodName in ("getScore", "getWeight"): + value = Coords2dService._safeCall(coordinate, methodName, None) + score = Coords2dService._safeNumber(value, None) + if score is not None: + return score + return None + + @staticmethod + def _extractCoordinateClassLabel(coordinate: Any) -> Optional[str]: + for methodName in ("getClassId", "getObjLabel"): + value = Coords2dService._safeCall(coordinate, methodName, None) + if value is not None and str(value).strip(): + return str(value) + return None + + @staticmethod + def _micrographSortKey(micId: str): + try: + return 0, int(micId) + except Exception: + return 1, str(micId).lower() + + def _buildMicrographMap(self, coordinatesSet: Any) -> Dict[str, Any]: + try: + micrographsSet = coordinatesSet.getMicrographs() + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not load coordinate micrographs: {e}", + ) + + micrographs: Dict[str, Any] = {} + try: + iterator = micrographsSet.iterItems(iterate=False) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not iterate micrographs: {e}", + ) + + for micrograph in iterator: + micId = self._micrographId(micrograph) + if micId: + micrographs[micId] = micrograph.clone() + + return micrographs + + def _countCoordinatesByMicrograph(self, coordinatesSet: Any) -> Dict[str, int]: + counts: Dict[str, int] = {} + + try: + for coordinate in coordinatesSet.iterItems(iterate=False): + micId = self._coordinateMicId(coordinate) + if not micId: + continue + counts[micId] = counts.get(micId, 0) + 1 + return counts + except Exception: + pass + + try: + micrographs = self._buildMicrographMap(coordinatesSet) + for micId in micrographs: + counts[micId] = len(list(coordinatesSet.iterCoordinates(int(micId)))) + return counts + except Exception: + return counts + + def listMicrographs( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + currentUser: Any, + protocolId: int, + outputName: str, + ) -> Dict[str, Any]: + _, coordinatesSet = self._loadCoordinatesOutput( + mapper, + projectId, + currentUser, + protocolId, + outputName, + ) + + micrographMap = self._buildMicrographMap(coordinatesSet) + counts = self._countCoordinatesByMicrograph(coordinatesSet) + boxSize = self._safeCall(coordinatesSet, "getBoxSize", None) + totalPicks = self._safeCall(coordinatesSet, "getSize", None) + + micrographs: List[Dict[str, Any]] = [] + sortedMicIds = sorted(micrographMap.keys(), key=self._micrographSortKey) + + for index, micId in enumerate(sortedMicIds, start=1): + micrograph = micrographMap[micId] + imageIndex, fileName = self._micrographLocation(micrograph) + width, height = self._micrographDims(micrograph) + + micrographs.append({ + "id": micId, + "index": index, + "fileName": fileName, + "label": self._micrographName(micrograph), + "particles": int(counts.get(micId, 0)), + "updated": False, + "width": width, + "height": height, + "locationIndex": imageIndex, + "thumbnailUrl": None, + }) + + if totalPicks is None: + totalPicks = sum(int(item.get("particles") or 0) for item in micrographs) + + return { + "micrographs": micrographs, + "totalMicrographs": len(micrographs), + "totalPicks": int(totalPicks or 0), + "boxSize": int(boxSize) if boxSize else None, + } + + def _findMicrograph(self, coordinatesSet: Any, micId: str) -> Any: + micrograph = self._buildMicrographMap(coordinatesSet).get(str(micId)) + if micrograph is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Micrograph '{micId}' not found in coordinates output", + ) + return micrograph + + def listCoordinatesForMicrograph( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + currentUser: Any, + protocolId: int, + outputName: str, + micId: str, + ) -> Dict[str, Any]: + _, coordinatesSet = self._loadCoordinatesOutput( + mapper, + projectId, + currentUser, + protocolId, + outputName, + ) + + self._findMicrograph(coordinatesSet, micId) + + try: + coordinatesIterator = coordinatesSet.iterCoordinates(int(micId)) + except Exception: + coordinatesIterator = [] + try: + coordinatesIterator = [ + coordinate + for coordinate in coordinatesSet.iterItems(iterate=False) + if self._coordinateMicId(coordinate) == str(micId) + ] + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not iterate coordinates for micrograph '{micId}': {e}", + ) + + coordinates: List[Dict[str, Any]] = [] + for index, coordinate in enumerate(coordinatesIterator): + x = self._safeNumber(self._safeCall(coordinate, "getX", None), None) + y = self._safeNumber(self._safeCall(coordinate, "getY", None), None) + if x is None or y is None: + continue + + objId = self._safeCall(coordinate, "getObjId", None) + coordinates.append({ + "id": objId if objId is not None else f"{micId}:{index}", + "micId": str(micId), + "x": x, + "y": y, + "score": self._extractCoordinateScore(coordinate), + "classLabel": self._extractCoordinateClassLabel(coordinate), + }) + + return {"coordinates": coordinates} + + def _resolveMicrographById(self, micrographsSet: Any, micId: str) -> Any: + try: + return micrographsSet[int(micId)] + except Exception: + pass + + try: + for micrograph in micrographsSet.iterItems(iterate=False): + if self._micrographId(micrograph) == str(micId): + return micrograph + except Exception: + pass + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Micrograph '{micId}' not found in source micrographs", + ) + + def _newCoordinateLike(self, coordinatesSet: Any) -> Any: + try: + firstItem = coordinatesSet.getFirstItem() + if firstItem is not None: + return firstItem.clone() + except Exception: + pass + + return Coordinate() + + def _appendCoordinateToSet( + self, + coordinatesSet: Any, + coordSet: Any, + micrographsSet: Any, + micId: str, + x: float, + y: float, + objId: int, + ) -> None: + coordinate = self._newCoordinateLike(coordinatesSet) + + try: + coordinate.setObjId(objId) + except Exception: + pass + + micrograph = self._resolveMicrographById(micrographsSet, micId) + + try: + coordinate.setMicrograph(micrograph) + except Exception: + pass + + try: + coordinate.setPosition(float(x), float(y)) + except Exception: + try: + coordinate.setX(float(x)) + coordinate.setY(float(y)) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid coordinate for micrograph '{micId}': {e}", + ) + + coordSet.append(coordinate) + + def createCoordinatesOutput( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + currentUser: Any, + protocolId: int, + outputName: str, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + payload = payload or {} + + protocol, coordinatesSet = self._loadCoordinatesOutput( + mapper, + projectId, + currentUser, + protocolId, + outputName, + ) + + try: + micrographsSet = coordinatesSet.getMicrographs() + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not load source micrographs: {e}", + ) + + replacementMap: Dict[str, Dict[str, Any]] = {} + + for item in payload.get("micrographs") or []: + if not isinstance(item, dict): + continue + + rawMicId = item.get("id", item.get("micId")) + if rawMicId is None: + continue + + micId = str(rawMicId) + existingCoordinates: Dict[int, Dict[str, float]] = {} + newCoordinates: List[Dict[str, float]] = [] + + for point in item.get("coordinates") or []: + if not isinstance(point, dict): + continue + + x = self._safeNumber(point.get("x"), None) + y = self._safeNumber(point.get("y"), None) + + if x is None or y is None: + continue + + pointId = self._tryInt(point.get("id")) + if pointId is None: + newCoordinates.append({"x": x, "y": y}) + else: + existingCoordinates[pointId] = {"x": x, "y": y} + + replacementMap[micId] = { + "existing": existingCoordinates, + "new": newCoordinates, + } + + if not replacementMap: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="No coordinate changes provided", + ) + + try: + originalCoordinates = list(coordinatesSet.iterItems(iterate=False)) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not read original coordinates: {e}", + ) + + try: + maxObjId = coordinatesSet.aggregate(["MAX"], "_objId")[0]["MAX"] or 0 + except Exception: + maxObjId = 0 + for coordinate in originalCoordinates: + objId = self._tryInt(self._safeCall(coordinate, "getObjId", None)) + if objId is not None: + maxObjId = max(maxObjId, objId) + + try: + suffix = f"{protocol.getOutputsSize()}_{uuid4().hex[:8]}" + coordSet = protocol._createSetOfCoordinates(micrographsSet, suffix=suffix) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not create coordinates output: {e}", + ) + + try: + coordSet.copyInfo(coordinatesSet) + except Exception: + pass + + boxSize = payload.get("boxSize", None) + if boxSize is not None: + try: + coordSet.setBoxSize(int(boxSize)) + except Exception: + pass + + totalCoordinates = 0 + + try: + for coordinate in originalCoordinates: + micId = self._coordinateMicId(coordinate) + if not micId: + continue + + objId = self._tryInt(self._safeCall(coordinate, "getObjId", None)) + if objId is None: + continue + + if micId in replacementMap: + existingCoordinates = replacementMap[micId]["existing"] + + if objId not in existingCoordinates: + continue + + x = existingCoordinates[objId]["x"] + y = existingCoordinates[objId]["y"] + else: + x = self._safeNumber(self._safeCall(coordinate, "getX", None), None) + y = self._safeNumber(self._safeCall(coordinate, "getY", None), None) + + if x is None or y is None: + continue + + coord = Coordinate() + coord.setObjId(objId) + coord.setMicrograph(self._resolveMicrographById(micrographsSet, micId)) + coord.setPosition(float(x), float(y)) + coordSet.append(coord) + totalCoordinates += 1 + + coordinateTemplate = self._newCoordinateLike(coordinatesSet) + + for micId, replacement in replacementMap.items(): + micrograph = self._resolveMicrographById(micrographsSet, micId) + + for point in replacement["new"]: + maxObjId += 1 + newCoordinate = coordinateTemplate.clone() + newCoordinate.setObjId(maxObjId) + newCoordinate.setMicrograph(micrograph) + newCoordinate.setPosition(float(point["x"]), float(point["y"])) + coordSet.append(newCoordinate) + totalCoordinates += 1 + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not append coordinates: {e}", + ) + + try: + coordSet.write() + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not write coordinates output: {e}", + ) + + requestedOutputName = str(payload.get("outputName") or "").strip() + + if requestedOutputName and not hasattr(protocol, requestedOutputName): + nextOutputName = requestedOutputName + else: + try: + nextOutputName = protocol.getNextOutputName("coordinates_") + except Exception: + nextOutputName = f"coordinates_{protocol.getOutputsSize()}" + + try: + protocol._defineOutputs(**{nextOutputName: coordSet}) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Could not define coordinates output: {e}", + ) + + try: + protocol._defineSourceRelation(micrographsSet, coordSet) + except Exception: + logger.warning("Could not define source relation for coords2d output", exc_info=True) + + return { + "success": True, + "outputName": nextOutputName, + "totalCoordinates": int(totalCoordinates), + "message": f"The new set of coordinates has been created: {nextOutputName}", + } + + @staticmethod + def _normalizeImageFormat(fmt: str) -> Tuple[str, str]: + value = (fmt or "png").strip().lower() + if value in {"jpg", "jpeg"}: + return "JPEG", "image/jpeg" + if value == "webp": + return "WEBP", "image/webp" + return "PNG", "image/png" + + @staticmethod + def _prepareImage(image: Image.Image, size: int) -> Image.Image: + if image.mode not in {"L", "RGB", "RGBA"}: + image = image.convert("L") + + if image.mode == "L": + image = ImageOps.autocontrast(image) + image = ImageEnhance.Contrast(image).enhance(1.6) + elif image.mode == "RGBA": + image = image.convert("RGB") + + if size and size > 0: + image.thumbnail((int(size), int(size)), Image.Resampling.LANCZOS) + + return image + + @staticmethod + def _readMicrographImage(imagePath: str, imageIndex: Optional[int]) -> Image.Image: + imageStack = ImageReadersRegistry.open(imagePath) + + if imageIndex is None: + return imageStack.getImage(pilImage=True) + + try: + return imageStack.getImage(index=imageIndex, pilImage=True) + except Exception: + pass + + try: + return imageStack.getImage(imageIndex, pilImage=True) + except Exception: + pass + + if imageIndex > 0: + zeroBasedIndex = imageIndex - 1 + + try: + return imageStack.getImage(index=zeroBasedIndex, pilImage=True) + except Exception: + pass + + try: + return imageStack.getImage(zeroBasedIndex, pilImage=True) + except Exception: + pass + + return imageStack.getImage(pilImage=True) + + def renderMicrographImage( + self, + mapper: PostgresqlFlatMapper, + projectId: int, + currentUser: Any, + protocolId: int, + outputName: str, + micId: str, + size: int = 2200, + fmt: str = "png", + ) -> Response: + _, coordinatesSet = self._loadCoordinatesOutput( + mapper, + projectId, + currentUser, + protocolId, + outputName, + ) + + micrograph = self._findMicrograph(coordinatesSet, micId) + imageIndex, imagePath = self._micrographLocation(micrograph) + + if not imagePath: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Micrograph '{micId}' does not have a file path", + ) + + imagePath = os.path.abspath(imagePath) + if not os.path.exists(imagePath): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Micrograph image file not found: {imagePath}", + ) + + try: + image = self._readMicrographImage(imagePath, imageIndex) + originalWidth, originalHeight = image.size + image = self._prepareImage(image, size) + except Exception as e: + logger.exception("Failed to render coords2d micrograph image: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to render micrograph image: {e}", + ) + + imageFormat, mediaType = self._normalizeImageFormat(fmt) + buffer = io.BytesIO() + saveOptions: Dict[str, Any] = {} + + if imageFormat == "JPEG": + if image.mode != "RGB": + image = image.convert("RGB") + saveOptions["quality"] = 90 + elif imageFormat == "WEBP": + saveOptions["quality"] = 85 + + image.save(buffer, format=imageFormat, **saveOptions) + + scaleX = image.width / originalWidth if originalWidth else 1 + scaleY = image.height / originalHeight if originalHeight else 1 + + headers = { + "X-Preview-Width": str(image.width), + "X-Preview-Height": str(image.height), + "X-Preview-Original-Width": str(originalWidth), + "X-Preview-Original-Height": str(originalHeight), + "X-Preview-Scale-X": f"{scaleX:.8f}", + "X-Preview-Scale-Y": f"{scaleY:.8f}", + "X-Preview-Origin": "top-left", + "X-Preview-Orientation": "scipion-top-left-no-flip", + "X-Preview-MicrographId": str(micId), + "X-Preview-Source-Index": "" if imageIndex is None else str(imageIndex), + "X-Preview-Source-File": os.path.basename(imagePath), + "X-Preview-Format": imageFormat, + "Cache-Control": "no-store", + } + + return Response( + content=buffer.getvalue(), + media_type=mediaType, + headers=headers, + ) \ No newline at end of file diff --git a/app/backend/api/services/project_service.py b/app/backend/api/services/project_service.py index 8ce8468..248f2b1 100644 --- a/app/backend/api/services/project_service.py +++ b/app/backend/api/services/project_service.py @@ -1136,13 +1136,16 @@ def sortKey(row: Dict[str, Any]): inputs = [] outputs = [] - cpuTime = "" - elapsedTime = "" + cpuTime = '' + elapsedTime = '' isinteractive = False numberOfSteps = 0 stepsDone = 0 thumbnailUrl = None thumbnailRebuildUrl = None + runName = '' + comment = '' + title = '' # Prefer the live protocol object coming from runs graph protocol = liveRuns.get(nodeId) @@ -1159,6 +1162,18 @@ def sortKey(row: Dict[str, Any]): except Exception: pass + try: + runName = protocol.runName.get() + if runName is None: + runName = protocol.getRunName() + except Exception: + pass + + try: + comment = protocol._objComment + except Exception: + pass + try: protStatus = protocol.getStatus() if protStatus: @@ -1264,6 +1279,9 @@ def sortKey(row: Dict[str, Any]): "children": childrenIds, "parents": parentIds, "label": label, + "title": title, + "runName": runName, + "comment": comment, "status": status, "parameter": [], "inputs": inputs, @@ -1519,7 +1537,7 @@ def applyWorkflowToProject( # 8) Return a compact, useful payload for the frontend return { - "status": "ok", + "status": 0, "projectId": projectId, "workflowId": workflowIdStr, "workflowName": getattr(selectedTemplate, "name", workflowIdStr), @@ -1640,6 +1658,11 @@ def attachContainerWizardMetadata(container: Optional[Dict[str, Any]]) -> None: logoPath = self.getResourceLogo(path) protName = str(protocol) + + if protocol.runName.get() is None: + runName = protocol.getRunName() + else: + runName = protocol.runName.get() status = protocol.getStatus() protocolClassName = protocol.getClassName() hosts = self.currentProject.getHostNames() @@ -1648,6 +1671,7 @@ def attachContainerWizardMetadata(container: Optional[Dict[str, Any]]) -> None: info = { "protocolId": protocol.getObjId(), "label": protName, + "runName": runName, "status": status, "expertLevel": hasExpert, "packageLogo": logoPath, @@ -1889,7 +1913,7 @@ def attachContainerWizardMetadata(container: Optional[Dict[str, Any]]) -> None: if paramProcessed: if paramName == 'runName': paramProcessed['default'] = '' - paramValue = protName + paramValue = runName elif paramName == 'numberOfThreads': paramValue = protocol.getScipionThreads() elif paramName == 'gpuList': @@ -2256,6 +2280,7 @@ def saveProtocol(self, mapper, projectId, protocolId, protocolClassName, params, protocol.setAttributeValue(key, castedValue) if key == "runName": + protocol.runName.set(castedValue) protocol.setObjLabel(castedValue) logger.info("[INFO] Set param %s = %s", key, castedValue) @@ -2410,7 +2435,7 @@ def launchProtocol(self, mapper, projectId, protocolId, protocolClassName, param ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Protocol execution finished but graph sync to PostgreSQL failed: {e}", + detail=f"{e}", ) def findViewersWeb(self, protocol): @@ -2993,6 +3018,7 @@ def renameProtocol(self, protocolId, newName): ) try: + protocol.runName.set(newName) protocol.setObjLabel(newName) self.currentProject._storeProtocol(protocol) except Exception as e: @@ -3098,7 +3124,7 @@ def deleteProtocol(self, mapper, projectId, protocols: Any): ) return { - "status": "ok", + "status": 0, "message": "Protocol deleted successfully", "protocolsCount": syncInfo.get("protocols"), "dependenciesCount": syncInfo.get("dependencies"), @@ -4629,7 +4655,7 @@ def createNewSetOfCtftomoSeriesService( "The new Ctftomo set (%s) has been created successfully with %d series", newOutputName, createdCount,) return { - "status": "ok", + "status": 0, "outputName": newOutputName, "createdSeries": createdCount, "restack": bool(restack), @@ -4880,7 +4906,7 @@ def createNewSetOfTiltSeriesService( logger.info("The new set (%s) has been created successfully", newOutputName) return { - "status": "ok", + "status": 0, "outputName": newOutputName, "createdTiltSeries": createdCount, "hasOddEven": bool(hasOddEven), diff --git a/app/backend/main.py b/app/backend/main.py index 53f259e..7813c11 100644 --- a/app/backend/main.py +++ b/app/backend/main.py @@ -44,6 +44,7 @@ from app.backend.api.routers.auth_router import router as auth from app.backend.api.routers.user_router import router as users from app.backend.api.routers.settings_router import router as settingsRouter +from app.backend.api.routers.coords2d_router import router as coords2dRouter from app.backend.utils.error_handlers import registerAllErrorHandlers from starlette.staticfiles import StaticFiles from starlette.exceptions import HTTPException as StarletteHttpException @@ -104,6 +105,12 @@ def _buildApiApp() -> FastAPI: "X-Preview-Mime", "X-Preview-Width", "X-Preview-Height", + "X-Preview-Original-Width", + "X-Preview-Original-Height", + "X-Preview-Scale-X", + "X-Preview-Scale-Y", + "X-Preview-Origin", + "X-Preview-Orientation", "X-Preview-Depth", "X-Preview-Colormap", "X-Preview-Colormap-Note", @@ -115,6 +122,8 @@ def _buildApiApp() -> FastAPI: "X-Preview-VoxelSize", "X-Preview-Schema", "X-Preview-Name", + "X-Preview-MicrographId", + "X-Preview-Format", ], ) @@ -125,6 +134,7 @@ def _buildApiApp() -> FastAPI: apiApp.include_router(auth) apiApp.include_router(users) apiApp.include_router(settingsRouter) + apiApp.include_router(coords2dRouter) @apiApp.get("/health") def health_check(): diff --git a/app/backend/utils/thumbnail_service.py b/app/backend/utils/thumbnail_service.py index 7d0d3bc..41d5fcf 100644 --- a/app/backend/utils/thumbnail_service.py +++ b/app/backend/utils/thumbnail_service.py @@ -32,7 +32,6 @@ import hashlib from urllib.parse import quote from pathlib import Path -import tempfile import threading from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple diff --git a/tests/conftest.py b/tests/conftest.py index 423a0bb..55cc702 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -207,6 +207,11 @@ def __init__(self): self.lastRenameProtocolCall = None self.duplicateProtocolError = None + self.duplicateProtocolResult = { + "status": 0, + "errors": [], + "duplicated": [], + } self.lastDuplicateProtocolCall = None self.deleteProtocolError = None @@ -454,6 +459,7 @@ def duplicateProtocol(self, mapper, projectId, items): } if self.duplicateProtocolError is not None: raise self.duplicateProtocolError + return self.duplicateProtocolResult def deleteProtocol(self, mapper, projectId, protocolIds): self.lastDeleteProtocolCall = { diff --git a/tests/integration/api/test_projects_router_protocol_ops.py b/tests/integration/api/test_projects_router_protocol_ops.py index 796498c..bc8078c 100644 --- a/tests/integration/api/test_projects_router_protocol_ops.py +++ b/tests/integration/api/test_projects_router_protocol_ops.py @@ -280,6 +280,7 @@ def test_DuplicateProtocolDelegatesToService(projectClient, fakeProjectService): "status": 0, "errors": [], "workflow": [], + "duplicated": [], } items = fakeProjectService.lastDuplicateProtocolCall["items"] diff --git a/tests/unit/backend/api/services/test_project_service_ctftomo.py b/tests/unit/backend/api/services/test_project_service_ctftomo.py index 8355b7c..e37f624 100644 --- a/tests/unit/backend/api/services/test_project_service_ctftomo.py +++ b/tests/unit/backend/api/services/test_project_service_ctftomo.py @@ -591,7 +591,7 @@ def test_CreateNewSetOfCtftomoSeriesServiceCreatesFilteredSeries(service, tmp_pa ) assert result == { - "status": "ok", + "status": 0, "outputName": "CTFTomoSeries_0", "createdSeries": 1, "restack": False, diff --git a/tests/unit/backend/api/services/test_project_service_protocols.py b/tests/unit/backend/api/services/test_project_service_protocols.py index 600ab31..4e21fb8 100644 --- a/tests/unit/backend/api/services/test_project_service_protocols.py +++ b/tests/unit/backend/api/services/test_project_service_protocols.py @@ -266,6 +266,11 @@ def mapper(): return FakeMapper() +def assertSuccessEnvelope(result): + assert result["status"] == 0 + assert result["errors"] == [] + + def test_CastParamValueSupportsEnumLookup(projectServiceModule, service, monkeypatch): monkeypatch.setattr(projectServiceModule, "EnumParam", FakeEnumParam) @@ -510,7 +515,7 @@ def test_RenameProtocolStoresNewLabel(service): result = service.renameProtocol(10, "Renamed protocol") - assert result == {"status": "ok", "message": "Protocol renamed successfully"} + assertSuccessEnvelope(result) assert protocol._label == "Renamed protocol" assert service.currentProject.storedProtocols == [protocol] @@ -558,12 +563,9 @@ def __init__(self, itemId): protocols=[DuplicateItem("10"), DuplicateItem("11")], ) - assert result == { - "status": "ok", - "message": "Protocol was duplicated successfully", - "protocolsCount": 2, - "dependenciesCount": 0, - } + assertSuccessEnvelope(result) + assert result["protocolsCount"] == 2 + assert result["dependenciesCount"] == 0 assert service.currentProject.copiedProtocolInputs == [[protocolA, protocolB]] assert mapper.savedProtocolContexts == [ {"projectId": 1, "protocolId": 110}, @@ -636,7 +638,7 @@ def test_ContinueProtocolAllLaunchesActiveProtocolsInResumeMode(projectServiceMo currentUser={"id": 1}, ) - assert result == {"status": "ok", "message": "Protocol subtree continued successfully"} + assertSuccessEnvelope(result) assert activeProtocol.runMode.get() == "resume-mode" assert service.currentProject.launchedProtocols == [activeProtocol] @@ -648,7 +650,7 @@ def test_ResetProtocolFromReturnsSuccessWhenWorkflowResets(service): result = service.resetProtocolFrom(10) - assert result == {"status": "ok", "message": "Protocol subtree reset successfully"} + assertSuccessEnvelope(result) def test_StopProtocolStopsEachProtocol(service): @@ -659,5 +661,5 @@ def test_StopProtocolStopsEachProtocol(service): result = service.stopProtocol(["10", "11"]) - assert result == {"status": "ok", "message": "Protocol stopped successfully"} + assertSuccessEnvelope(result) assert service.currentProject.stoppedProtocols == [protocolA, protocolB] \ No newline at end of file