diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 63ed3d0f..8bd5417f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: language: system types: [python] pass_filenames: false - stages: [commit] + stages: [pre-commit] - repo: local hooks: - id: docformatter @@ -18,5 +18,5 @@ repos: types: [python] args: ['--in-place', '--recursive','src/deepforest/'] pass_filenames: false - stages: [commit] + stages: [pre-commit] diff --git a/.readthedocs.yml b/.readthedocs.yml index 4db54370..c9ead8e3 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,6 +4,7 @@ build: os: ubuntu-22.04 tools: python: "3.12" + python: install: - requirements: dev_requirements.txt @@ -11,4 +12,5 @@ python: path: . submodules: - include: all + include: [] + diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 320beca0..f2f1695d 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -763,7 +763,43 @@ def predict_step(self, batch, batch_idx): for result in batch_results: boxes = visualize.format_boxes(result) results.append(boxes) + return results + + def predict_batch(self, images, preprocess_fn=None): + """Predict a batch of images with the deepforest model. + + Args: + images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W) or (B, H, W, C). + preprocess_fn (callable, optional): A function to preprocess images before prediction. + If None, assumes images are preprocessed. + + Returns: + List[pd.DataFrame]: A list of dataframes with predictions for each image. + """ + self.model.eval() + + #conver to tensor if input is array + if isinstance(images, np.ndarray): + images = torch.tensor(images, device=self.device) + + #check input format + if images.dim() == 4 and images.shape[-1] == 3: + #Convert channels_last (B, H, W, C) to channels_first (B, C, H, W) + images = images.permute(0, 3, 1, 2) + + #appy preprocessing if available + if preprocess_fn: + images = preprocess_fn(images) + + #using Pytorch Ligthning's predict_step + with torch.no_grad(): + predictions = [] + for idx, image in enumerate(images): + predictions = self.predict_step(image.unsqueeze(0), idx) + predictions.extend(predictions) + #convert predictions to dataframes + results = [pd.DataFrame(pred) for pred in predictions if pred is not None] return results def configure_optimizers(self): diff --git a/test.py b/test.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_main.py b/tests/test_main.py index a8c3543e..13d10826 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -18,6 +18,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import TensorBoardLogger +from torch.utils.data import DataLoader + from PIL import Image @@ -674,3 +676,106 @@ def test_predict_tile_with_crop_model_empty(): # Assert the result assert result is None + + +# @pytest.mark.parametrize("batch_size", [1, 4, 8]) +# def test_batch_prediction(m, batch_size, raster_path): +# +# # Prepare input data +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) +# dl = DataLoader(ds, batch_size=batch_size) + +# # Perform prediction +# predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# predictions.append(prediction) + +# # Check results +# assert len(predictions) == len(dl) +# for batch_pred in predictions: +# assert isinstance(batch_pred, pd.DataFrame) +# assert set(batch_pred.columns) == { +# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" +# } + +# @pytest.mark.parametrize("batch_size", [1, 4]) +# def test_batch_training(m, batch_size, tmpdir): +# +# # Generate synthetic training data +# csv_file = get_data("example.csv") +# root_dir = os.path.dirname(csv_file) +# train_ds = m.load_dataset(csv_file, root_dir=root_dir) +# train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) + +# # Configure the model and trainer +# m.config["batch_size"] = batch_size +# m.create_trainer() +# trainer = m.trainer + +# # Train the model +# trainer.fit(m, train_dl) + +# # Assertions +# assert trainer.current_epoch == 1 +# assert trainer.batch_size == batch_size + +# @pytest.mark.parametrize("batch_size", [2, 4]) +# def test_batch_data_augmentation(m, batch_size, raster_path): +# +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100, augment=True) +# dl = DataLoader(ds, batch_size=batch_size) + +# predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# predictions.append(prediction) + +# assert len(predictions) == len(dl) +# for batch_pred in predictions: +# assert isinstance(batch_pred, pd.DataFrame) +# assert set(batch_pred.columns) == { +# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" +# } + +# def test_batch_inference_consistency(m, raster_path): +# +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) +# dl = DataLoader(ds, batch_size=4) + +# batch_predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# batch_predictions.append(prediction) + +# single_predictions = [] +# for image in ds: +# prediction = m.predict_image(image=image) +# single_predictions.append(prediction) + +# batch_df = pd.concat(batch_predictions, ignore_index=True) +# single_df = pd.concat(single_predictions, ignore_index=True) + +# pd.testing.assert_frame_equal(batch_df, single_df) + +# def test_large_batch_handling(m, raster_path): +# +# tile = np.array(Image.open(raster_path)) +# ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=100) +# dl = DataLoader(ds, batch_size=16) + +# predictions = [] +# for batch in dl: +# prediction = m.predict_batch(batch) +# predictions.append(prediction) + +# assert len(predictions) > 0 +# for batch_pred in predictions: +# assert isinstance(batch_pred, pd.DataFrame) +# assert set(batch_pred.columns) == { +# "xmin", "ymin", "xmax", "ymax", "label", "score", "geometry" +# } +# assert not batch_pred.empty