diff --git a/geotile/GeoTile.py b/geotile/GeoTile.py index 670920b..6b24260 100644 --- a/geotile/GeoTile.py +++ b/geotile/GeoTile.py @@ -15,6 +15,7 @@ from rasterio.warp import calculate_default_transform, reproject from rasterio.enums import Resampling from rasterio.features import rasterize +from rasterio.transform import Affine # geopandas library import geopandas as gpd @@ -104,9 +105,55 @@ def _calculate_offset(self, stride_x: Optional[int] = None, stride_y: Optional[i X = [x for x in range(0, self.width, stride_x)] Y = [y for y in range(0, self.height, stride_y)] offsets = list(itertools.product(X, Y)) - return offsets - + self.offsets = offsets + + def _windows_transform_to_affine(self, window_transform: Optional[tuple]): + """Convert the window transform to affine transform + + Parameters + ---------- + window_transform: tuple + tuple of window transform + + Returns + ------- + tuple: tuple of affine transform + """ + a, b, c, d, e, f, _, _, _ = window_transform + return Affine(a, b, c, d, e, f) + + def suffel_tiles(self, random_state: Optional[int] = None): + """Shuffle the tiles + + Parameters + ---------- + random_state: int + Random state for shuffling the tiles + + Returns + ------- + None: Shuffle the tiles. The offsets will be shuffled in place + + Examples + -------- + >>> from geotile import GeoTile + >>> tiler = GeoTile('/path/to/raster/file.tif') + >>> tiler.shuffle_tiles() + """ + # check if random_state is not None + if random_state is not None: + self.random_state = random_state + np.random.seed(self.random_state) + + assert len(self.offsets) == len(self.window_data) == len(self.window_transform), "The number of offsets and window data should be same" + + # shuffle the offsets and window data + p = np.random.permutation(len(self.offsets)) + self.offsets = np.array(self.offsets)[p] + self.window_data = np.array(self.window_data)[p] + self.window_transform = np.array(self.window_transform)[p] + def tile_info(self): """Get the information of the tiles @@ -126,13 +173,14 @@ def tile_info(self): def generate_tiles( self, output_folder: str, + save_tiles: Optional[bool] = True, out_bands: Optional[list] = None, image_format: Optional[str] = None, dtype: Optional[str] = None, tile_x: Optional[int] = 256, tile_y: Optional[int] = 256, stride_x: Optional[int] = 128, - stride_y: Optional[int] = 128 + stride_y: Optional[int] = 128, ): """ Save the tiles to the output folder @@ -141,6 +189,8 @@ def generate_tiles( ---------- output_folder : str Path to the output folder + save_tiles : bool + If True, the tiles will be saved to the output folder else the tiles will be stored in the class out_bands : list The bands to save (eg. [3, 2, 1]), if None, the output bands will be same as the input raster bands image_format : str @@ -184,13 +234,25 @@ def generate_tiles( os.makedirs(output_folder) # offset calculation - offsets = self._calculate_offset(self.stride_x, self.stride_y) + self._calculate_offset(self.stride_x, self.stride_y) + + #store all the windows data as a list, windows shape: (band, tile_y, tile_x) + self.window_data = [] + + # store all the transform data as a list + self.window_transform = [] - # iterate through the offsets - for col_off, row_off in offsets: + # iterate through the offsets and save the tiles + for col_off, row_off in self.offsets: window = windows.Window( col_off=col_off, row_off=row_off, width=self.tile_x, height=self.tile_y) transform = windows.transform(window, self.ds.transform) + + # convert the window transform to affine transform and append to the list + transform = self._windows_transform_to_affine(transform) + self.window_transform.append(transform) + + # copy the meta data meta = self.ds.meta.copy() nodata = meta['nodata'] @@ -206,6 +268,10 @@ def generate_tiles( else: meta.update({"count": len(out_bands)}) + # read the window data and append to the list + single_window_data = self.ds.read(out_bands, window=window, fill_value=nodata, boundless=True) + self.window_data.append(single_window_data) + # if data_type, update the meta if dtype: meta.update({"dtype": dtype}) @@ -218,10 +284,71 @@ def generate_tiles( str(row_off) + '.' + image_format tile_path = os.path.join(output_folder, tile_name) + if save_tiles: + # save the tiles with new metadata + with rio.open(tile_path, 'w', **meta) as outds: + outds.write(self.ds.read( + out_bands, window=window, fill_value=nodata, boundless=True).astype(dtype)) + + def save_tiles(self, output_folder: str, image_format: Optional[str] = None, dtype: Optional[str] = None): + """Save the tiles to the output folder + + Parameters + ---------- + output_folder : str + Path to the output folder + image_format : str + The image format (eg. tif), if None, the image format will be the same as the input raster format (eg. tif) + + dtype : str, np.dtype + The output dtype (eg. uint8, float32), if None, the dtype will be the same as the input raster + + Returns + ------- + None: save the tiles to the output folder + + Examples + -------- + >>> from geotile import GeoTile + >>> tiler = GeoTile('/path/to/raster/file.tif') + >>> tiler.save_tiles('/path/to/output/folder') + """ + # create the output folder if it doesn't exist + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + # meta data + meta = self.meta.copy() + # nodata = meta['nodata'] + meta.update({ + "width": self.tile_x, + "height": self.tile_y, + }) + + # check if image_format is None + if image_format is None: + image_format = pathlib.Path(self.path).suffix[1:] + + # if data_type, update the meta + if dtype: + meta.update({"dtype": dtype}) + dtype=dtype + + # iterate through the offsets and windows_data and save the tiles + for i, ((col_off, row_off), wd, wt) in enumerate(zip(self.offsets, self.window_data, self.window_transform)): + + # update meta data with transform + meta.update({"transform": tuple(wt)}) + + # tile name and path + tile_name = 'tile_' + str(col_off) + '_' + \ + str(row_off) + '.' + image_format + tile_path = os.path.join(output_folder, tile_name) + # save the tiles with new metadata with rio.open(tile_path, 'w', **meta) as outds: - outds.write(self.ds.read( - out_bands, window=window, fill_value=nodata, boundless=True).astype(dtype)) + outds.write(wd.astype(dtype)) + def mask(self, input_vector: str, out_path: str, crop=False, invert=False, **kwargs): """Generate a mask raster from a vector