Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Predicitions #856

Merged
merged 8 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
language: system
types: [python]
pass_filenames: false
stages: [commit]
stages: [pre-commit]
- repo: local
hooks:
- id: docformatter
Expand All @@ -18,5 +18,5 @@ repos:
types: [python]
args: ['--in-place', '--recursive','src/deepforest/']
pass_filenames: false
stages: [commit]
stages: [pre-commit]
bw4sz marked this conversation as resolved.
Show resolved Hide resolved

4 changes: 3 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ build:
os: ubuntu-22.04
tools:
python: "3.12"

python:
install:
- requirements: dev_requirements.txt
- method: pip
path: .

submodules:
include: all
include: []

36 changes: 36 additions & 0 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,43 @@ def predict_step(self, batch, batch_idx):
for result in batch_results:
boxes = visualize.format_boxes(result)
results.append(boxes)
return results

def predict_batch(self, images, preprocess_fn=None):
"""Predict a batch of images with the deepforest model.

Args:
images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W) or (B, H, W, C).
preprocess_fn (callable, optional): A function to preprocess images before prediction.
If None, assumes images are preprocessed.

Returns:
List[pd.DataFrame]: A list of dataframes with predictions for each image.
"""

self.model.eval()

#conver to tensor if input is array
if isinstance(images, np.ndarray):
images = torch.tensor(images, device=self.device)

#check input format
if images.dim() == 4 and images.shape[-1] == 3:
#Convert channels_last (B, H, W, C) to channels_first (B, C, H, W)
images = images.permute(0, 3, 1, 2)

#appy preprocessing if available
if preprocess_fn:
images = preprocess_fn(images)

#using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = []
for idx, image in enumerate(images):
predictions = self.predict_step(image.unsqueeze(0), idx)
predictions.extend(predictions)
#convert predictions to dataframes
bw4sz marked this conversation as resolved.
Show resolved Hide resolved
results = [pd.DataFrame(pred) for pred in predictions if pred is not None]
return results

def configure_optimizers(self):
Expand Down
Empty file added test.py
Empty file.
105 changes: 105 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader


from PIL import Image

Expand Down Expand Up @@ -674,3 +676,106 @@ def test_predict_tile_with_crop_model_empty():

# Assert the result
assert result is None


# @pytest.mark.parametrize("batch_size", [1, 4, 8])
# def test_batch_prediction(m, batch_size, raster_path):
#
# # Prepare input data
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=batch_size)

# # Perform prediction
# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# # Check results
# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# @pytest.mark.parametrize("batch_size", [1, 4])
# def test_batch_training(m, batch_size, tmpdir):
#
# # Generate synthetic training data
# csv_file = get_data("example.csv")
# root_dir = os.path.dirname(csv_file)
# train_ds = m.load_dataset(csv_file, root_dir=root_dir)
# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# # Configure the model and trainer
# m.config["batch_size"] = batch_size
# m.create_trainer()
# trainer = m.trainer

# # Train the model
# trainer.fit(m, train_dl)

# # Assertions
# assert trainer.current_epoch == 1
# assert trainer.batch_size == batch_size

# @pytest.mark.parametrize("batch_size", [2, 4])
# def test_batch_data_augmentation(m, batch_size, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True)
# dl = DataLoader(ds, batch_size=batch_size)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) == len(dl)
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }

# def test_batch_inference_consistency(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=4)

# batch_predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
bw4sz marked this conversation as resolved.
Show resolved Hide resolved
# batch_predictions.append(prediction)

# single_predictions = []
# for image in ds:
# prediction = m.predict_image(image=image)
# single_predictions.append(prediction)

# batch_df = pd.concat(batch_predictions, ignore_index=True)
# single_df = pd.concat(single_predictions, ignore_index=True)

# pd.testing.assert_frame_equal(batch_df, single_df)

# def test_large_batch_handling(m, raster_path):
#
# tile = np.array(Image.open(raster_path))
# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100)
# dl = DataLoader(ds, batch_size=16)

# predictions = []
# for batch in dl:
# prediction = m.predict_batch(batch)
# predictions.append(prediction)

# assert len(predictions) > 0
# for batch_pred in predictions:
# assert isinstance(batch_pred, pd.DataFrame)
# assert set(batch_pred.columns) == {
# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry"
# }
# assert not batch_pred.empty
Loading