Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed May 12, 2021
1 parent 4e5d044 commit c512f75
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
15 changes: 13 additions & 2 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,18 @@
class SemanticSegmentationNumpyDataSource(NumpyDataSource):

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = img.shape
return sample


class SemanticSegmentationTensorDataSource(TensorDataSource):

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
img = sample[DefaultDataKeys.INPUT].float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = img.shape
return sample


Expand Down Expand Up @@ -162,7 +173,7 @@ def __init__(
data_sources={
DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(),
DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(),
DefaultDataSources.TENSORS: TensorDataSource(),
DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(),
DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(),
},
default_data_source=DefaultDataSources.FILES,
Expand Down
4 changes: 2 additions & 2 deletions tests/vision/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_predict_tensor():
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)
assert out[0].shape == (10, 20)


def test_predict_numpy():
Expand All @@ -98,4 +98,4 @@ def test_predict_numpy():
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
assert isinstance(out[0], torch.Tensor)
assert out[0].shape == (196, 196)
assert out[0].shape == (10, 20)

0 comments on commit c512f75

Please sign in to comment.