diff --git a/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java b/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java index dd1114e..1994688 100644 --- a/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java +++ b/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java @@ -82,7 +82,6 @@ public class Cellpose2D { private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class); private int channel1; private int channel2; - private int nChannels; private double iouThreshold = 0.1; private double simplifyDistance = 1.4; private double probabilityThreshold; @@ -104,6 +103,7 @@ public class Cellpose2D { private boolean measureShape = false; private Collection compartments; private Collection measurements; + private boolean invert; /** * Create a builder to customize detection parameters. @@ -111,7 +111,7 @@ public class Cellpose2D { * or a path to a custom model (as a String) * * @param modelPath name or path to model to use for prediction. - * @return + * @return this builder */ public static Builder builder(String modelPath) { return new Builder(modelPath); @@ -198,8 +198,7 @@ public void detectObjects(ImageData imageData, Collection imageData, Collection imageData, Collection { PathObject parent = tileMap.getObject(); // Read each image - List allDetections = Collections.synchronizedList(new ArrayList()); + List allDetections = Collections.synchronizedList(new ArrayList<>()); tileMap.getTileFiles().parallelStream().forEach(tilefile -> { File ori = tilefile.getFile(); File maskFile = new File(ori.getParent(), FilenameUtils.removeExtension(ori.getName()) + "_cp_masks.tif"); @@ -260,7 +257,7 @@ public void detectObjects(ImageData imageData, Collection n != null) + }).filter(Objects::nonNull) .collect(Collectors.toList()); // Resolve cell overlaps, if needed @@ -268,7 +265,7 @@ public void detectObjects(ImageData imageData, Collection objectToCell(c)).collect(Collectors.toList()); + var cells = filteredDetections.stream().map(Cellpose2D::objectToCell).collect(Collectors.toList()); cells = CellTools.constrainCellOverlaps(cells); filteredDetections = cells.stream().map(c -> cellToObject(c, creatorFun)).collect(Collectors.toList()); } else @@ -368,7 +365,7 @@ private PathObject convertToObject(PathObject object, ImagePlane plane, double c private List filterDetections(List rawDetections) { // Sort by size - Collections.sort(rawDetections, Comparator.comparingDouble(o -> -1 * o.getROI().getArea())); + rawDetections.sort(Comparator.comparingDouble(o -> -1 * o.getROI().getArea())); // Create array of detections to keep & to skip var detections = new LinkedHashSet(); @@ -430,7 +427,7 @@ private List filterDetections(List rawDetections) { * @param imageData the current ImageData * @param request the region we want to save * @return a simple object that contains the request and the associated file in the temp folder - * @throws IOException + * @throws IOException an error in case of read/write issue */ private TileFile saveTileImage(ImageDataOp op, ImageData imageData, RegionRequest request) throws IOException { @@ -479,9 +476,8 @@ private void runCellPose() throws IOException, InterruptedException { VirtualEnvironmentRunner veRunner = new VirtualEnvironmentRunner(cellposeOptions.getEnvironmentNameorPath(), cellposeOptions.getEnvironmentType()); // This is the list of commands after the 'python' call - List cellposeArguments = new ArrayList<>(); - cellposeArguments.addAll(Arrays.asList("-W", "ignore", "-m", "cellpose")); + List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", "cellpose")); cellposeArguments.add("--dir"); cellposeArguments.add("" + this.cellposeTempFolder); @@ -507,8 +503,12 @@ private void runCellPose() throws IOException, InterruptedException { cellposeArguments.add("--save_tif"); cellposeArguments.add("--no_npy"); + cellposeArguments.add("--resample"); + if (invert) cellposeArguments.add("--invert"); + + if (cellposeOptions.useGPU()) cellposeArguments.add("--use_gpu"); veRunner.setArguments(cellposeArguments); @@ -524,7 +524,7 @@ private void runCellPose() throws IOException, InterruptedException { */ public static class Builder { - private String modelPath; + private final String modelPath; private ColorTransform[] channels = new ColorTransform[0]; private double probabilityThreshold = 0.5; @@ -538,8 +538,8 @@ public static class Builder { private double pixelSize = Double.NaN; - private int tileWidth = 2048; - private int tileHeight = 2048; + private int tileWidth = 1024; + private int tileHeight = 1024; private Function creatorFun; @@ -551,11 +551,12 @@ public static class Builder { private boolean constrainToParent = true; - private List ops = new ArrayList<>(); + private final List ops = new ArrayList<>(); private double iouThreshold = 0.1; private int channel1 = 0; //GRAY private int channel2 = 0; // NONE + private boolean isInvert = false; private Builder(String modelPath) { @@ -566,7 +567,7 @@ private Builder(String modelPath) { /** * Probability threshold to apply for detection, between 0 and 1. * - * @param threshold + * @param threshold probability threshold between 0 and 1 (default 0.5) * @return this builder */ public Builder probabilityThreshold(double threshold) { @@ -577,7 +578,7 @@ public Builder probabilityThreshold(double threshold) { /** * Flow threshold to apply for detection, between 0 and 1. * - * @param threshold + * @param threshold flow threshold (default 0.0) * @return this builder */ public Builder flowThreshold(double threshold) { @@ -590,7 +591,7 @@ public Builder flowThreshold(double threshold) { * their expected diameter * * @param diameter in pixels - * @return + * @return this builder */ public Builder diameter(double diameter) { this.diameter = diameter; @@ -600,12 +601,11 @@ public Builder diameter(double diameter) { /** * Add preprocessing operations, if required. * - * @param ops + * @param ops series of ImageOps to apply to this server before saving the images * @return this builder */ public Builder preprocess(ImageOp... ops) { - for (var op : ops) - this.ops.add(op); + Collections.addAll(this.ops, ops); return this; } @@ -648,7 +648,7 @@ public Builder iou(double iouThreshold) { */ public Builder channels(int... channels) { return channels(Arrays.stream(channels) - .mapToObj(c -> ColorTransforms.createChannelExtractor(c)) + .mapToObj(ColorTransforms::createChannelExtractor) .toArray(ColorTransform[]::new)); } @@ -662,7 +662,7 @@ public Builder channels(int... channels) { */ public Builder channels(String... channels) { return channels(Arrays.stream(channels) - .map(c -> ColorTransforms.createChannelExtractor(c)) + .map(ColorTransforms::createChannelExtractor) .toArray(ColorTransform[]::new)); } @@ -671,7 +671,7 @@ public Builder channels(String... channels) { *

* This makes it possible to supply color deconvolved channels, for example. * - * @param channels + * @param channels ColorTransform channels to use, typically only used internally * @return this builder */ public Builder channels(ColorTransform... channels) { @@ -692,7 +692,7 @@ public Builder channels(ColorTransform... channels) { *

* In short, be wary. * - * @param distance + * @param distance cell expansion distance in microns * @return this builder */ public Builder cellExpansion(double distance) { @@ -705,7 +705,7 @@ public Builder cellExpansion(double distance) { * the nucleus size. Only meaningful for values > 1; the nucleus is expanded according * to the scale factor, and used to define the maximum permitted cell expansion. * - * @param scale + * @param scale a number to multiply each pixel in the image by * @return this builder */ public Builder cellConstrainScale(double scale) { @@ -720,14 +720,14 @@ public Builder cellConstrainScale(double scale) { * @return this builder */ public Builder createAnnotations() { - this.creatorFun = r -> PathObjects.createAnnotationObject(r); + this.creatorFun = PathObjects::createAnnotationObject; return this; } /** * Request that a classification is applied to all created objects. * - * @param pathClass + * @param pathClass the PathClass of all detections resulting from this run * @return this builder */ public Builder classify(PathClass pathClass) { @@ -739,7 +739,7 @@ public Builder classify(PathClass pathClass) { * Request that a classification is applied to all created objects. * This is a convenience method that get a {@link PathClass} from {@link PathClassFactory}. * - * @param pathClassName + * @param pathClassName the name of the PathClass for all detections * @return this builder */ public Builder classify(String pathClassName) { @@ -749,7 +749,7 @@ public Builder classify(String pathClassName) { /** * If true, ignore overlaps when computing cell expansion. * - * @param ignore + * @param ignore ignore overlaps when computing cell expansion. * @return this builder */ public Builder ignoreCellOverlaps(boolean ignore) { @@ -760,7 +760,7 @@ public Builder ignoreCellOverlaps(boolean ignore) { /** * If true, constrain nuclei and cells to any parent annotation (default is true). * - * @param constrainToParent + * @param constrainToParent constrain nuclei and cells to any parent annotation * @return this builder */ public Builder constrainToParent(boolean constrainToParent) { @@ -824,7 +824,7 @@ public Builder compartments(Compartments... compartments) { *

* For an image calibrated in microns, the recommended default is approximately 0.5. * - * @param pixelSize + * @param pixelSize Pixel size in microns for the analysis * @return this builder */ public Builder pixelSize(double pixelSize) { @@ -837,7 +837,7 @@ public Builder pixelSize(double pixelSize) { * Note that tiles are independently normalized, and therefore tiling can impact * the results. Default is 1024. * - * @param tileSize + * @param tileSize if the regions must be broken down, how large should the tiles be, in pixels (width and height) * @return this builder */ public Builder tileSize(int tileSize) { @@ -849,8 +849,8 @@ public Builder tileSize(int tileSize) { * Note that tiles are independently normalized, and therefore tiling can impact * the results. Default is 1024. * - * @param tileWidth - * @param tileHeight + * @param tileWidth if the regions must be broken down, how large should the tiles be (width), in pixels + * @param tileHeight if the regions must be broken down, how large should the tiles be (height), in pixels * @return this builder */ public Builder tileSize(int tileWidth, int tileHeight) { @@ -912,7 +912,12 @@ public Builder inputScale(double... values) { return this; } - public Builder cellPoseChannels(int channel1, int channel2) { + public Builder invertChannels(boolean isInvert) { + this.isInvert = isInvert; + return this; + } + + public Builder cellposeChannels(int channel1, int channel2) { this.channel1 = channel1; this.channel2 = channel2; return this; @@ -921,7 +926,7 @@ public Builder cellPoseChannels(int channel1, int channel2) { /** * Create a {@link Cellpose2D}, all ready for detection. * - * @return + * @return a CellPose2D object, ready to be run */ public Cellpose2D build() { @@ -950,13 +955,13 @@ public Cellpose2D build() { cellpose.op = ImageOps.buildImageDataOp(channels).appendOps(mergedOps.toArray(ImageOp[]::new)); // CellPose accepts either one or two channels. This will help the final command - cellpose.nChannels = channels.length; cellpose.channel1 = channel1; cellpose.channel2 = channel2; cellpose.probabilityThreshold = probabilityThreshold; cellpose.flowThreshold = flowThreshold; cellpose.pixelSize = pixelSize; cellpose.diameter = diameter; + cellpose.invert = isInvert; // Overlap for segmentation of tiles. Should be large enough that any object must be "complete" // in at least one tile for resolving overlaps diff --git a/src/main/java/qupath/ext/biop/cellpose/Cellpose2DTraining.java b/src/main/java/qupath/ext/biop/cellpose/Cellpose2DTraining.java index ec521cf..4d3259c 100644 --- a/src/main/java/qupath/ext/biop/cellpose/Cellpose2DTraining.java +++ b/src/main/java/qupath/ext/biop/cellpose/Cellpose2DTraining.java @@ -33,7 +33,8 @@ public class Cellpose2DTraining { private File trainDirectory; private File valDirectory; private double diameter; - private int nChannels; + private int channel1; + private int channel2; private int nEpochs; private ImageDataOp op; private double pixelSize; @@ -41,7 +42,9 @@ public class Cellpose2DTraining { private void saveImagePairs(List annotations, String imageName, ImageServer originalServer, ImageServer labelServer, File saveDirectory) { - if(annotations.isEmpty()) { return; } + if (annotations.isEmpty()) { + return; + } int downsample = 1; if (Double.isFinite(pixelSize) && pixelSize > 0) { downsample = (int) Math.round(pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue()); @@ -73,9 +76,9 @@ private void saveTrainingImages() throws IOException { Project project = QP.getProject(); // Prepare location to save images - project.getImageList().stream().forEach(e -> { + project.getImageList().parallelStream().forEach(e -> { - ImageData imageData = null; + ImageData imageData; try { imageData = e.readImageData(); String imageName = GeneralTools.getNameWithoutExtension(imageData.getServer().getMetadata().getName()); @@ -100,7 +103,7 @@ private void saveTrainingImages() throws IOException { saveImagePairs(trainingAnnotations, imageName, processed, labelServer, trainDirectory); saveImagePairs(validationAnnotations, imageName, processed, labelServer, valDirectory); } catch (Exception ex) { - logger.error( ex.getMessage()); + logger.error(ex.getMessage()); } }); @@ -111,7 +114,7 @@ private File moveAndReturnModelFile() throws IOException { // Find the first file in there File[] all = cellPoseModelFolder.listFiles(); Optional cellPoseModel = Arrays.stream(all).filter(f -> f.getName().contains("cellpose")).findFirst(); - if(cellPoseModel.isPresent()) { + if (cellPoseModel.isPresent()) { logger.info("Found model file at {} ", cellPoseModel); File model = cellPoseModel.get(); File newModel = new File(modelDirectory, model.getName()); @@ -120,6 +123,72 @@ private File moveAndReturnModelFile() throws IOException { } return null; } + + public File train() { + + try { + saveTrainingImages(); + runCellPose(); + return moveAndReturnModelFile(); + + } catch (IOException | InterruptedException e) { + logger.error(e.getMessage(), e); + } + return null; + + } + + public void runCellPose() throws IOException, InterruptedException { + + //python -m cellpose --train --dir ~/images_cyto/train/ --test_dir ~/images_cyto/test/ --pretrained_model cyto --chan 2 --chan2 1 + // Get options + CellposeOptions cellposeOptions = CellposeOptions.getInstance(); + + // Create command to run + VirtualEnvironmentRunner veRunner = new VirtualEnvironmentRunner(cellposeOptions.getEnvironmentNameorPath(), cellposeOptions.getEnvironmentType()); + + // This is the list of commands after the 'python' call + + List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", "cellpose")); + + cellposeArguments.add("--train"); + + cellposeArguments.add("--dir"); + cellposeArguments.add("" + trainDirectory.getAbsolutePath()); + cellposeArguments.add("--test_dir"); + cellposeArguments.add("" + valDirectory.getAbsolutePath()); + + cellposeArguments.add("--pretrained_model"); + if (!Objects.equals(pretrainedModel, "")) { + cellposeArguments.add("" + pretrainedModel); + } else { + cellposeArguments.add("None"); + } + + // The channel order will always be 1 and 2, in the order defined by channels(...) in the builder + cellposeArguments.add("--chan"); + cellposeArguments.add("" + channel1); + cellposeArguments.add("--chan2"); + cellposeArguments.add("" + channel2); + + cellposeArguments.add("--diameter"); + cellposeArguments.add("" + diameter); + + cellposeArguments.add("--n_epochs"); + cellposeArguments.add("" + nEpochs); + + if (invert) cellposeArguments.add("--invert"); + + if (cellposeOptions.useGPU()) cellposeArguments.add("--use_gpu"); + + veRunner.setArguments(cellposeArguments); + + // Finally, we can run Cellpose + veRunner.runCommand(); + + logger.info("Cellpose command finished running"); + } + /** * Builder to help create a {@link Cellpose2D} with custom parameters. */ @@ -127,14 +196,16 @@ public static class Builder { private File modelDirectory; private ColorTransforms.ColorTransform[] channels = new ColorTransforms.ColorTransform[0]; - private String pretrainedModel = "cyto"; + private String pretrainedModel; private double diameter = 100; private int nEpochs; + private int channel1 = 0; //GRAY + private int channel2 = 0; // NONE private double pixelSize = Double.NaN; - private List ops = new ArrayList<>(); + private final List ops = new ArrayList<>(); private boolean isInvert; private Builder(String pretrainedModel) { @@ -142,7 +213,7 @@ private Builder(String pretrainedModel) { this.ops.add(ImageOps.Core.ensureType(PixelType.FLOAT32)); } - Builder modelDirectory( File modelDirectory) { + Builder modelDirectory(File modelDirectory) { this.modelDirectory = modelDirectory; return this; } @@ -157,7 +228,7 @@ Builder epochs(int nEpochs) { * their expected diameter * * @param diameter in pixels - * @return + * @return this builder */ public Builder diameter(double diameter) { this.diameter = diameter; @@ -167,12 +238,11 @@ public Builder diameter(double diameter) { /** * Add preprocessing operations, if required. * - * @param ops + * @param ops ImageOps to preprocess the image * @return this builder */ public Builder preprocess(ImageOp... ops) { - for (var op : ops) - this.ops.add(op); + Collections.addAll(this.ops, ops); return this; } @@ -186,7 +256,7 @@ public Builder preprocess(ImageOp... ops) { */ public Builder channels(int... channels) { return channels(Arrays.stream(channels) - .mapToObj(c -> ColorTransforms.createChannelExtractor(c)) + .mapToObj(ColorTransforms::createChannelExtractor) .toArray(ColorTransforms.ColorTransform[]::new)); } @@ -200,7 +270,7 @@ public Builder channels(int... channels) { */ public Builder channels(String... channels) { return channels(Arrays.stream(channels) - .map(c -> ColorTransforms.createChannelExtractor(c)) + .map(ColorTransforms::createChannelExtractor) .toArray(ColorTransforms.ColorTransform[]::new)); } @@ -209,7 +279,7 @@ public Builder channels(String... channels) { *

* This makes it possible to supply color deconvolved channels, for example. * - * @param channels + * @param channels the ColorTransform channels we want. Mostly used internally * @return this builder */ public Builder channels(ColorTransforms.ColorTransform... channels) { @@ -225,7 +295,7 @@ public Builder channels(ColorTransforms.ColorTransform... channels) { *

* For an image calibrated in microns, the recommended default is approximately 0.5. * - * @param pixelSize + * @param pixelSize the requested pixel size in microns * @return this builder */ public Builder pixelSize(double pixelSize) { @@ -265,20 +335,26 @@ public Builder inputScale(double... values) { return this; } - public Builder invertChannels( boolean isInvert ) { + public Builder invertChannels(boolean isInvert) { this.isInvert = isInvert; return this; } + public Builder cellposeChannels(int channel1, int channel2) { + this.channel1 = channel1; + this.channel2 = channel2; + return this; + } + /** * Create a {@link Cellpose2D}, all ready for detection. * - * @return + * @return a Cellpose2DTraining object, ready to be run */ public Cellpose2DTraining build() { // Directory to move trained models. - if( modelDirectory == null) { + if (modelDirectory == null) { modelDirectory = new File(QP.getProject().getPath().getParent().toFile(), "models"); } @@ -313,11 +389,12 @@ public Cellpose2DTraining build() { } mergedOps.add(ImageOps.Core.ensureType(PixelType.FLOAT32)); - cellpose.op = ImageOps.buildImageDataOp(channels) + cellpose.op = ImageOps.buildImageDataOp(channels) .appendOps(mergedOps.toArray(ImageOp[]::new)); // CellPose accepts either one or two channels. This will help the final command - cellpose.nChannels = channels.length; + cellpose.channel1 = channel1; + cellpose.channel2 = channel2; cellpose.pixelSize = pixelSize; cellpose.diameter = diameter; cellpose.invert = isInvert; @@ -330,77 +407,4 @@ public Cellpose2DTraining build() { } } - - public File train() { - - try { - saveTrainingImages(); - runCellPose(); - File model = moveAndReturnModelFile(); - - return model; - } catch (IOException | InterruptedException e) { - logger.error(e.getMessage(), e); - } - return null; - - } - public void runCellPose() throws IOException, InterruptedException { - - //python -m cellpose --train --dir ~/images_cyto/train/ --test_dir ~/images_cyto/test/ --pretrained_model cyto --chan 2 --chan2 1 - // Get options - CellposeOptions cellposeOptions = CellposeOptions.getInstance(); - - // Create command to run - VirtualEnvironmentRunner veRunner = new VirtualEnvironmentRunner(cellposeOptions.getEnvironmentNameorPath(), cellposeOptions.getEnvironmentType()); - - // This is the list of commands after the 'python' call - List cellposeArguments = new ArrayList<>(); - - cellposeArguments.addAll(Arrays.asList("-W", "ignore", "-m", "cellpose")); - - cellposeArguments.add("--train"); - - cellposeArguments.add("--dir"); - cellposeArguments.add("" + trainDirectory.getAbsolutePath()); - cellposeArguments.add("--test_dir"); - cellposeArguments.add("" + valDirectory.getAbsolutePath()); - - cellposeArguments.add("--pretrained_model"); - if(pretrainedModel != "") { - cellposeArguments.add("" + pretrainedModel); - } else { - cellposeArguments.add("None"); - } - - // The channel order will always be 1 and 2, in the order defined by channels(...) in the builder - if (nChannels > 1) { - cellposeArguments.add("--chan"); - cellposeArguments.add("1"); - - cellposeArguments.add("--chan2"); - cellposeArguments.add("2"); - } else { - cellposeArguments.add("--chan"); - cellposeArguments.add("0"); - } - - - cellposeArguments.add("--diameter"); - cellposeArguments.add("" + diameter); - - cellposeArguments.add("--n_epochs"); - cellposeArguments.add("" + nEpochs); - - if(invert) cellposeArguments.add("--invert"); - - if (cellposeOptions.useGPU()) cellposeArguments.add("--use_gpu"); - - veRunner.setArguments(cellposeArguments); - - // Finally, we can run Cellpose - veRunner.runCommand(); - - logger.info("Cellpose command finished running"); - } } \ No newline at end of file