Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes and update to Cellpose 3 logging #46

Merged
merged 11 commits into from
Apr 30, 2024
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ext.qupathVersion = gradle.ext.qupathVersion

description = 'QuPath extension to use Cellpose'

version = "0.9.2"
version = "0.9.3-SNAPSHOT"

dependencies {
implementation "io.github.qupath:qupath-gui-fx:${qupathVersion}"
Expand Down
105 changes: 73 additions & 32 deletions src/main/java/qupath/ext/biop/cellpose/Cellpose2D.java
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public class Cellpose2D {
private final static Logger logger = LoggerFactory.getLogger(Cellpose2D.class);

public ImageOp extendChannelOp;
public boolean useGPU;

protected double simplifyDistance = 1.4;

Expand Down Expand Up @@ -412,6 +413,11 @@ public void detectObjectsImpl(ImageData<BufferedImage> imageData, Collection<? e

logger.info("All tiles and objects read, now resolving overlaps");

// In case alltiles is null, we are basically done
if( allTiles == null ) {
logger.info("No results from Cellpose", "There is nothing to recover from cellpose");
}

// Group the candidates per parent object, as this is needed to optimize when checking for overlap
Map<PathObject, List<CandidateObject>> candidatesPerParent = allTiles.values().stream()
.flatMap(t -> t.getCandidates().stream())
Expand Down Expand Up @@ -721,12 +727,15 @@ private VirtualEnvironmentRunner getVirtualEnvironmentRunner() {
*/
private LinkedHashMap<File, TileFile> runCellpose(LinkedHashMap<File, TileFile> allTiles) throws InterruptedException, IOException {


// Need to define the name of the command we are running. We used to be able to use 'cellpose' for both but not since Cellpose v2
String runCommand = this.parameters.containsKey("omni") ? "omnipose" : "cellpose";
VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner();

// This is the list of commands after the 'python' call
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand));
// We want to ignore all warnings to make sure the log is clean (-W ignore)
// We want to be able to call the module by name (-m)
// We want to make sure UTF8 mode is by default (-X utf8)
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-Xutf8", "-W", "ignore", "-m", runCommand));

cellposeArguments.add("--dir");
cellposeArguments.add("" + this.tempDirectory);
Expand All @@ -746,25 +755,32 @@ private LinkedHashMap<File, TileFile> runCellpose(LinkedHashMap<File, TileFile>

cellposeArguments.add("--no_npy");

cellposeArguments.add("--use_gpu");
if( this.useGPU ) cellposeArguments.add("--use_gpu");

cellposeArguments.add("--verbose");

veRunner.setArguments(cellposeArguments);

// Finally, we can run Cellpose
veRunner.runCommand();
veRunner.runCommand(false);

return processCellposeFiles(veRunner, allTiles);

}

private LinkedHashMap<File, TileFile> processCellposeFiles(VirtualEnvironmentRunner veRunner, LinkedHashMap<File, TileFile> allTiles) throws CancellationException, InterruptedException, IOException {

// Make sure that allTiles is not null, if it is, just return null
// as we are likely just running validation and thus do not need to give any results back
if (allTiles == null ) {
veRunner.getProcess().waitFor();
return null;
}

// Build a thread pool to process reading the images in parallel
ExecutorService executor = Executors.newFixedThreadPool(5);

if (!this.doReadResultsAsynchronously || allTiles == null) {
if (!this.doReadResultsAsynchronously) {
// We need to wait for the process to finish
veRunner.getProcess().waitFor();
allTiles.entrySet().forEach(entry -> {
Expand Down Expand Up @@ -893,7 +909,7 @@ private void runTraining() throws IOException, InterruptedException {
VirtualEnvironmentRunner veRunner = getVirtualEnvironmentRunner();

// This is the list of commands after the 'python' call
List<String> cellposeArguments = new ArrayList<>(Arrays.asList("-W", "ignore", "-m", runCommand));
List<String> cellposeArguments = new ArrayList<>(Arrays.asList( "-Xutf8", "-W", "ignore", "-m", runCommand));

cellposeArguments.add("--train");

Expand All @@ -917,18 +933,15 @@ private void runTraining() throws IOException, InterruptedException {
}
});


cellposeArguments.add("--use_gpu");
// Some people may deactivate this...
if( this.useGPU ) cellposeArguments.add("--use_gpu");

cellposeArguments.add("--verbose");

veRunner.setArguments(cellposeArguments);

// Finally, we can run Cellpose
veRunner.runCommand();

// Wait for the process to finish
veRunner.getProcess().waitFor();
veRunner.runCommand(true);

// Get the log
this.theLog = veRunner.getProcessLog();
Expand Down Expand Up @@ -989,7 +1002,8 @@ private ResultsTable runCellposeQC() throws IOException, InterruptedException {

qcRunner.setArguments(qcArguments);

qcRunner.runCommand();
qcRunner.runCommand(true);


// The results are stored in the validation directory, open them as a results table
File qcResults = new File( getValidationDirectory(), "QC-Results" + File.separator + "Quality_Control for " + this.modelFile.getName() + ".csv");
Expand Down Expand Up @@ -1047,20 +1061,25 @@ private ResultsTable parseTrainingResults() {

if (this.theLog != null) {
// Try to parse the output of Cellpose to give meaningful information to the user. This is very old school
// Look for "Epoch 0, Time 2.3s, Loss 1.0758, Loss Test 0.6007, LR 0.2000"
String epochPattern = ".*Epoch\\s*(\\d+),\\s*Time\\s*(\\d+\\.\\d)s,\\s*Loss\\s*(\\d+\\.\\d+),\\s*Loss Test\\s*(\\d+\\.\\d+),\\s*LR\\s*(\\d+\\.\\d+).*";
// Build Matcher
Pattern pattern = Pattern.compile(epochPattern);
Matcher m;
for (String line : this.theLog) {
m = pattern.matcher(line);
if (m.find()) {
trainingResults.incrementCounter();
trainingResults.addValue("Epoch", Double.parseDouble(m.group(1)));
trainingResults.addValue("Time[s]", Double.parseDouble(m.group(2)));
trainingResults.addValue("Loss", Double.parseDouble(m.group(3)));
trainingResults.addValue("Loss Test", Double.parseDouble(m.group(4)));
trainingResults.addValue("LR", Double.parseDouble(m.group(5)));
Matcher m;
for (LogParser parser : LogParser.values()) {
m = parser.getPattern().matcher(line);
if (m.find()) {
trainingResults.incrementCounter();
trainingResults.addValue("Epoch", Double.parseDouble(m.group("epoch")));
trainingResults.addValue("Time", Double.parseDouble(m.group("time")));
trainingResults.addValue("Loss", Double.parseDouble(m.group("loss")));
if (parser != LogParser.OMNI) { // Omnipose does not provide validation loss
trainingResults.addValue("Validation Loss", Double.parseDouble(m.group("val")));
trainingResults.addValue("LR", Double.parseDouble(m.group("lr")));

} else {
trainingResults.addValue("Validation Loss", Double.NaN);
trainingResults.addValue("LR", Double.NaN);

}
}
}
}
}
Expand Down Expand Up @@ -1104,7 +1123,7 @@ public void showTrainingGraph(boolean show, boolean save) {
//populating the series with data
for (int i = 0; i < output.getCounter(); i++) {
loss.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Loss", i)));
lossTest.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Loss Test", i)));
lossTest.getData().add(new XYChart.Data<>(output.getValue("Epoch", i), output.getValue("Validation Loss", i)));

}
lineChart.getData().add(loss);
Expand Down Expand Up @@ -1166,18 +1185,18 @@ private void saveImagePairs(List<PathObject> annotations, String imageName, Imag
if (annotations.isEmpty()) {
return;
}
int downsample = 1;
double downsample;
if (Double.isFinite(pixelSize) && pixelSize > 0) {
downsample = (int) Math.round(pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue());
downsample = pixelSize / originalServer.getPixelCalibration().getAveragedPixelSize().doubleValue();
} else {
downsample = 1.0;
}

AtomicInteger idx = new AtomicInteger();
int finalDownsample = downsample;

annotations.forEach(a -> {
int i = idx.getAndIncrement();

RegionRequest request = RegionRequest.createInstance(originalServer.getPath(), finalDownsample, a.getROI());
RegionRequest request = RegionRequest.createInstance(originalServer.getPath(), downsample, a.getROI());
File imageFile = new File(saveDirectory, imageName + "_region_" + i + ".tif");
File maskFile = new File(saveDirectory, imageName + "_region_" + i + "_masks.tif");
try {
Expand Down Expand Up @@ -1348,6 +1367,7 @@ private Collection<CandidateObject> readObjectsFromFile(TileFile tileFile) throw
}
}
// Ignore the IDs, because they will be the same across different images, and we don't really need them
if(candidates.isEmpty()) return Collections.emptyList();
return candidates.values();
}

Expand Down Expand Up @@ -1424,4 +1444,25 @@ private static class CandidateObject {
geometry = geometry.getGeometryN(index);
}
}
public enum LogParser {

// Cellpose 2 pattern when training : "Look for "Epoch 0, Time 2.3s, Loss 1.0758, Loss Test 0.6007, LR 0.2000"
// Cellpose 3 pattern when training : "5, train_loss=2.6546, test_loss=2.0054, LR=0.1111, time 2.56s"
// Omnipose pattern when training : "Train epoch: 10 | Time: 0.22min | last epoch: 0.74s | <sec/epoch>: 0.73s | <sec/batch>: 0.33s | <Batch Loss>: 5.076259 | <Epoch Loss>: 4.429341"
// WARNING: Currently Omnipose does not provide any output to the validation loss (Test loss in Cellpose)
CP2("Cellpose v2", ".*Epoch\\s*(?<epoch>\\d+),\\s*Time\\s*(?<time>\\d+\\.\\d)s,\\s*Loss\\s*(?<loss>\\d+\\.\\d+),\\s*Loss Test\\s*(?<val>\\d+\\.\\d+),\\s*LR\\s*(?<lr>\\d+\\.\\d+).*"),
CP3( "Cellpose v3", ".* (?<epoch>\\d+), train_loss=(?<loss>\\d+\\.\\d+), test_loss=(?<val>\\d+\\.\\d+), LR=(?<lr>\\d+\\.\\d+), time (?<time>\\d+\\.\\d+)s.*"),
OMNI("Omnipose", ".*Train epoch: (?<epoch>\\d+) \\| Time: (?<time>\\d+\\.\\d+)min .*\\<Epoch Loss\\>: (?<loss>\\d+\\.\\d+).*");

private final String name;
private final Pattern pattern;

LogParser(String name, String regex) {
this.name = name;
this.pattern = Pattern.compile(regex);
}

public String getName() {return this.name;}
public Pattern getPattern() {return this.pattern;}
}
}
14 changes: 14 additions & 0 deletions src/main/java/qupath/ext/biop/cellpose/CellposeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class CellposeBuilder {
private ImageOp extendChannelOp = null;

private boolean doReadResultsAsynchronously = false;
private boolean useGPU = true;

/**
* can create a cellpose builder from a serialized JSON version of this builder.
Expand Down Expand Up @@ -134,9 +135,20 @@ protected CellposeBuilder(String modelPath) {

}

/**
* overwrite use GPU
* @param useGPU add or remove the option
* @return this builder
*/
public CellposeBuilder useGPU( boolean useGPU ) {
this.useGPU = useGPU;

return this;
}

/**
* Specify the training directory
*
*/
public CellposeBuilder groundTruthDirectory(File groundTruthDirectory) {
this.groundTruthDirectory = groundTruthDirectory;
Expand Down Expand Up @@ -771,6 +783,8 @@ public Cellpose2D build() {
// Give it the number of threads to use
cellpose.nThreads = nThreads;

cellpose.useGPU = useGPU;

// Check the model. If it is a file, then it is a custom model
File file = new File(this.modelNameOrPath);
if (file.exists()) {
Expand Down
31 changes: 21 additions & 10 deletions src/main/java/qupath/ext/biop/cmd/VirtualEnvironmentRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class VirtualEnvironmentRunner {
private final EnvType envType;
private WatchService watchService;
private String name;
private String environmentNameOrPath;
private String pythonPath;

private List<String> arguments;

Expand Down Expand Up @@ -61,7 +61,7 @@ public String toString() {
}

public VirtualEnvironmentRunner(String environmentNameOrPath, EnvType type, String name) {
this.environmentNameOrPath = environmentNameOrPath;
this.pythonPath = environmentNameOrPath;
this.envType = type;
this.name = name;
if (envType.equals(EnvType.OTHER))
Expand All @@ -84,27 +84,29 @@ private List<String> getActivationCommand() {
case CONDA:
switch (platform) {
case WINDOWS:
cmd.addAll(Arrays.asList("CALL", "conda.bat", "activate", environmentNameOrPath, "&", "python"));
// Adjust path to the folder with the env name based on the python location. On Windows it's at the root of the environment
cmd.addAll(Arrays.asList("CALL", "conda.bat", "activate", new File(pythonPath).getParent(), "&", "python"));
break;
case UNIX:
case OSX:
cmd.addAll(Arrays.asList("conda", "activate", environmentNameOrPath, ";", "python"));
// Adjust path to the folder with the env name based on the python location. In Linux/MacOS it's in the 'bin' sub folder
cmd.addAll(Arrays.asList("conda", "activate", new File(pythonPath).getParentFile().getParent(), ";", "python"));
break;
}
break;
case VENV:
switch (platform) {
case WINDOWS:
cmd.add(new File(environmentNameOrPath, "Scripts/python").getAbsolutePath());
cmd.add(new File(pythonPath, "Scripts/python").getAbsolutePath());
break;
case UNIX:
case OSX:
cmd.add(new File(environmentNameOrPath, "bin/python").getAbsolutePath());
cmd.add(new File(pythonPath, "bin/python").getAbsolutePath());
break;
}
break;
case EXE:
cmd.add(environmentNameOrPath);
cmd.add(pythonPath);
break;
case OTHER:
return null;
Expand All @@ -123,10 +125,10 @@ public void setArguments(List<String> arguments) {

/**
* This builds, runs the command and outputs it to the logger as it is being run
*
* @throws IOException // In case there is an issue starting the process
* @param waitUntilDone whether to wait for the process to end or not before exiting this method
* @throws IOException in case there is an issue with the process
*/
public void runCommand() throws IOException {
public void runCommand(boolean waitUntilDone) throws IOException {

// Get how to start the command, based on the VENV Type
List<String> command = getActivationCommand();
Expand Down Expand Up @@ -207,6 +209,15 @@ public void run() {


logger.info("Virtual Environment Runner Started");

// If we ask to wait, let's wait directly here rather than handle it outside
if(waitUntilDone) {
try {
this.process.waitFor();
} catch (InterruptedException e) {
logger.error(e.getMessage());
}
}
}

public Process getProcess() {
Expand Down
Loading