diff --git a/build.gradle b/build.gradle index 8ec5ea3..5910529 100644 --- a/build.gradle +++ b/build.gradle @@ -25,7 +25,7 @@ ext.qupathVersion = gradle.ext.qupathVersion description = 'QuPath extension to use Cellpose' -version = "0.9.2" +version = "0.9.3-SNAPSHOT" dependencies { implementation "io.github.qupath:qupath-gui-fx:${qupathVersion}" diff --git a/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java b/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java index b6da2c8..a7cf65a 100644 --- a/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java +++ b/src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java @@ -109,6 +109,7 @@ public class Cellpose2D { private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class); public ImageOp extendChannelOp; + public boolean useGPU; protected double simplifyDistance = 1.4; @@ -412,6 +413,11 @@ public void detectObjectsImpl(ImageData imageData, Collection> candidatesPerParent = allTiles.values().stream() .flatMap(t -> t.getCandidates().stream()) @@ -721,12 +727,15 @@ private VirtualEnvironmentRunner getVirtualEnvironmentRunner() { */ private LinkedHashMap runCellpose(LinkedHashMap allTiles) throws InterruptedException, IOException { - + // Need to define the name of the command we are running. We used to be able to use 'cellpose' for both but not since Cellpose v2 String runCommand = this.parameters.containsKey("omni") ? "omnipose" : "cellpose"; VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner(); // This is the list of commands after the 'python' call - List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand)); + // We want to ignore all warnings to make sure the log is clean (-W ignore) + // We want to be able to call the module by name (-m) + // We want to make sure UTF8 mode is by default (-X utf8) + List cellposeArguments = new ArrayList<>(Arrays.asList("-Xutf8", "-W", "ignore", "-m", runCommand)); cellposeArguments.add("--dir"); cellposeArguments.add("" + this.tempDirectory); @@ -746,14 +755,14 @@ private LinkedHashMap runCellpose(LinkedHashMap cellposeArguments.add("--no_npy"); - cellposeArguments.add("--use_gpu"); + if( this.useGPU ) cellposeArguments.add("--use_gpu"); cellposeArguments.add("--verbose"); veRunner.setArguments(cellposeArguments); // Finally, we can run Cellpose - veRunner.runCommand(); + veRunner.runCommand(false); return processCellposeFiles(veRunner, allTiles); @@ -761,10 +770,17 @@ private LinkedHashMap runCellpose(LinkedHashMap private LinkedHashMap processCellposeFiles(VirtualEnvironmentRunner veRunner, LinkedHashMap allTiles) throws CancellationException, InterruptedException, IOException { + // Make sure that allTiles is not null, if it is, just return null + // as we are likely just running validation and thus do not need to give any results back + if (allTiles == null ) { + veRunner.getProcess().waitFor(); + return null; + } + // Build a thread pool to process reading the images in parallel ExecutorService executor = Executors.newFixedThreadPool(5); - if (!this.doReadResultsAsynchronously || allTiles == null) { + if (!this.doReadResultsAsynchronously) { // We need to wait for the process to finish veRunner.getProcess().waitFor(); allTiles.entrySet().forEach(entry -> { @@ -893,7 +909,7 @@ private void runTraining() throws IOException, InterruptedException { VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner(); // This is the list of commands after the 'python' call - List cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand)); + List cellposeArguments = new ArrayList<>(Arrays.asList( "-Xutf8", "-W", "ignore", "-m", runCommand)); cellposeArguments.add("--train"); @@ -917,18 +933,15 @@ private void runTraining() throws IOException, InterruptedException { } }); - - cellposeArguments.add("--use_gpu"); + // Some people may deactivate this... + if( this.useGPU ) cellposeArguments.add("--use_gpu"); cellposeArguments.add("--verbose"); veRunner.setArguments(cellposeArguments); // Finally, we can run Cellpose - veRunner.runCommand(); - - // Wait for the process to finish - veRunner.getProcess().waitFor(); + veRunner.runCommand(true); // Get the log this.theLog = veRunner.getProcessLog(); @@ -989,7 +1002,8 @@ private ResultsTable runCellposeQC() throws IOException, InterruptedException { qcRunner.setArguments(qcArguments); - qcRunner.runCommand(); + qcRunner.runCommand(true); + // The results are stored in the validation directory, open them as a results table File qcResults = new File( getValidationDirectory(), "QC-Results" + File.separator + "Quality_Control for " + this.modelFile.getName() + ".csv"); @@ -1047,20 +1061,25 @@ private ResultsTable parseTrainingResults() { if (this.theLog != null) { // Try to parse the output of Cellpose to give meaningful information to the user. This is very old school - // Look for "Epoch 0, Time 2.3s, Loss 1.0758, Loss Test 0.6007, LR 0.2000" - String epochPattern = ".*Epoch\\s*(\\d+),\\s*Time\\s*(\\d+\\.\\d)s,\\s*Loss\\s*(\\d+\\.\\d+),\\s*Loss Test\\s*(\\d+\\.\\d+),\\s*LR\\s*(\\d+\\.\\d+).*"; - // Build Matcher - Pattern pattern = Pattern.compile(epochPattern); - Matcher m; for (String line : this.theLog) { - m = pattern.matcher(line); - if (m.find()) { - trainingResults.incrementCounter(); - trainingResults.addValue("Epoch", Double.parseDouble(m.group(1))); - trainingResults.addValue("Time[s]", Double.parseDouble(m.group(2))); - trainingResults.addValue("Loss", Double.parseDouble(m.group(3))); - trainingResults.addValue("Loss Test", Double.parseDouble(m.group(4))); - trainingResults.addValue("LR", Double.parseDouble(m.group(5))); + Matcher m; + for (LogParser parser : LogParser.values()) { + m = parser.getPattern().matcher(line); + if (m.find()) { + trainingResults.incrementCounter(); + trainingResults.addValue("Epoch", Double.parseDouble(m.group("epoch"))); + trainingResults.addValue("Time", Double.parseDouble(m.group("time"))); + trainingResults.addValue("Loss", Double.parseDouble(m.group("loss"))); + if (parser != LogParser.OMNI) { // Omnipose does not provide validation loss + trainingResults.addValue("Validation Loss", Double.parseDouble(m.group("val"))); + trainingResults.addValue("LR", Double.parseDouble(m.group("lr"))); + + } else { + trainingResults.addValue("Validation Loss", Double.NaN); + trainingResults.addValue("LR", Double.NaN); + + } + } } } } @@ -1104,7 +1123,7 @@ public void showTrainingGraph(boolean show, boolean save) { //populating the series with data for (int i = 0; i < output.getCounter(); i++) { loss.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Loss", i))); - lossTest.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Loss Test", i))); + lossTest.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Validation Loss", i))); } lineChart.getData().add(loss); @@ -1166,18 +1185,18 @@ private void saveImagePairs(List annotations, String imageName, Imag if (annotations.isEmpty()) { return; } - int downsample = 1; + double downsample; if (Double.isFinite(pixelSize) && pixelSize > 0) { - downsample = (int) Math.round(pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue()); + downsample = pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue(); + } else { + downsample = 1.0; } - AtomicInteger idx = new AtomicInteger(); - int finalDownsample = downsample; annotations.forEach(a -> { int i = idx.getAndIncrement(); - RegionRequest request = RegionRequest.createInstance(originalServer.getPath(), finalDownsample, a.getROI()); + RegionRequest request = RegionRequest.createInstance(originalServer.getPath(), downsample, a.getROI()); File imageFile = new File(saveDirectory, imageName + "_region_" + i + ".tif"); File maskFile = new File(saveDirectory, imageName + "_region_" + i + "_masks.tif"); try { @@ -1348,6 +1367,7 @@ private Collection readObjectsFromFile(TileFile tileFile) throw } } // Ignore the IDs, because they will be the same across different images, and we don't really need them + if(candidates.isEmpty()) return Collections.emptyList(); return candidates.values(); } @@ -1424,4 +1444,25 @@ private static class CandidateObject { geometry = geometry.getGeometryN(index); } } + public enum LogParser { + + // Cellpose 2 pattern when training : "Look for "Epoch 0, Time 2.3s, Loss 1.0758, Loss Test 0.6007, LR 0.2000" + // Cellpose 3 pattern when training : "5, train_loss=2.6546, test_loss=2.0054, LR=0.1111, time 2.56s" + // Omnipose pattern when training : "Train epoch: 10 | Time: 0.22min | last epoch: 0.74s | : 0.73s | : 0.33s | : 5.076259 | : 4.429341" + // WARNING: Currently Omnipose does not provide any output to the validation loss (Test loss in Cellpose) + CP2("Cellpose v2", ".*Epoch\\s*(?\\d+),\\s*Time\\s*(?