Skip to content

Commit

Permalink
adds new logic for setting channels for cellpose
Browse files Browse the repository at this point in the history
This can help for GRB images where we want to send the whole image to cellpose but want to use the gray channel.

Adds inversion possibility for channels with Builder invertChannels(boolean)

Code formatting
  • Loading branch information
lacan committed Oct 28, 2021
1 parent 59567fb commit 62515e4
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 136 deletions.
87 changes: 46 additions & 41 deletions src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ public class Cellpose2D {
private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class);
private int channel1;
private int channel2;
private int nChannels;
private double iouThreshold = 0.1;
private double simplifyDistance = 1.4;
private double probabilityThreshold;
Expand All @@ -104,14 +103,15 @@ public class Cellpose2D {
private boolean measureShape = false;
private Collection<ObjectMeasurements.Compartments> compartments;
private Collection<ObjectMeasurements.Measurements> measurements;
private boolean invert;

/**
* Create a builder to customize detection parameters.
* This accepts either Text describing the built-in models from cellpose (cyto, cyto2, nuc)
* or a path to a custom model (as a String)
*
* @param modelPath name or path to model to use for prediction.
* @return
* @return this builder
*/
public static Builder builder(String modelPath) {
return new Builder(modelPath);
Expand Down Expand Up @@ -198,8 +198,7 @@ public void detectObjects(ImageData<BufferedImage> imageData, Collection<? exten
// Make a new RegionRequest
var region = RegionRequest.createInstance(server.getPath(), finalDownsample, t);
try {
TileFile file = saveTileImage(op, imageData, region);
return file;
return saveTileImage(op, imageData, region);
} catch (IOException e) {
e.printStackTrace();
}
Expand All @@ -212,17 +211,15 @@ public void detectObjects(ImageData<BufferedImage> imageData, Collection<? exten
// Here the files are saved, and we can run cellpose.
try {
runCellPose();
} catch (IOException e) {
logger.error("Failed to Run Cellpose", e);
} catch (InterruptedException e) {
} catch (IOException | InterruptedException e) {
logger.error("Failed to Run Cellpose", e);
}

// Recover all the images from CellPose to get the masks
allTiles.parallelStream().forEach(tileMap -> {
PathObject parent = tileMap.getObject();
// Read each image
List<PathObject> allDetections = Collections.synchronizedList(new ArrayList<PathObject>());
List<PathObject> allDetections = Collections.synchronizedList(new ArrayList<>());
tileMap.getTileFiles().parallelStream().forEach(tilefile -> {
File ori = tilefile.getFile();
File maskFile = new File(ori.getParent(), FilenameUtils.removeExtension(ori.getName()) + "_cp_masks.tif");
Expand Down Expand Up @@ -260,15 +257,15 @@ public void detectObjects(ImageData<BufferedImage> imageData, Collection<? exten
logger.warn("Error converting to object: " + e.getLocalizedMessage(), e);
return null;
}
}).filter(n -> n != null)
}).filter(Objects::nonNull)
.collect(Collectors.toList());

// Resolve cell overlaps, if needed
if (expansion > 0 && !ignoreCellOverlaps) {
logger.info("Resolving cell overlaps for {}", parent);
if (creatorFun != null) {
// It's awkward, but we need to temporarily convert to cells and back
var cells = filteredDetections.stream().map(c -> objectToCell(c)).collect(Collectors.toList());
var cells = filteredDetections.stream().map(Cellpose2D::objectToCell).collect(Collectors.toList());
cells = CellTools.constrainCellOverlaps(cells);
filteredDetections = cells.stream().map(c -> cellToObject(c, creatorFun)).collect(Collectors.toList());
} else
Expand Down Expand Up @@ -368,7 +365,7 @@ private PathObject convertToObject(PathObject object, ImagePlane plane, double c
private List<PathObject> filterDetections(List<PathObject> rawDetections) {

// Sort by size
Collections.sort(rawDetections, Comparator.comparingDouble(o -> -1 * o.getROI().getArea()));
rawDetections.sort(Comparator.comparingDouble(o -> -1 * o.getROI().getArea()));

// Create array of detections to keep & to skip
var detections = new LinkedHashSet<PathObject>();
Expand Down Expand Up @@ -430,7 +427,7 @@ private List<PathObject> filterDetections(List<PathObject> rawDetections) {
* @param imageData the current ImageData
* @param request the region we want to save
* @return a simple object that contains the request and the associated file in the temp folder
* @throws IOException
* @throws IOException an error in case of read/write issue
*/
private TileFile saveTileImage(ImageDataOp op, ImageData<BufferedImage> imageData, RegionRequest request) throws IOException {

Expand Down Expand Up @@ -479,9 +476,8 @@ private void runCellPose() throws IOException, InterruptedException {
VirtualEnvironmentRunner veRunner = new VirtualEnvironmentRunner(cellposeOptions.getEnvironmentNameorPath(), cellposeOptions.getEnvironmentType());

// This is the list of commands after the 'python' call
List<String> cellposeArguments = new ArrayList<>();

cellposeArguments.addAll(Arrays.asList("-W", "ignore", "-m", "cellpose"));
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", "cellpose"));

cellposeArguments.add("--dir");
cellposeArguments.add("" + this.cellposeTempFolder);
Expand All @@ -507,8 +503,12 @@ private void runCellPose() throws IOException, InterruptedException {
cellposeArguments.add("--save_tif");

cellposeArguments.add("--no_npy");

cellposeArguments.add("--resample");

if (invert) cellposeArguments.add("--invert");


if (cellposeOptions.useGPU()) cellposeArguments.add("--use_gpu");

veRunner.setArguments(cellposeArguments);
Expand All @@ -524,7 +524,7 @@ private void runCellPose() throws IOException, InterruptedException {
*/
public static class Builder {

private String modelPath;
private final String modelPath;
private ColorTransform[] channels = new ColorTransform[0];

private double probabilityThreshold = 0.5;
Expand All @@ -538,8 +538,8 @@ public static class Builder {

private double pixelSize = Double.NaN;

private int tileWidth = 2048;
private int tileHeight = 2048;
private int tileWidth = 1024;
private int tileHeight = 1024;

private Function<ROI, PathObject> creatorFun;

Expand All @@ -551,11 +551,12 @@ public static class Builder {

private boolean constrainToParent = true;

private List<ImageOp> ops = new ArrayList<>();
private final List<ImageOp> ops = new ArrayList<>();
private double iouThreshold = 0.1;

private int channel1 = 0; //GRAY
private int channel2 = 0; // NONE
private boolean isInvert = false;


private Builder(String modelPath) {
Expand All @@ -566,7 +567,7 @@ private Builder(String modelPath) {
/**
* Probability threshold to apply for detection, between 0 and 1.
*
* @param threshold
* @param threshold probability threshold between 0 and 1 (default 0.5)
* @return this builder
*/
public Builder probabilityThreshold(double threshold) {
Expand All @@ -577,7 +578,7 @@ public Builder probabilityThreshold(double threshold) {
/**
* Flow threshold to apply for detection, between 0 and 1.
*
* @param threshold
* @param threshold flow threshold (default 0.0)
* @return this builder
*/
public Builder flowThreshold(double threshold) {
Expand All @@ -590,7 +591,7 @@ public Builder flowThreshold(double threshold) {
* their expected diameter
*
* @param diameter in pixels
* @return
* @return this builder
*/
public Builder diameter(double diameter) {
this.diameter = diameter;
Expand All @@ -600,12 +601,11 @@ public Builder diameter(double diameter) {
/**
* Add preprocessing operations, if required.
*
* @param ops
* @param ops series of ImageOps to apply to this server before saving the images
* @return this builder
*/
public Builder preprocess(ImageOp... ops) {
for (var op : ops)
this.ops.add(op);
Collections.addAll(this.ops, ops);
return this;
}

Expand Down Expand Up @@ -648,7 +648,7 @@ public Builder iou(double iouThreshold) {
*/
public Builder channels(int... channels) {
return channels(Arrays.stream(channels)
.mapToObj(c -> ColorTransforms.createChannelExtractor(c))
.mapToObj(ColorTransforms::createChannelExtractor)
.toArray(ColorTransform[]::new));
}

Expand All @@ -662,7 +662,7 @@ public Builder channels(int... channels) {
*/
public Builder channels(String... channels) {
return channels(Arrays.stream(channels)
.map(c -> ColorTransforms.createChannelExtractor(c))
.map(ColorTransforms::createChannelExtractor)
.toArray(ColorTransform[]::new));
}

Expand All @@ -671,7 +671,7 @@ public Builder channels(String... channels) {
* <p>
* This makes it possible to supply color deconvolved channels, for example.
*
* @param channels
* @param channels ColorTransform channels to use, typically only used internally
* @return this builder
*/
public Builder channels(ColorTransform... channels) {
Expand All @@ -692,7 +692,7 @@ public Builder channels(ColorTransform... channels) {
* <p>
* In short, be wary.
*
* @param distance
* @param distance cell expansion distance in microns
* @return this builder
*/
public Builder cellExpansion(double distance) {
Expand All @@ -705,7 +705,7 @@ public Builder cellExpansion(double distance) {
* the nucleus size. Only meaningful for values &gt; 1; the nucleus is expanded according
* to the scale factor, and used to define the maximum permitted cell expansion.
*
* @param scale
* @param scale a number to multiply each pixel in the image by
* @return this builder
*/
public Builder cellConstrainScale(double scale) {
Expand All @@ -720,14 +720,14 @@ public Builder cellConstrainScale(double scale) {
* @return this builder
*/
public Builder createAnnotations() {
this.creatorFun = r -> PathObjects.createAnnotationObject(r);
this.creatorFun = PathObjects::createAnnotationObject;
return this;
}

/**
* Request that a classification is applied to all created objects.
*
* @param pathClass
* @param pathClass the PathClass of all detections resulting from this run
* @return this builder
*/
public Builder classify(PathClass pathClass) {
Expand All @@ -739,7 +739,7 @@ public Builder classify(PathClass pathClass) {
* Request that a classification is applied to all created objects.
* This is a convenience method that get a {@link PathClass} from {@link PathClassFactory}.
*
* @param pathClassName
* @param pathClassName the name of the PathClass for all detections
* @return this builder
*/
public Builder classify(String pathClassName) {
Expand All @@ -749,7 +749,7 @@ public Builder classify(String pathClassName) {
/**
* If true, ignore overlaps when computing cell expansion.
*
* @param ignore
* @param ignore ignore overlaps when computing cell expansion.
* @return this builder
*/
public Builder ignoreCellOverlaps(boolean ignore) {
Expand All @@ -760,7 +760,7 @@ public Builder ignoreCellOverlaps(boolean ignore) {
/**
* If true, constrain nuclei and cells to any parent annotation (default is true).
*
* @param constrainToParent
* @param constrainToParent constrain nuclei and cells to any parent annotation
* @return this builder
*/
public Builder constrainToParent(boolean constrainToParent) {
Expand Down Expand Up @@ -824,7 +824,7 @@ public Builder compartments(Compartments... compartments) {
* <p>
* For an image calibrated in microns, the recommended default is approximately 0.5.
*
* @param pixelSize
* @param pixelSize Pixel size in microns for the analysis
* @return this builder
*/
public Builder pixelSize(double pixelSize) {
Expand All @@ -837,7 +837,7 @@ public Builder pixelSize(double pixelSize) {
* Note that tiles are independently normalized, and therefore tiling can impact
* the results. Default is 1024.
*
* @param tileSize
* @param tileSize if the regions must be broken down, how large should the tiles be, in pixels (width and height)
* @return this builder
*/
public Builder tileSize(int tileSize) {
Expand All @@ -849,8 +849,8 @@ public Builder tileSize(int tileSize) {
* Note that tiles are independently normalized, and therefore tiling can impact
* the results. Default is 1024.
*
* @param tileWidth
* @param tileHeight
* @param tileWidth if the regions must be broken down, how large should the tiles be (width), in pixels
* @param tileHeight if the regions must be broken down, how large should the tiles be (height), in pixels
* @return this builder
*/
public Builder tileSize(int tileWidth, int tileHeight) {
Expand Down Expand Up @@ -912,7 +912,12 @@ public Builder inputScale(double... values) {
return this;
}

public Builder cellPoseChannels(int channel1, int channel2) {
public Builder invertChannels(boolean isInvert) {
this.isInvert = isInvert;
return this;
}

public Builder cellposeChannels(int channel1, int channel2) {
this.channel1 = channel1;
this.channel2 = channel2;
return this;
Expand All @@ -921,7 +926,7 @@ public Builder cellPoseChannels(int channel1, int channel2) {
/**
* Create a {@link Cellpose2D}, all ready for detection.
*
* @return
* @return a CellPose2D object, ready to be run
*/
public Cellpose2D build() {

Expand Down Expand Up @@ -950,13 +955,13 @@ public Cellpose2D build() {
cellpose.op = ImageOps.buildImageDataOp(channels).appendOps(mergedOps.toArray(ImageOp[]::new));

// CellPose accepts either one or two channels. This will help the final command
cellpose.nChannels = channels.length;
cellpose.channel1 = channel1;
cellpose.channel2 = channel2;
cellpose.probabilityThreshold = probabilityThreshold;
cellpose.flowThreshold = flowThreshold;
cellpose.pixelSize = pixelSize;
cellpose.diameter = diameter;
cellpose.invert = isInvert;

// Overlap for segmentation of tiles. Should be large enough that any object must be "complete"
// in at least one tile for resolving overlaps
Expand Down
Loading

0 comments on commit 62515e4

Please sign in to comment.