diff --git a/README.md b/README.md index a0c548f..22cb99d 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,26 @@ -# QuPath Cellpose extension +# QuPath Cellpose/Omnipose extension This repo adds some support to use 2D Cellpose within QuPath through a Python virtual environment. +We also want to use Omnipose, which offers some amazing features, but there is currently no consensus between the +developers and there are incompatibilities between the current Cellpose and Omnipose versions. + +We have decided to provide support for both using cellpose and omnipose, in the form of two separate environments, so that +they can play nice. +======= > **Warning** > Versions above v0.6.0 of this extension **will only work on QuPath 0.4.0 or later**. Please update QuPath to the latest version. > **Warning** > In case you are stuck with QuPath v0.3.2, [the last release to work is v0.5.1](https://github.com/BIOP/qupath-extension-cellpose/releases/tag/v0.5.1) - # Installation -## Step 1: Install Cellpose - -Follow the instructions to install Cellpose from [the main Cellpose repository](https://github.com/mouseland/cellpose). -This extension will need to know the path to your Cellpose environment. +## Step 1: Install Cellpose and Omnipose -Note that there is currently a bug in Cellpose v2.1.1, where all training and prediction is done twice, -so until the issue is fixed, the recommended version of cellpose is v2.0.5. +Follow the instructions to install Cellpose from [the main Cellpose repository](https://github.com/mouseland/cellpose). And +Omnipose from [the main Omnipose repository](https://omnipose.readthedocs.io/installation.html) +This extension will need to know the path to at least your Cellpose environment. If you plan on using Omnipose, you will also need to install it. ### NOTE: `scikit-image` Dependency As of version 0.4 of this extension, QC (quality control) is run **automatically** when training a model. @@ -25,71 +28,38 @@ As of version 0.4 of this extension, QC (quality control) is run **automatically Due to the dependencies of the validation code, located inside [run-cellpose-qc.py](QC/run-cellpose-qc.py) requires an extra dependency on `scikit-image`. -The simplest way to add it, is when installing Cellpose as instructed in the oficial repository and adding scikit-image +The simplest way to add it, is when installing Cellpose as instructed in the official repository and adding scikit-image ```bash python -m pip install cellpose scikit-image ``` +or +```bash +python -m pip install omnipose scikit-image +``` -### This extension no longer works with cellpose versions before 2.0. -Please keep this in mind and update your cellpose installation in case of problems. +## Installation with Conda/Mamba +We provide the following YAML file that installs Cellpose and omnipose in the same environment. +The configuration files are without guarantee, but these are the ones we use for our Windows machines. +[Download `cellpose-omnipose-biop-gpu.yml`](files/cellpose-omnipose-biop-gpu.yml) -### Example Cellpose 2.0.5 installation with CUDA 11.3 GPU support +You can create the environment with the following command using either conda or mamba: -First, we create the conda environment: -``` -conda create -n cellpose-205 python=3.8 -conda activate cellpose-205 -pip install cellpose==2.0.5 scikit-image==0.19.3 -pip uninstall torch -pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 +```bash +mamba env create -f cellpose-omnipose-biop-gpu.yml ``` -
- See the 'pip freeze' result -``` -cellpose==2.0.5 -certifi @ file:///C:/Windows/TEMP/abs_e9b7158a-aa56-4a5b-87b6-c00d295b01fanefpc8_o/croots/recipe/certifi_1655968940823/work/certifi -charset-normalizer==2.0.12 -colorama==0.4.5 -fastremap==1.13.0 -idna==3.3 -imagecodecs==2022.2.22 -imageio==2.21.2 -llvmlite==0.38.1 -natsort==8.1.0 -networkx==2.8.6 -numba==0.55.2 -numpy==1.22.4 -opencv-python-headless==4.6.0.66 -packaging==21.3 -Pillow==9.1.1 -pyparsing==3.0.9 -PyWavelets==1.3.0 -requests==2.28.0 -scikit-image==0.19.3 -scipy==1.8.1 -tifffile==2022.5.4 -torch==1.11.0+cu113 -torchaudio==0.11.0+cu113 -torchvision==0.12.0+cu113 -tqdm==4.64.0 -typing_extensions==4.2.0 -urllib3==1.26.9 -wincertstore==0.2 -``` -
- -Next, we look for the Python executable, **which we will need later when configuring the QuPath Cellpose extension**. +### Check the path to the Python executable +We will need this information later when configuring the QuPath Cellpose extension. ``` +mamba activa te cellpose-omnipose-biop-gpu where python -C:\Users\oburri\.conda\envs\cellpose-205\python.exe +F:\conda-envs\cellpose-omnipose-biop-gpu\python.exe ``` > **Note** > While this example is done under Windows, this will work on Mac and Linux as well. - ## Step 2: Install the QuPath Cellpose extension Download the latest `qupath-extension-cellpose-[version].zip` file from [releases](https://github.com/biop/qupath-extension-cellpose/releases) and unzip it into your `extensions` directory. @@ -105,14 +75,21 @@ You might then need to restart QuPath (but not your computer). > In case you do not do this step, Cellpose training will still work, but the QC step will be skipped, and you will be notified that `run-cellpose-qc.py` cannot be found. -## QuPath Extension Cellpose: First time setup +## QuPath Extension Cellpose/Omnipose: First time setup -Go to `Edit > Preferences > Cellpose` +Go to `Edit > Preferences > Cellpose/Omnipose` Complete the fields with the requested information. based on the `conda` installation above, this is what it should look like: ![Cellpose setup example](files/cellpose-qupath-setup-example.png) > **Note** -> Prefer using "Python Executable" as Cellpose Environment Type, as this is more OS-agnostic than the other methods, which may become deprecated in the future. +You have the possibility to provide **two** different environments. One for Cellpose and one for Omnipose. +> If you do not plan on using Omnipose or have installed both cellpose and Omnipose in the same environment, you can leave it blank. + +The reason for this is that there may be versions of cellpose and its dependencies that might not match with Omnipose. Adding to that, some parameters +in cellpose and omnipose are currently out of sync, so it could be wiser to keep them separate. +> **Warning** as of this writing, the versions used are `cellpose==2.2.1` and `omnipose==0.4.4` + +**The extension handles switching between the two based on the `useOmnipose()` flag in the builder.** ## Running Cellpose the first time in standalone @@ -122,7 +99,6 @@ QuPath due to permission issues. One trick is to **run Cellpose from the command line** once with the model you want to use. The download should work from there, and you can then use it within the QuPath Extension Cellpose. - # Using the Cellpose QuPath Extension ## Training @@ -223,33 +199,37 @@ All builder options that are implemented are [in the Javadoc](https://biop.githu ### Breaking changes after QuPath 0.4.0 In order to make the extension more flexible and less dependent on the builder, a new Builder method `addParameter(name, value)` is available that can take [any cellpose CLI argument or argument pair](https://cellpose.readthedocs.io/en/latest/command.html#options). -For this to work, some elements that were "hard coded" on the builder have been removed, so you will get some errors. For example: `useOmnipose()`, `excludeEdges()` and `clusterDBSCAN()` no longer exist. -You can use `addParameter("omni")`, `addParameter("exclude_on_edges")`, and `addParameter("cluster")` instead. +For this to work, some elements that were "hard coded" on the builder have been removed, so you will get some errors. For example: `excludeEdges()` and `clusterDBSCAN()` no longer exist. +You can use `addParameter("exclude_on_edges")`, and `addParameter("cluster")` instead. ```groovy import qupath.ext.biop.cellpose.Cellpose2D +// For all the options from cellpose: https://cellpose.readthedocs.io/en/latest/cli.html +// For all the options from omnipose: https://omnipose.readthedocs.io/command.html#all-options // Specify the model name (cyto, nuc, cyto2, omni_bact or a path to your custom model) def pathModel = 'cyto2' def cellpose = Cellpose2D.builder( pathModel ) - .pixelSize( 0.5 ) // Resolution for detection in um - .channels( 'DAPI' ) // Select detection channel(s) + .pixelSize( 0.5 ) // Resolution for detection in um + .channels( 'DAPI' ) // Select detection channel(s) // .preprocess( ImageOps.Filters.median(1) ) // List of preprocessing ImageOps to run on the images before exporting them - .normalizePercentilesGlobal(0.1, 99.8, 10) // Convenience global percentile normalization. arguments are percentileMin, percentileMax, dowsample. - .tileSize(1024) // If your GPU can take it, make larger tiles to process fewer of them. Useful for Omnipose -// .cellposeChannels(1,2) // Overwrites the logic of this plugin with these two values. These will be sent directly to --chan and --chan2 -// .cellprobThreshold(0.0) // Threshold for the mask detection, defaults to 0.0 -// .flowThreshold(0.4) // Threshold for the flows, defaults to 0.4 +// .normalizePercentilesGlobal(0.1, 99.8, 10) // Convenience global percentile normalization. arguments are percentileMin, percentileMax, dowsample. +// .tileSize(1024) // If your GPU can take it, make larger tiles to process fewer of them. Useful for Omnipose +// .cellposeChannels(1,2) // Overwrites the logic of this plugin with these two values. These will be sent directly to --chan and --chan2 +// .cellprobThreshold(0.0) // Threshold for the mask detection, defaults to 0.0 +// .flowThreshold(0.4) // Threshold for the flows, defaults to 0.4 // .diameter(15) // Median object diameter. Set to 0.0 for the `bact_omni` model or for automatic computation -// .addParameter("save_flows") // Any parameter from cellpose not available in the builder. See https://cellpose.readthedocs.io/en/latest/command.html -// .addParameter("anisotropy", "3") // Any parameter from cellpose not available in the builder. See https://cellpose.readthedocs.io/en/latest/command.html +// .useOmnipose() // Use the omnipose instead +// .addParameter("cluster") // Any parameter from cellpose or omnipose not available in the builder. +// .addParameter("save_flows") // Any parameter from cellpose or omnipose not available in the builder. +// .addParameter("anisotropy", "3") // Any parameter from cellpose or omnipose not available in the builder. // .cellExpansion(5.0) // Approximate cells based upon nucleus expansion -// .cellConstrainScale(1.5) // Constrain cell expansion using nucleus size -// .classify("My Detections") // PathClass to give newly created objects -// .measureShape() // Add shape measurements -// .measureIntensity() // Add cell measurements (in all compartments) -// .createAnnotations() // Make annotations instead of detections. This ignores cellExpansion -// .simplify(0) // Simplification 1.6 by default, set to 0 to get the cellpose masks as precisely as possible +// .cellConstrainScale(1.5) // Constrain cell expansion using nucleus size +// .classify("My Detections") // PathClass to give newly created objects +// .measureShape() // Add shape measurements +// .measureIntensity() // Add cell measurements (in all compartments) +// .createAnnotations() // Make annotations instead of detections. This ignores cellExpansion +// .simplify(0) // Simplification 1.6 by default, set to 0 to get the cellpose masks as precisely as possible .build() // Run detection for the selected objects @@ -265,13 +245,16 @@ println 'Done!' # Citing -If you use this extension, you should cite the original Cellpose publication -- Stringer, C., Wang, T., Michaelos, M. et al. -[*Cellpose: a generalist algorithm for cellular segmentation*](https://arxiv.org/abs/1806.03535) -Nat Methods 18, 100–106 (2021). https://doi.org/10.1038/s41592-020-01018-x +If you use this extension, you should cite the following publications -**You should also cite the QuPath publication, as described [here](https://qupath.readthedocs.io/en/stable/docs/intro/citing.html).** +Stringer, C., Wang, T., Michaelos, M. et al **Cellpose: a generalist algorithm for cellular segmentation**. Nat Methods 18, 100–106 (2021). https://doi.org/10.1038/s41592-020-01018-x +Pachitariu, M., Stringer, C. **Cellpose 2.0: how to train your own model**. Nat Methods 19, 1634–1641 (2022). https://doi.org/10.1038/s41592-022-01663-4 + +Cutler, K.J., Stringer, C., Lo, T.W. et al. **Omnipose: a high-precision morphology-independent solution for bacterial cell segmentation**. Nat Methods 19, 1438–1448 (2022). https://doi.org/10.1038/s41592-022-01639-4 + +Bankhead, P. et al. **QuPath: Open source software for digital pathology image analysis**. Scientific Reports (2017). +https://doi.org/10.1038/s41598-017-17204-5 # Building @@ -306,7 +289,6 @@ In case you end up with split detections, this is caused by the overlap calculat In turn, this causes the QuPath extension to fail to extract tiles with sufficient overlap. Use `setOverlap( int )` in the builder to set the overlap (in pixels) to a value 2x larger than the largest object you are segmenting. - ### To find the overlap You can draw a line ROI across your largest object in QuPath and run the following one-line script @@ -316,6 +298,5 @@ print "Selected Line length is " + Math.round(getSelectedObject().getROI().getLe Double whatever value is output from the script and use it in `setOverlap( int )` in the builder. ## Ubuntu Error 13: Permission Denied - [As per this post here](https://forum.image.sc/t/could-not-execute-system-command-in-qupath-thanks-to-groovy-script-and-java-processbuilder-class/61629/2?u=oburri), there is a permissions issue when using Ubuntu, which does not allow Java's `ProcessBuilder` to run. The current workaround is [to build QuPath from source](https://qupath.readthedocs.io/en/stable/docs/reference/building.html) in Ubuntu, which then allows the use of the `ProcessBuilder`, which is the magic piece of code that actually calls Cellpose. diff --git a/build.gradle b/build.gradle index a720198..8764ac7 100644 --- a/build.gradle +++ b/build.gradle @@ -24,7 +24,7 @@ ext.qupathVersion = gradle.ext.qupathVersion description = 'QuPath extension to use Cellpose' -version = "0.6.2-SNAPSHOT" +version = "0.7.0" dependencies { implementation "io.github.qupath:qupath-gui-fx:${qupathVersion}" diff --git a/files/cellpose-omnipose-biop-gpu.yml b/files/cellpose-omnipose-biop-gpu.yml new file mode 100644 index 0000000..abe623a --- /dev/null +++ b/files/cellpose-omnipose-biop-gpu.yml @@ -0,0 +1,15 @@ +name: cellpose-omnipose-biop-gpu +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + - python>=3.8 + - pytorch-cuda=11.7 + - pytorch + - mahotas=1.4.13 + - pip + - pip: + - cellpose==2.2.1 + - omnipose==0.4.4 + - scikit-image==0.20.0 \ No newline at end of file diff --git a/files/cellpose-qupath-setup-example.png b/files/cellpose-qupath-setup-example.png index 6ae7874..e2aa794 100644 Binary files a/files/cellpose-qupath-setup-example.png and b/files/cellpose-qupath-setup-example.png differ diff --git a/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java b/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java index 5cbf182..aae1987 100644 --- a/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java +++ b/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java @@ -32,7 +32,6 @@ import org.locationtech.jts.geom.Envelope; import org.locationtech.jts.geom.Geometry; import org.locationtech.jts.geom.GeometryCollection; -import org.locationtech.jts.geom.Polygon; import org.locationtech.jts.geom.prep.PreparedGeometry; import org.locationtech.jts.geom.prep.PreparedGeometryFactory; import org.locationtech.jts.index.strtree.STRtree; @@ -89,13 +88,19 @@ import java.util.stream.Collectors; /** - * Dense object detection based on the following publication and code + * Dense object detection based on the cellpose and omnipose publications *
  * Stringer, C., Wang, T., Michaelos, M. et al.
  *     "Cellpose: a generalist algorithm for cellular segmentation"
  *     Nat Methods 18, 100–106 (2021). https://doi.org/10.1038/s41592-020-01018-x
  * 
- * See the main repo at https://github.com/mouseland/cellpose + * And + *
+ * Cutler, K.J., Stringer, C., Lo, T.W. et al.
+ *     "Omnipose: a high-precision morphology-independent solution for bacterial cell segmentation"
+ *     Nat Methods 19, 1438–1448 (2022). https://doi.org/10.1038/s41592-022-01639-4
+ * 
+ * See the main repos at https://github.com/mouseland/cellpose and https://github.com/kevinjohncutler/omnipose * *

* The structure of this extension was adapted from the qupath-stardist-extension at https://github.com/qupath/qupath-extension-stardist @@ -103,13 +108,12 @@ *

* * @author Olivier Burri + * @author Pete Bankhead */ public class Cellpose2D { private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class); - protected boolean doLog = false; - protected double simplifyDistance = 1.4; protected ImageDataOp op; @@ -148,46 +152,15 @@ public class Cellpose2D { protected File modelDirectory; protected File trainDirectory; - protected File valDirectory ; - + protected File valDirectory; + protected int nThreads = -1; private File cellposeTempFolder; - private String[] theLog; - // Results table from the training private ResultsTable trainingResults; private ResultsTable qcResults; private File modelFile; - protected int nThreads = -1; - - /** - * Optionally submit runnable to a thread pool. This limits the parallelization used by parallel streams. - * @param runnable - */ - private void runInPool(Runnable runnable) { - if (nThreads > 0) { - if (nThreads == 1) - logger.info("Processing with {} thread", nThreads); - else - logger.info("Processing with {} threads", nThreads); - // Using an outer thread poll impacts any parallel streams created inside - var pool = new ForkJoinPool(nThreads); - try { - pool.submit(() -> runnable.run()); - } finally { - pool.shutdown(); - try { - pool.awaitTermination(24, TimeUnit.HOURS); - } catch (InterruptedException e) { - logger.warn("Process was interrupted! " + e.getLocalizedMessage(), e); - } - } - } else { - runnable.run(); - } - } - /** * Create a builder to customize detection parameters. * This accepts either Text describing the built-in models from cellpose (cyto, cyto2, nuc) @@ -211,33 +184,6 @@ public static CellposeBuilder builder(File builderPath) { return new CellposeBuilder(builderPath); } - /** - * The directory that was used for saving the training images - * @return - */ - public File getTrainingDirectory() { - return trainDirectory; - } - - /** - * The directory that was used for saving the validation images - * @return - */ - public File getValidationDirectory() { - return valDirectory; - } - - - private Geometry simplify(Geometry geom) { - if (simplifyDistance <= 0) - return geom; - try { - return VWSimplifier.simplify(geom, simplifyDistance); - } catch (Exception e) { - return geom; - } - } - /** * Build a normalization op that can be based upon the entire (2D) image, rather than only local tiles. *

@@ -294,7 +240,7 @@ private static List filterObjects(List objectGeometry) { var envelope = envelopes.computeIfAbsent(cell, g -> g.getEnvelopeInternal()); @SuppressWarnings("unchecked") - var overlaps = (List)tree.query(envelope); + var overlaps = (List) tree.query(envelope); // Remove the overlaps that we can be sure don't apply using quick tests, to avoid expensive ones var iter = overlaps.iterator(); @@ -361,7 +307,7 @@ private static List filterObjects(List objectGeometry) { int skipCount = skippedObjects.size(); String s = skipErrorCount == 1 ? "1 nucleus" : skipErrorCount + " nuclei"; logger.warn("Skipped {} due to error in resolving overlaps ({}% of all skipped)", - s, GeneralTools.formatNumber(skipErrorCount*100.0/skipCount, 1)); + s, GeneralTools.formatNumber(skipErrorCount * 100.0 / skipCount, 1)); } return new ArrayList<>(objects); } @@ -378,7 +324,7 @@ else if (children.size() > 1) private static PathObject cellToObject(PathObject cell, Function creator) { var parent = creator.apply(cell.getROI()); - var nucleusROI = cell instanceof PathCellObject ? ((PathCellObject)cell).getNucleusROI() : null; + var nucleusROI = cell instanceof PathCellObject ? ((PathCellObject) cell).getNucleusROI() : null; if (nucleusROI != null) { var nucleus = creator.apply(nucleusROI); nucleus.setPathClass(cell.getPathClass()); @@ -394,6 +340,62 @@ private static PathObject cellToObject(PathObject cell, Function 0) { + if (nThreads == 1) + logger.info("Processing with {} thread", nThreads); + else + logger.info("Processing with {} threads", nThreads); + // Using an outer thread poll impacts any parallel streams created inside + var pool = new ForkJoinPool(nThreads); + try { + pool.submit(() -> runnable.run()); + } finally { + pool.shutdown(); + try { + pool.awaitTermination(24, TimeUnit.HOURS); + } catch (InterruptedException e) { + logger.warn("Process was interrupted! " + e.getLocalizedMessage(), e); + } + } + } else { + runnable.run(); + } + } + + /** + * The directory that was used for saving the training images + * + * @return + */ + public File getTrainingDirectory() { + return trainDirectory; + } + + /** + * The directory that was used for saving the validation images + * + * @return + */ + public File getValidationDirectory() { + return valDirectory; + } + + private Geometry simplify(Geometry geom) { + if (simplifyDistance <= 0) + return geom; + try { + return VWSimplifier.simplify(geom, simplifyDistance); + } catch (Exception e) { + return geom; + } + } + /** * Detect cells within one or more parent objects, firing update events upon completion. * @@ -444,9 +446,9 @@ public void detectObjectsImpl(ImageData imageData, Collection 0) { - downsample = (int) Math.round(pixelSize / imageData.getServer().getPixelCalibration().getAveragedPixelSize().doubleValue()); + downsample = Math.round(pixelSize / imageData.getServer().getPixelCalibration().getAveragedPixelSize().doubleValue()); } ImageServer server = imageData.getServer(); @@ -454,13 +456,13 @@ public void detectObjectsImpl(ImageData imageData, Collection allTiles = parents.parallelStream().map(parent -> { RegionRequest request = RegionRequest.createInstance( - opServer.getPath(), - opServer.getDownsampleForResolution(0), - parent.getROI()); + opServer.getPath(), + opServer.getDownsampleForResolution(0), + parent.getROI()); // Get all the required tiles that intersect with the mask ROI Geometry mask = parent.getROI().getGeometry(); @@ -473,10 +475,10 @@ public void detectObjectsImpl(ImageData imageData, Collection rois = RoiTools.computeTiledROIs(parent.getROI(), ImmutableDimension.getInstance(tileWidth * finalDownsample, tileWidth * finalDownsample), ImmutableDimension.getInstance(tileWidth * finalDownsample, tileHeight * finalDownsample), true, overlap * finalDownsample); + Collection rois = RoiTools.computeTiledROIs(parent.getROI(), ImmutableDimension.getInstance((int) (tileWidth * finalDownsample), (int) (tileWidth * finalDownsample)), ImmutableDimension.getInstance((int) (tileWidth * finalDownsample), (int) (tileHeight * finalDownsample)), true, (int) (overlap * finalDownsample)); - List tiles = rois.stream().map( r -> { - return RegionRequest.createInstance( opServer.getPath(),opServer.getDownsampleForResolution(0), r); + List tiles = rois.stream().map(r -> { + return RegionRequest.createInstance(opServer.getPath(), opServer.getDownsampleForResolution(0), r); }).collect(Collectors.toList()); // Detect all potential nuclei @@ -540,13 +542,13 @@ public void detectObjectsImpl(ImageData imageData, Collection { PathObject parent = tileMap.getObject(); // Read each image @@ -561,7 +563,7 @@ public void detectObjectsImpl(ImageData imageData, Collection candidates = null; try { - candidates = readObjectsFromFile(maskFile, tilefile.getTile()).stream().filter(c-> parent.getROI().getGeometry().intersects(c.geometry)).collect(Collectors.toList()); + candidates = readObjectsFromFile(maskFile, tilefile.getTile()).stream().filter(c -> parent.getROI().getGeometry().intersects(c.geometry)).collect(Collectors.toList()); } catch (IOException e) { e.printStackTrace(); @@ -775,7 +777,7 @@ private List filterDetections(Collection rawCa difference = difference.getGeometryN(index); } // difference instanceof Polygon && - if ( difference.getArea() > overlappingCandidate.area / 2.0) + if (difference.getArea() > overlappingCandidate.area / 2.0) overlappingCandidate.geometry = difference; else { skippedObjects.add(overlappingCandidate); @@ -855,19 +857,35 @@ private TileFile saveTileImage(ImageDataOp op, ImageData imageDat return new TileFile(request, tempFile); } + /** + * Selects the right folder to run from, based on whether it's cellpose or omnipose. + * Hopefully this will become deprecated soon + * @return + */ + private VirtualEnvironmentRunner getVirtualEnvironmentRunner() { + // Need to decide whether to call cellpose or omnipose + VirtualEnvironmentRunner veRunner; + + if (this.parameters.containsKey("omni") && !cellposeSetup.getOmniposePytonPath().isEmpty()) { + veRunner = new VirtualEnvironmentRunner(cellposeSetup.getOmniposePytonPath(), VirtualEnvironmentRunner.EnvType.EXE, this.getClass().getSimpleName()); + } else { + veRunner = new VirtualEnvironmentRunner(cellposeSetup.getCellposePytonPath(), VirtualEnvironmentRunner.EnvType.EXE, this.getClass().getSimpleName()); + } + return veRunner; + } + /** * This class actually runs Cellpose by calling the virtual environment * * @throws IOException Exception in case files could not be read * @throws InterruptedException Exception in case of command thread has some failing */ - private void runCellpose() throws IOException, InterruptedException { - - // Create command to run - VirtualEnvironmentRunner veRunner = new VirtualEnvironmentRunner(cellposeSetup.getEnvironmentNameOrPath(), cellposeSetup.getEnvironmentType(), this.getClass().getSimpleName()); + private void runDetection() throws IOException, InterruptedException { + String runCommand = this.parameters.containsKey("omni") ? "omnipose" : "cellpose"; + VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner(); // This is the list of commands after the 'python' call - List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", "cellpose")); + List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand)); cellposeArguments.add("--dir"); cellposeArguments.add("" + this.cellposeTempFolder); @@ -876,8 +894,8 @@ private void runCellpose() throws IOException, InterruptedException { cellposeArguments.add("" + this.model); this.parameters.forEach((parameter, value) -> { - cellposeArguments.add("--"+parameter); - if( value != null) { + cellposeArguments.add("--" + parameter); + if (value != null) { cellposeArguments.add(value); } }); @@ -917,7 +935,7 @@ public File train() { saveTrainingImages(); - runCellposeTraining(); + runTraining(); this.modelFile = moveAndReturnModelFile(); @@ -943,12 +961,12 @@ public File train() { * @throws IOException Exception in case files could not be read * @throws InterruptedException Exception in case of command thread has some failing */ - private void runCellposeTraining() throws IOException, InterruptedException { - // Create command to run - VirtualEnvironmentRunner veRunner = new VirtualEnvironmentRunner(cellposeSetup.getEnvironmentNameOrPath(), cellposeSetup.getEnvironmentType(), this.getClass().getSimpleName() + "-train"); + private void runTraining() throws IOException, InterruptedException { + String runCommand = this.parameters.containsKey("omni") ? "omnipose" : "cellpose"; + VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner(); // This is the list of commands after the 'python' call - List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", "cellpose")); + List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand)); cellposeArguments.add("--train"); @@ -966,8 +984,8 @@ private void runCellposeTraining() throws IOException, InterruptedException { } this.parameters.forEach((parameter, value) -> { - cellposeArguments.add("--"+parameter); - if( value != null) { + cellposeArguments.add("--" + parameter); + if (value != null) { cellposeArguments.add(value); } }); @@ -1001,7 +1019,7 @@ private void runCellposeOnValidationImages() { this.model = modelFile.getAbsolutePath(); try { - runCellpose(); + runDetection(); } catch (InterruptedException | IOException e) { logger.error(e.getMessage(), e); } @@ -1036,7 +1054,7 @@ private ResultsTable runCellposeQC() throws IOException, InterruptedException { } // Start the Virtual Environment Runner - VirtualEnvironmentRunner qcRunner = new VirtualEnvironmentRunner(cellposeSetup.getEnvironmentNameOrPath(), cellposeSetup.getEnvironmentType(), this.getClass().getSimpleName() + "-qc"); + VirtualEnvironmentRunner qcRunner = getVirtualEnvironmentRunner(); List qcArguments = new ArrayList<>(Arrays.asList(qcPythonFile.getAbsolutePath(), this.valDirectory.getAbsolutePath(), this.modelFile.getName())); qcRunner.setArguments(qcArguments); @@ -1267,12 +1285,41 @@ public void saveTrainingImages() { List testingAnnotations = allAnnotations.stream().filter(a -> a.getPathClass() == PathClassFactory.getPathClass("Test")).collect(Collectors.toList()); + PixelCalibration resolution = imageData.getServer().getPixelCalibration(); + if (Double.isFinite(pixelSize) && pixelSize > 0) { + double downsample = pixelSize / resolution.getAveragedPixelSize().doubleValue(); + resolution = resolution.createScaledInstance(downsample, downsample); + } logger.info("Found {} Training objects and {} Validation objects in image {}", trainingAnnotations.size(), validationAnnotations.size(), imageName); if (!trainingAnnotations.isEmpty() || !validationAnnotations.isEmpty()) { // Make the server using the required ops - ImageServer processed = ImageOps.buildServer(imageData, op, imageData.getServer().getPixelCalibration(), 2048, 2048); + // Do global preprocessing calculations, if required + ArrayList fullPreprocess = new ArrayList(); + fullPreprocess.add(ImageOps.Core.ensureType(PixelType.FLOAT32)); + if (globalPreprocess != null) { + try { + var normalizeOps = globalPreprocess.createOps(op, imageData, null, null); + fullPreprocess.addAll(normalizeOps); + + // If this has happened, then we should expect to not use the cellpose normalization? + this.parameters.put("no_norm", null); + + } catch (IOException ex) { + throw new RuntimeException("Exception computing global normalization", ex); + } + } + + if (!preprocess.isEmpty()) { + fullPreprocess.addAll(preprocess); + } + if (fullPreprocess.size() > 1) + fullPreprocess.add(ImageOps.Core.ensureType(PixelType.FLOAT32)); + + op = op.appendOps(fullPreprocess.toArray(ImageOp[]::new)); + + ImageServer processed = ImageOps.buildServer(imageData, op, resolution, tileWidth, tileHeight); LabeledImageServer labelServer = new LabeledImageServer.Builder(imageData) .backgroundLabel(0, ColorTools.BLACK) @@ -1323,7 +1370,7 @@ private Collection readObjectsFromFile(File maskFile, RegionReq if (p > maxValue) maxValue = p; } - int maxLabel = (int)maxValue; + int maxLabel = (int) maxValue; Map candidates = new TreeMap<>(); float lastLabel = Float.NaN; @@ -1383,8 +1430,8 @@ public List getTileFiles() { } private static class CandidateObject { - private Geometry geometry; private final double area; + private Geometry geometry; CandidateObject(Geometry geom) { this.geometry = geom; diff --git a/src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java b/src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java index d14fb32..28d19c8 100644 --- a/src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java +++ b/src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java @@ -17,7 +17,6 @@ package qupath.ext.biop.cellpose; import com.google.gson.Gson; -import org.apache.commons.io.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import qupath.lib.analysis.features.ObjectMeasurements.Compartments; @@ -49,9 +48,9 @@ * "Cell Detection with Star-convex Polygons." * International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI), Granada, Spain, September 2018. * - * See the main repo at https://github.com/mpicbg-csbd/stardist + * See the main repo at ... *

- * Very much inspired by stardist-imagej at https://github.com/mpicbg-csbd/stardist-imagej but re-written from scratch to use OpenCV and + * Very much inspired by stardist-imagej at ... but re-written from scratch to use OpenCV and * adapt the method of converting predictions to contours (very slightly) to be more QuPath-friendly. *

* Models are expected in the same format as required by the Fiji plugin, or converted to a frozen .pb file for use with OpenCV. @@ -61,24 +60,23 @@ public class CellposeBuilder { private static final Logger logger = LoggerFactory.getLogger(CellposeBuilder.class); - + private final transient CellposeSetup cellposeSetup; + private final Double probaThreshold = 0.0; + private final Double flowThreshold = 0.0; + private final Double diameter = 0.0; + private final File trainDirectory = null; + private final File valDirectory = null; + private final Integer nEpochs = null; + private final Integer batchSize = null; + private final Double learningRate = Double.NaN; + private final Integer minTrainMasks = null; + private final List ops = new ArrayList<>(); + private final List preprocessing = new ArrayList<>(); + private final LinkedHashMap cellposeParameters = new LinkedHashMap<>(); // Cellpose Related Options private String modelNameOrPath; - private transient CellposeSetup cellposeSetup; - - private Double probaThreshold = 0.0; - private Double flowThreshold = 0.0; - private Double diameter = 0.0; - // Cellpose Training options private File modelDirectory = null; - private File trainDirectory = null; - private File valDirectory = null; - private Integer nEpochs = null; - private Integer batchSize = null; - private Double learningRate = Double.NaN; - private Integer minTrainMasks = null; - // QuPath Object handling options private ColorTransform[] channels = new ColorTransform[0]; private Double cellExpansion = Double.NaN; @@ -90,25 +88,15 @@ public class CellposeBuilder { private PathClass globalPathClass = PathClass.getNullClass(); private Boolean measureShape = Boolean.FALSE; private Boolean constrainToParent = Boolean.TRUE; - private Function creatorFun; private Collection compartments = Arrays.asList(Compartments.values()); private Collection measurements; - private List ops = new ArrayList<>(); - private transient boolean saveBuilder; private transient String builderName; - private Integer overlap = null; - private double simplifyDistance = 1.4; - private Map classifications; - private TileOpCreator globalPreprocessing; - private List preprocessing = new ArrayList<>(); - - private LinkedHashMap cellposeParameters = new LinkedHashMap<>(); private int nThreads = -1; @@ -133,6 +121,7 @@ protected CellposeBuilder(File builderFile) { /** * Build a cellpose model by providing a string which can be the name of a pretrained model or a path to a custom model + * * @param modelPath the model name or path */ protected CellposeBuilder(String modelPath) { @@ -149,7 +138,8 @@ protected CellposeBuilder(String modelPath) { * Specify the number of threads to use for processing. * If you encounter problems, setting this to 1 may help to resolve them by preventing * multithreading. - * @param nThreads + * + * @param nThreads the number of threads to use * @return this builder */ public CellposeBuilder nThreads(int nThreads) { @@ -176,12 +166,11 @@ public CellposeBuilder pixelSize(double pixelSize) { /** * Add preprocessing operations, if required. * - * @param ops + * @param ops a series of ImageOps to apply to the input image * @return this builder */ public CellposeBuilder preprocess(ImageOp... ops) { - for (var op : ops) - this.preprocessing.add(op); + Collections.addAll(this.preprocessing, ops); return this; } @@ -251,7 +240,7 @@ public CellposeBuilder channels(String... channels) { *

* This makes it possible to supply color deconvolved channels, for example. * - * @param channels + * @param channels the channels to use * @return this builder */ public CellposeBuilder channels(ColorTransform... channels) { @@ -268,7 +257,7 @@ public CellposeBuilder channels(ColorTransform... channels) { *

* In short, be wary. * - * @param distance + * @param distance expansion distance in microns * @return this builder */ public CellposeBuilder cellExpansion(double distance) { @@ -303,7 +292,7 @@ public CellposeBuilder createAnnotations() { /** * Request that a classification is applied to all created objects. * - * @param pathClass + * @param pathClass the classification to give to all detected PathObjects * @return this builder */ public CellposeBuilder classify(PathClass pathClass) { @@ -315,11 +304,11 @@ public CellposeBuilder classify(PathClass pathClass) { * Request that a classification is applied to all created objects. * This is a convenience method that get a {@link PathClass} from a String representation. * - * @param pathClassName + * @param pathClassName the classification to give to all detected PathObjects as a String * @return this builder */ public CellposeBuilder classify(String pathClassName) { - return classify(PathClass.fromString(pathClassName, (Integer) null)); + return classify(PathClass.fromString(pathClassName, null)); } /** @@ -554,6 +543,8 @@ public CellposeBuilder addParameter(String flagName) { * @return this builder */ public CellposeBuilder useOmnipose() { + if (cellposeSetup.getOmniposePytonPath() == "") + logger.warn("Omnipose environment path not set. Using cellpose path instead."); addParameter("omni"); return this; } @@ -568,6 +559,12 @@ public CellposeBuilder excludeEdges() { return this; } + /** + * Explicitly set the cellpose channels manually. This corresponds to --chan and --chan2 + * @param channel1 --chan value passed to cellpose/omnipose + * @param channel2 --chan2 value passed to cellpose/omnipose + * @return + */ public CellposeBuilder cellposeChannels(Integer channel1, Integer channel2) { addParameter("chan", channel1.toString()); addParameter("chan2", channel2.toString()); @@ -707,8 +704,9 @@ public CellposeBuilder setOverlap(int overlap) { /** * Convenience method to call global normalization for the dataset - * @param percentileMin the min percentile 0-100 - * @param percentileMax the max percentile 0-100 + * + * @param percentileMin the min percentile 0-100 + * @param percentileMax the max percentile 0-100 * @param normDownsample a large downsample for the computation to be efficient over the whole image * @return this builder */ @@ -729,7 +727,8 @@ public CellposeBuilder normalizePercentilesGlobal(double percentileMin, double p /** * convenience method? to deactivate cellpose normalization. - * @return + * + * @return this builder */ public CellposeBuilder noCellposeNormalization() { return this.addParameter("no_norm"); @@ -738,7 +737,7 @@ public CellposeBuilder noCellposeNormalization() { /** * Create a {@link Cellpose2D}, all ready for detection. * - * @return + * @return a new {@link Cellpose2D} instance */ public Cellpose2D build() { Cellpose2D cellpose = new Cellpose2D(); @@ -773,13 +772,6 @@ public Cellpose2D build() { trainDirectory.mkdirs(); valDirectory.mkdirs(); - // Cleanup a previous run - try { - FileUtils.cleanDirectory(trainDirectory); - FileUtils.cleanDirectory(valDirectory); - } catch (IOException e) { - logger.error(e.getMessage(), e); - } cellpose.modelDirectory = modelDirectory; cellpose.trainDirectory = trainDirectory; cellpose.valDirectory = valDirectory; diff --git a/src/main/java/qupath/ext/biop/cellpose/CellposeExtension.java b/src/main/java/qupath/ext/biop/cellpose/CellposeExtension.java index 2cba2b8..ca62e94 100644 --- a/src/main/java/qupath/ext/biop/cellpose/CellposeExtension.java +++ b/src/main/java/qupath/ext/biop/cellpose/CellposeExtension.java @@ -1,12 +1,6 @@ package qupath.ext.biop.cellpose; -import javafx.beans.property.BooleanProperty; -import javafx.beans.property.ObjectProperty; import javafx.beans.property.StringProperty; -import javafx.collections.FXCollections; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import qupath.ext.biop.cmd.VirtualEnvironmentRunner; import qupath.lib.common.Version; import qupath.lib.gui.QuPathGUI; import qupath.lib.gui.extensions.GitHubProject; @@ -14,7 +8,6 @@ import qupath.lib.gui.panes.PreferencePane; import qupath.lib.gui.prefs.PathPrefs; -import qupath.ext.biop.cmd.VirtualEnvironmentRunner.EnvType; /** * Install Cellpose as an extension. @@ -24,8 +17,6 @@ * @author Olivier Burri */ public class CellposeExtension implements QuPathExtension, GitHubProject { - private final static Logger logger = LoggerFactory.getLogger(CellposeExtension.class); - @Override public GitHubRepo getRepository() { @@ -39,26 +30,24 @@ public void installExtension(QuPathGUI qupath) { CellposeSetup options = CellposeSetup.getInstance(); // Create the options we need - ObjectProperty envType = PathPrefs.createPersistentPreference("cellposeEnvType", EnvType.CONDA, EnvType.class); - StringProperty envPath = PathPrefs.createPersistentPreference("cellposeEnvPath", ""); + StringProperty cellposePath = PathPrefs.createPersistentPreference("cellposePythonPath", ""); + StringProperty omniposePath = PathPrefs.createPersistentPreference("cellposeOmniposePath", ""); //Set options to current values - options.setEnvironmentType(envType.get()); - options.setEnvironmentNameOrPath(envPath.get()); + options.setCellposePytonPath(cellposePath.get()); + options.setOmniposePytonPath(omniposePath.get()); // Listen for property changes - envType.addListener((v,o,n) -> options.setEnvironmentType(n)); - envPath.addListener((v,o,n) -> options.setEnvironmentNameOrPath(n)); + cellposePath.addListener((v, o, n) -> options.setCellposePytonPath(n)); + omniposePath.addListener((v, o, n) -> options.setOmniposePytonPath(n)); // Add Permanent Preferences and Populate Preferences PreferencePane prefs = QuPathGUI.getInstance().getPreferencePane(); - prefs.addPropertyPreference(envPath, String.class, "Cellpose Environment name or directory", "Cellpose", - "Enter either the directory where your chosen Cellpose virtual environment (conda or venv) is located. Or the name of the conda environment you created."); - prefs.addChoicePropertyPreference(envType, - FXCollections.observableArrayList(VirtualEnvironmentRunner.EnvType.values()), - VirtualEnvironmentRunner.EnvType.class,"Cellpose Environment Type", "Cellpose", - "This changes how the environment is started."); + prefs.addPropertyPreference(cellposePath, String.class, "Cellpose Python executable location", "Cellpose/Omnipose", + "Enter the full path to your cellpose environment, including 'python.exe' or equivalent."); + prefs.addPropertyPreference(omniposePath, String.class, "Omnipose Python executable location (Optional)", "Cellpose/Omnipose", + "Enter the full path to your omnipose environment, including 'python.exe' or equivalent."); } @Override @@ -76,11 +65,4 @@ public Version getQuPathVersion() { return QuPathExtension.super.getQuPathVersion(); } - /* // Removed as no longer needed because version gets read into manifest - @Override - public Version getVersion() { - return Version.parse("0.3.5"); - } - - */ } \ No newline at end of file diff --git a/src/main/java/qupath/ext/biop/cellpose/CellposeSetup.java b/src/main/java/qupath/ext/biop/cellpose/CellposeSetup.java index bd118e3..0f38327 100644 --- a/src/main/java/qupath/ext/biop/cellpose/CellposeSetup.java +++ b/src/main/java/qupath/ext/biop/cellpose/CellposeSetup.java @@ -3,29 +3,27 @@ import qupath.ext.biop.cmd.VirtualEnvironmentRunner.EnvType; public class CellposeSetup { - private EnvType envType; - private String environmentNameOrPath; + private static final CellposeSetup instance = new CellposeSetup(); + private String cellposePythonPath = null; + private String omniposePythonPath = null; - - private static CellposeSetup instance = new CellposeSetup(); - - public EnvType getEnvironmentType() { - return envType; + public static CellposeSetup getInstance() { + return instance; } - public void setEnvironmentType(EnvType envType) { - this.envType = envType; + public String getCellposePytonPath() { + return cellposePythonPath; } - public String getEnvironmentNameOrPath() { - return environmentNameOrPath; + public void setCellposePytonPath(String path) { + this.cellposePythonPath = path; } - public void setEnvironmentNameOrPath(String environmentNameOrPath) { - this.environmentNameOrPath = environmentNameOrPath; + public String getOmniposePytonPath() { + return omniposePythonPath; } - public static CellposeSetup getInstance() { - return instance; + public void setOmniposePytonPath(String path) { + this.omniposePythonPath = path; } } diff --git a/src/main/java/qupath/ext/biop/cellpose/OpCreators.java b/src/main/java/qupath/ext/biop/cellpose/OpCreators.java index 7074c7d..9e25a12 100644 --- a/src/main/java/qupath/ext/biop/cellpose/OpCreators.java +++ b/src/main/java/qupath/ext/biop/cellpose/OpCreators.java @@ -1,12 +1,12 @@ /*- * Copyright 2022 QuPath developers, University of Edinburgh - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -41,376 +41,380 @@ /** * Helper class for creating new {@linkplain ImageOp ImageOps} based upon other image properties. *

- * This addresses that problem that every {@link ImageOp} only knows about the image tile that it - * 'sees' at runtime. + * This addresses that problem that every {@link ImageOp} only knows about the image tile that it + * 'sees' at runtime. * This means that all processing needs to be local. *

- * Often, we want ops to use information from across the entire image - particularly for - * normalization as a step in preprocessing, such as when normalizing to zero mean and unit variance + * Often, we want ops to use information from across the entire image - particularly for + * normalization as a step in preprocessing, such as when normalizing to zero mean and unit variance * across the entire image. *

- * Before this class, this was problematic because either the parameters needed to be calculated - * elsewhere (which was awkward), or else normalization would always treat each image tile independent - + * Before this class, this was problematic because either the parameters needed to be calculated + * elsewhere (which was awkward), or else normalization would always treat each image tile independent - * which could result in tiles within the same image being normalized in very different ways. - * + * * @author Pete Bankhead + * @implNote This is currently in development. If it proves useful enough, it is likely to be + * refined and moved to the core QuPath software. * @since v0.4.0 - * @implNote This is currently in development. If it proves useful enough, it is likely to be - * refined and moved to the core QuPath software. */ public class OpCreators { - - /** - * Helper class for creating (tile-based) ImageOps with parameters that are derived from an entire image or ROI. - *

- * This is most useful for normalization, where statistics may need to be calculated across the image - * even if they are then applied locally (e.g. an offset and scaling factor). - * - * @author Pete Bankhead - */ - public static interface TileOpCreator { - - /** - * Compute the (tile-based) ops from the image. - * @param op the data op, which determines how to extract channels from the image data - * @param imageData the image data to process - * @param mask ROI mask that may be used to restrict the region being considered (optional) - * @param plane the 2D image plane to use; if not provided, the plane from any ROI will be used, or otherwise the default plane - * @return - * @throws IOException - */ - public List createOps(ImageDataOp op, ImageData imageData, ROI mask, ImagePlane plane) throws IOException; - - } - - abstract static class DownsampledOpCreator implements TileOpCreator { - - private static final Logger logger = LoggerFactory.getLogger(DownsampledOpCreator.class); - - private boolean useMask = false; - - private double downsample = Double.NaN; - private int maxDimension = 2048; - - DownsampledOpCreator(int maxDimension, double downsample, boolean useMask) { - this.maxDimension = maxDimension; - this.downsample = downsample; - this.useMask = useMask; - } - - DownsampledOpCreator() { - this(2048, Double.NaN, false); - } - - @Override - public List createOps(ImageDataOp op, ImageData imageData, ROI mask, ImagePlane plane) throws IOException { - var server = imageData.getServer(); - double downsample = this.downsample; - - int x = 0, y = 0, width = server.getWidth(), height = server.getHeight(); - if (useMask && mask != null) { - x = (int)Math.round(mask.getBoundsX()); - y = (int)Math.round(mask.getBoundsY()); - width = (int)Math.round(mask.getBoundsWidth()); - height = (int)Math.round(mask.getBoundsHeight()); - } - if (plane == null) { - if (mask == null) { - logger.warn("Plane not specified - will use the default plane"); - plane = ImagePlane.getDefaultPlane(); - } else { - logger.debug("Plane not specified - will use the ROI mask plane"); - plane = mask.getImagePlane(); - } - } - - - if (Double.isNaN(downsample)) { - downsample = Math.max(width, height) / (double)maxDimension; - downsample = Math.max(downsample, 1.0); - logger.info("Computed downsample: {}", downsample); - } - - var request = RegionRequest.createInstance(server.getPath(), downsample, - x, y, width, height, plane.getZ(), plane.getT()); - - try (var scope = new PointerScope()) { - var mat = op.apply(imageData, request); - - if (useMask && mask != null) { - var img = BufferedImageTools.createROIMask(mat.cols(), mat.rows(), mask, request); - var matMask = OpenCVTools.imageToMat(img); - opencv_core.bitwise_not(matMask, matMask); - if (mat.depth() != opencv_core.CV_32F && mat.depth() != opencv_core.CV_64F) - mat.convertTo(mat, opencv_core.CV_32F); - mat.setTo(OpenCVTools.scalarMat(Double.NaN, mat.depth()), matMask); - matMask.close(); - // Show image for debugging + + /** + * Build a normalization op that can be based upon the entire (2D) image, rather than only local tiles. + *

+ * Note that currently this requires downsampling the image to a manageable size. + * + * @return + */ + public static ImageNormalizationBuilder imageNormalizationBuilder() { + return new ImageNormalizationBuilder(); + } + + /** + * Helper class for creating (tile-based) ImageOps with parameters that are derived from an entire image or ROI. + *

+ * This is most useful for normalization, where statistics may need to be calculated across the image + * even if they are then applied locally (e.g. an offset and scaling factor). + * + * @author Pete Bankhead + */ + public interface TileOpCreator { + + /** + * Compute the (tile-based) ops from the image. + * + * @param op the data op, which determines how to extract channels from the image data + * @param imageData the image data to process + * @param mask ROI mask that may be used to restrict the region being considered (optional) + * @param plane the 2D image plane to use; if not provided, the plane from any ROI will be used, or otherwise the default plane + * @return + * @throws IOException + */ + List createOps(ImageDataOp op, ImageData imageData, ROI mask, ImagePlane plane) throws IOException; + + } + + abstract static class DownsampledOpCreator implements TileOpCreator { + + private static final Logger logger = LoggerFactory.getLogger(DownsampledOpCreator.class); + + private boolean useMask = false; + + private double downsample = Double.NaN; + private int maxDimension = 2048; + + DownsampledOpCreator(int maxDimension, double downsample, boolean useMask) { + this.maxDimension = maxDimension; + this.downsample = downsample; + this.useMask = useMask; + } + + DownsampledOpCreator() { + this(2048, Double.NaN, false); + } + + @Override + public List createOps(ImageDataOp op, ImageData imageData, ROI mask, ImagePlane plane) throws IOException { + var server = imageData.getServer(); + double downsample = this.downsample; + + int x = 0, y = 0, width = server.getWidth(), height = server.getHeight(); + if (useMask && mask != null) { + x = (int) Math.round(mask.getBoundsX()); + y = (int) Math.round(mask.getBoundsY()); + width = (int) Math.round(mask.getBoundsWidth()); + height = (int) Math.round(mask.getBoundsHeight()); + } + if (plane == null) { + if (mask == null) { + logger.warn("Plane not specified - will use the default plane"); + plane = ImagePlane.getDefaultPlane(); + } else { + logger.debug("Plane not specified - will use the ROI mask plane"); + plane = mask.getImagePlane(); + } + } + + + if (Double.isNaN(downsample)) { + downsample = Math.max(width, height) / (double) maxDimension; + downsample = Math.max(downsample, 1.0); + logger.info("Computed downsample: {}", downsample); + } + + var request = RegionRequest.createInstance(server.getPath(), downsample, + x, y, width, height, plane.getZ(), plane.getT()); + + try (var scope = new PointerScope()) { + var mat = op.apply(imageData, request); + + if (useMask && mask != null) { + var img = BufferedImageTools.createROIMask(mat.cols(), mat.rows(), mask, request); + var matMask = OpenCVTools.imageToMat(img); + opencv_core.bitwise_not(matMask, matMask); + if (mat.depth() != opencv_core.CV_32F && mat.depth() != opencv_core.CV_64F) + mat.convertTo(mat, opencv_core.CV_32F); + mat.setTo(OpenCVTools.scalarMat(Double.NaN, mat.depth()), matMask); + matMask.close(); + // Show image for debugging // OpenCVTools.matToImagePlus("Masked input", mat).show(); - } - - return compute(mat); - } - } - - protected abstract List compute(Mat mat); - - } - - /** - * Tile op creator that computes offset and scale values across the full image - * to normalize using min and max percentiles. - */ - public static class PercentileTileOpCreator extends DownsampledOpCreator { - - private static final Logger logger = LoggerFactory.getLogger(PercentileTileOpCreator.class); - - private double percentileMin = 0; - private double percentileMax = 99.8; - private boolean perChannel = false; - - private double eps = 1e-6; - - private PercentileTileOpCreator(int maxSize, double downsample, boolean useMask, double percentileMin, double percentileMax, boolean perChannel, double eps) { - super(maxSize, downsample, useMask); - this.percentileMin = percentileMin; - this.percentileMax = percentileMax; - this.perChannel = perChannel; - this.eps = eps; - } - - @Override - protected List compute(Mat mat) { - if (perChannel) { - int nChannels = mat.channels(); - double[] toSubtract = new double[nChannels]; - double[] toScale = new double[nChannels]; - int c = 0; - try (var scope = new PointerScope()) { - for (var matChannel : OpenCVTools.splitChannels(mat)) { - double[] percentiles = OpenCVTools.percentiles(matChannel, percentileMin, percentileMax); - toSubtract[c] = percentiles[0]; - toScale[c] = 1.0/Math.max(percentiles[1] - percentiles[0], eps); - c++; - } - } - logger.info("Computed percentile normalization offsets={}, scales={}", Arrays.toString(toSubtract), Arrays.toString(toScale)); - return List.of( - ImageOps.Core.subtract(toSubtract), - ImageOps.Core.multiply(toScale) ); - } else { - double[] percentiles = OpenCVTools.percentiles(mat, percentileMin, percentileMax); - logger.info("Computed percentiles {}, {}", percentiles[0], percentiles[1]); - return List.of( - ImageOps.Core.subtract(percentiles[0]), - ImageOps.Core.multiply(1.0/Math.max(percentiles[1] - percentiles[0], 1e-6)) ); - } - } - - } - - /** - * Tile op creator that computes offset and scale values across the full image - * to normalize to zero mean and unit variance. - */ - public static class ZeroMeanVarianceTileOpCreator extends DownsampledOpCreator { - - private static final Logger logger = LoggerFactory.getLogger(ZeroMeanVarianceTileOpCreator.class); - - private boolean perChannel = false; - private double eps = 1e-6; - - private ZeroMeanVarianceTileOpCreator(int maxSize, double downsample, boolean useMask, boolean perChannel, double eps) { - super(maxSize, downsample, useMask); - this.perChannel = perChannel; - this.eps = eps; - } - - @Override - protected List compute(Mat mat) { - if (perChannel) { - int nChannels = mat.channels(); - double[] toSubtract = new double[nChannels]; - double[] toScale = new double[nChannels]; - int c = 0; - try (var scope = new PointerScope()) { - for (var matChannel : OpenCVTools.splitChannels(mat)) { - toSubtract[c] = OpenCVTools.mean(matChannel); - toScale[c] = 1.0/(OpenCVTools.stdDev(matChannel) + eps); - c++; - } - } - logger.info("Computed mean/variance normalization offsets={}, scales={}", Arrays.toString(toSubtract), Arrays.toString(toScale)); - return List.of( - ImageOps.Core.subtract(toSubtract), - ImageOps.Core.multiply(toScale) - ); - } else { - double toSubtract = OpenCVTools.mean(mat); - double toScale = 1.0/(OpenCVTools.stdDev(mat) + eps); - logger.info("Computed mean/variance normalization offset={}, scale={}", toSubtract, toScale); - return List.of( - ImageOps.Core.subtract(toSubtract), - ImageOps.Core.multiply(toScale) - ); - } - } - - } - - - - /** - * Builder for a {@link TileOpCreator} that can be used for image preprocessing - * using min/max percentiles or zero-mean-unit-variance normalization. - */ - public static class ImageNormalizationBuilder { - - private static final Logger logger = LoggerFactory.getLogger(ImageNormalizationBuilder.class); - - private boolean zeroMeanUnitVariance = false; - - private double minPercentile = 0; - private double maxPercentile = 100; - - private boolean perChannel = false; - private double eps = 1e-6; // 1e-6 - update javadoc if this changes - - private double downsample = Double.NaN; - private int maxDimension = 2048; // 2048 - update javadoc if this changes - private boolean useMask = false; - - /** - * Specify min and max percentiles to calculate normalization values. - * See {@link Normalize#percentile(double, double)}. - * @param minPercentile - * @param maxPercentile - * @return this builder - */ - public ImageNormalizationBuilder percentiles(double minPercentile, double maxPercentile) { - this.minPercentile = minPercentile; - this.maxPercentile = maxPercentile; - if (zeroMeanUnitVariance) { - logger.warn("Specifying percentiles overrides previous zero-mean-unit-variance request"); - zeroMeanUnitVariance = false; - } - return this; - } - - /** - * Error constant used for numerical stability and avoid dividing by zero. - * Default is 1e-6; - * @param eps - * @return this builder - */ - public ImageNormalizationBuilder eps(double eps) { - this.eps = eps; - return this; - } - - /** - * Compute the normalization values separately per channel; if false, values are computed - * jointly across channels. - * @param perChannel - * @return this builder - */ - public ImageNormalizationBuilder perChannel(boolean perChannel) { - this.perChannel = perChannel; - return this; - } - - /** - * Specify the downsample factor to use when calculating the normalization. - * If this is not provided, then {@link #maxDimension(int)} will be used to calculate - * a downsample value automatically. - *

- * The downsample should be ≥ 1.0 and high enough to ensure that the entire image - * can be fit in memory. A downsample of 1.0 for a whole slide image will probably - * fail due to memory or array size limits. - * - * @param downsample - * @return this builder - * see {@link #maxDimension(int)} - */ - public ImageNormalizationBuilder downsample(double downsample) { - this.downsample = downsample; - return this; - } - - /** - * The maximum width or height, which is used to calculate a downsample factor for - * the image if {@link #downsample(double)} is not specified. - *

- * The current default value is 2048; - * - * @param maxDimension - * @return this builder - */ - public ImageNormalizationBuilder maxDimension(int maxDimension) { - this.maxDimension = maxDimension; - return this; - } - - /** - * Optionally use any ROI mask provided for the calculation. - * This can restrict the region that is considered. - * - * @param useMask - * @return this builder - */ - public ImageNormalizationBuilder useMask(boolean useMask) { - this.useMask = useMask; - return this; - } - - /** - * Normalize for zero mean and unit variance. - * This is an alternative to using {@link #percentiles(double, double)}. - * @return this builder - */ - public ImageNormalizationBuilder zeroMeanUnitVariance() { - return zeroMeanUnitVariance(true); - } - - /** - * Optionally normalize for zero mean and unit variance. - * This is an alternative to using {@link #percentiles(double, double)}. - * @param doZeroMeanUnitVariance - * @return this builder - */ - public ImageNormalizationBuilder zeroMeanUnitVariance(boolean doZeroMeanUnitVariance) { - this.zeroMeanUnitVariance = doZeroMeanUnitVariance; - if (zeroMeanUnitVariance && (minPercentile != 0 || maxPercentile != 100)) - logger.warn("Setting zero-mean-unit-variance will override previous percentiles that were set"); - return this; - } - - /** - * Build a {@link TileOpCreator} according to the builder's parameters. - * @return this builder - */ - public TileOpCreator build() { - if (zeroMeanUnitVariance) { - logger.debug("Creating zero-mean-unit-variance normalization op"); - return new ZeroMeanVarianceTileOpCreator(maxDimension, downsample, useMask, perChannel, eps); - } else { - logger.debug("Creating percentile normalization op"); - return new PercentileTileOpCreator(maxDimension, downsample, useMask, minPercentile, maxPercentile, perChannel, eps); - } - } - - } - - - /** - * Build a normalization op that can be based upon the entire (2D) image, rather than only local tiles. - *

- * Note that currently this requires downsampling the image to a manageable size. - * - * @return - */ - public static ImageNormalizationBuilder imageNormalizationBuilder() { - return new ImageNormalizationBuilder(); - } + } + + return compute(mat); + } + } + + protected abstract List compute(Mat mat); + + } + + /** + * Tile op creator that computes offset and scale values across the full image + * to normalize using min and max percentiles. + */ + public static class PercentileTileOpCreator extends DownsampledOpCreator { + + private static final Logger logger = LoggerFactory.getLogger(PercentileTileOpCreator.class); + + private double percentileMin = 0; + private double percentileMax = 99.8; + private boolean perChannel = false; + + private double eps = 1e-6; + + private PercentileTileOpCreator(int maxSize, double downsample, boolean useMask, double percentileMin, double percentileMax, boolean perChannel, double eps) { + super(maxSize, downsample, useMask); + this.percentileMin = percentileMin; + this.percentileMax = percentileMax; + this.perChannel = perChannel; + this.eps = eps; + } + + @Override + protected List compute(Mat mat) { + if (perChannel) { + int nChannels = mat.channels(); + double[] toSubtract = new double[nChannels]; + double[] toScale = new double[nChannels]; + int c = 0; + try (var scope = new PointerScope()) { + for (var matChannel : OpenCVTools.splitChannels(mat)) { + double[] percentiles = OpenCVTools.percentiles(matChannel, percentileMin, percentileMax); + toSubtract[c] = percentiles[0]; + toScale[c] = 1.0 / Math.max(percentiles[1] - percentiles[0], eps); + c++; + } + } + logger.info("Computed percentile normalization offsets={}, scales={}", Arrays.toString(toSubtract), Arrays.toString(toScale)); + return List.of( + ImageOps.Core.subtract(toSubtract), + ImageOps.Core.multiply(toScale)); + } else { + double[] percentiles = OpenCVTools.percentiles(mat, percentileMin, percentileMax); + logger.info("Computed percentiles {}, {}", percentiles[0], percentiles[1]); + return List.of( + ImageOps.Core.subtract(percentiles[0]), + ImageOps.Core.multiply(1.0 / Math.max(percentiles[1] - percentiles[0], 1e-6))); + } + } + + } + + /** + * Tile op creator that computes offset and scale values across the full image + * to normalize to zero mean and unit variance. + */ + public static class ZeroMeanVarianceTileOpCreator extends DownsampledOpCreator { + + private static final Logger logger = LoggerFactory.getLogger(ZeroMeanVarianceTileOpCreator.class); + + private boolean perChannel = false; + private double eps = 1e-6; + + private ZeroMeanVarianceTileOpCreator(int maxSize, double downsample, boolean useMask, boolean perChannel, double eps) { + super(maxSize, downsample, useMask); + this.perChannel = perChannel; + this.eps = eps; + } + + @Override + protected List compute(Mat mat) { + if (perChannel) { + int nChannels = mat.channels(); + double[] toSubtract = new double[nChannels]; + double[] toScale = new double[nChannels]; + int c = 0; + try (var scope = new PointerScope()) { + for (var matChannel : OpenCVTools.splitChannels(mat)) { + toSubtract[c] = OpenCVTools.mean(matChannel); + toScale[c] = 1.0 / (OpenCVTools.stdDev(matChannel) + eps); + c++; + } + } + logger.info("Computed mean/variance normalization offsets={}, scales={}", Arrays.toString(toSubtract), Arrays.toString(toScale)); + return List.of( + ImageOps.Core.subtract(toSubtract), + ImageOps.Core.multiply(toScale) + ); + } else { + double toSubtract = OpenCVTools.mean(mat); + double toScale = 1.0 / (OpenCVTools.stdDev(mat) + eps); + logger.info("Computed mean/variance normalization offset={}, scale={}", toSubtract, toScale); + return List.of( + ImageOps.Core.subtract(toSubtract), + ImageOps.Core.multiply(toScale) + ); + } + } + + } + + /** + * Builder for a {@link TileOpCreator} that can be used for image preprocessing + * using min/max percentiles or zero-mean-unit-variance normalization. + */ + public static class ImageNormalizationBuilder { + + private static final Logger logger = LoggerFactory.getLogger(ImageNormalizationBuilder.class); + + private boolean zeroMeanUnitVariance = false; + + private double minPercentile = 0; + private double maxPercentile = 100; + + private boolean perChannel = false; + private double eps = 1e-6; // 1e-6 - update javadoc if this changes + + private double downsample = Double.NaN; + private int maxDimension = 2048; // 2048 - update javadoc if this changes + private boolean useMask = false; + + /** + * Specify min and max percentiles to calculate normalization values. + * See {@link Normalize#percentile(double, double)}. + * + * @param minPercentile + * @param maxPercentile + * @return this builder + */ + public ImageNormalizationBuilder percentiles(double minPercentile, double maxPercentile) { + this.minPercentile = minPercentile; + this.maxPercentile = maxPercentile; + if (zeroMeanUnitVariance) { + logger.warn("Specifying percentiles overrides previous zero-mean-unit-variance request"); + zeroMeanUnitVariance = false; + } + return this; + } + + /** + * Error constant used for numerical stability and avoid dividing by zero. + * Default is 1e-6; + * + * @param eps + * @return this builder + */ + public ImageNormalizationBuilder eps(double eps) { + this.eps = eps; + return this; + } + + /** + * Compute the normalization values separately per channel; if false, values are computed + * jointly across channels. + * + * @param perChannel + * @return this builder + */ + public ImageNormalizationBuilder perChannel(boolean perChannel) { + this.perChannel = perChannel; + return this; + } + + /** + * Specify the downsample factor to use when calculating the normalization. + * If this is not provided, then {@link #maxDimension(int)} will be used to calculate + * a downsample value automatically. + *

+ * The downsample should be ≥ 1.0 and high enough to ensure that the entire image + * can be fit in memory. A downsample of 1.0 for a whole slide image will probably + * fail due to memory or array size limits. + * + * @param downsample + * @return this builder + * see {@link #maxDimension(int)} + */ + public ImageNormalizationBuilder downsample(double downsample) { + this.downsample = downsample; + return this; + } + + /** + * The maximum width or height, which is used to calculate a downsample factor for + * the image if {@link #downsample(double)} is not specified. + *

+ * The current default value is 2048; + * + * @param maxDimension + * @return this builder + */ + public ImageNormalizationBuilder maxDimension(int maxDimension) { + this.maxDimension = maxDimension; + return this; + } + + /** + * Optionally use any ROI mask provided for the calculation. + * This can restrict the region that is considered. + * + * @param useMask + * @return this builder + */ + public ImageNormalizationBuilder useMask(boolean useMask) { + this.useMask = useMask; + return this; + } + + /** + * Normalize for zero mean and unit variance. + * This is an alternative to using {@link #percentiles(double, double)}. + * + * @return this builder + */ + public ImageNormalizationBuilder zeroMeanUnitVariance() { + return zeroMeanUnitVariance(true); + } + + /** + * Optionally normalize for zero mean and unit variance. + * This is an alternative to using {@link #percentiles(double, double)}. + * + * @param doZeroMeanUnitVariance + * @return this builder + */ + public ImageNormalizationBuilder zeroMeanUnitVariance(boolean doZeroMeanUnitVariance) { + this.zeroMeanUnitVariance = doZeroMeanUnitVariance; + if (zeroMeanUnitVariance && (minPercentile != 0 || maxPercentile != 100)) + logger.warn("Setting zero-mean-unit-variance will override previous percentiles that were set"); + return this; + } + + /** + * Build a {@link TileOpCreator} according to the builder's parameters. + * + * @return this builder + */ + public TileOpCreator build() { + if (zeroMeanUnitVariance) { + logger.debug("Creating zero-mean-unit-variance normalization op"); + return new ZeroMeanVarianceTileOpCreator(maxDimension, downsample, useMask, perChannel, eps); + } else { + logger.debug("Creating percentile normalization op"); + return new PercentileTileOpCreator(maxDimension, downsample, useMask, minPercentile, maxPercentile, perChannel, eps); + } + } + + } }