Skip to content

Commit

Permalink
Refactor telemetry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 25, 2024
1 parent 6eeadb3 commit 8ae0b9a
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions spleenseg/core/neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ def metaTensor_toNIfTI(self, metaTensor: MetaTensor, savefile: Path):
nib.save(niftiVolume, savefile)
pass

def sample_showInfoTest(
self,
# sample: int,
tensor: list[MetaTensor | torch.Tensor],
prefix: list[str],
saveto: list[Path],
saveVolumes: bool = True,
):
# if sample != 1:
# return
pudb.set_trace()
for T, txt, savefile in zip(tensor, prefix, saveto):
print(f"{txt} shape: {T.shape}")
if self.trainingParams.options.logTrainingTransformVols and saveVolumes:
print(f"{txt} save: {savefile}")
self.metaTensor_toNIfTI(T, savefile)

def sample_showInfo(
self,
sample: int,
Expand Down Expand Up @@ -341,6 +358,45 @@ def diceMetric_do(
return Tdm

def validate(
self,
sample: dict[str, MetaTensor | torch.Tensor],
space: data.LoaderCache,
# index: int,
result: torch.Tensor,
telemetry: dict[str, list] | None = None,
) -> float:
"""
This is callback method called in the inference stage.
Given a 'sample' from a LoaderCache iteration, a data space containing
the sample, the sample 'index', and the result, perform some validation.
"""
metric: float = -1.0
outputPostProc: list[MetaTensor] = [
self.f_outputPost(i)
for i in decollate_batch(result) # type: ignore[arg-type]
]
Tdm: torch.Tensor = self.diceMetric_do(
outputPostProc, sample["label"].to(self.network.device)
)
if space.loader.batch_size:
print(
f" validation run {sample['index']:02}/"
f"{len(space.cache) // space.loader.batch_size:02}, "
f"dice metric: {Tdm}"
)
if sample["index"] == len(space.cache) // space.loader.batch_size:
if telemetry is not None:
self.sample_showInfoTest(
[sample["input"], sample["output"]],
telemetry["prefix"],
telemetry["destination"],
)
metric = self.inference_metricsProcess()
return metric

def validateOrig(
self,
sample: dict[str, MetaTensor | torch.Tensor],
space: data.LoaderCache,
Expand Down Expand Up @@ -380,13 +436,15 @@ def slidingWindowInference_do(
[
dict[str, MetaTensor | torch.Tensor],
data.LoaderCache,
int,
# int,
torch.Tensor,
dict[str, list] | None,
],
float,
]
| None
) = None,
telemetry: dict[str, list[str | Path]] | None = None,
) -> float:
metric: float = 0.0
self.network.model.eval()
Expand All @@ -402,21 +460,22 @@ def slidingWindowInference_do(
input, roi_size, sw_batch_size, self.network.model
)
)
sample["index"] = index
sample["input"] = input.cpu()
sample["output"] = outputRaw.cpu()
self.sample_showInfo(
index,
[input, outputRaw],
["validation inference input", "validation inference output"],
[
self.trainingParams.whileTrainingValidation / "input.nii.gz",
self.trainingParams.whileTrainingValidation / "output.nii.gz",
],
not self.whileTrainingValidationNIfTIsaved,
)
self.whileTrainingValidationNIfTIsaved = True
# self.sample_showInfo(
# index,
# [input, outputRaw],
# ["validation inference input", "validation inference output"],
# [
# self.trainingParams.whileTrainingValidation / "input.nii.gz",
# self.trainingParams.whileTrainingValidation / "output.nii.gz",
# ],
# not self.whileTrainingValidationNIfTIsaved,
# )
# self.whileTrainingValidationNIfTIsaved = True
if f_callback is not None:
metric = f_callback(sample, inferSpace, index, outputRaw)
metric = f_callback(sample, inferSpace, outputRaw, telemetry)
return metric

def plot_bestModel(
Expand Down

0 comments on commit 8ae0b9a

Please sign in to comment.