Skip to content

Commit

Permalink
Merge pull request #9487 from dbaston/python-dataset-getlayer-nocrash
Browse files Browse the repository at this point in the history
Python bindings: Invalidate layer ref when Dataset closes
  • Loading branch information
rouault authored Mar 17, 2024
2 parents 838909a + 1e64eb5 commit 4383d1a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
19 changes: 19 additions & 0 deletions autotest/gcore/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,25 @@ def test_band_use_after_dataset_close_2():
band.Checksum()


def test_layer_use_after_dataset_close_1():
with gdal.OpenEx("../ogr/data/poly.shp") as ds:
lyr = ds.GetLayer(0)

# Make sure ds.__exit__() has invalidated "lyr" so we don't crash here
with pytest.raises(Exception):
lyr.GetFeatureCount()


def test_layer_use_after_dataset_close_2():
ds = gdal.OpenEx("../ogr/data/poly.shp")
lyr = ds.GetLayerByName("poly")

del ds
# Make sure ds.__del__() has invalidated "lyr" so we don't crash here
with pytest.raises(Exception):
lyr.GetFeatureCount()


def test_mask_band_use_after_dataset_close():
with gdal.Open("data/byte.tif") as ds:
m1 = ds.GetRasterBand(1).GetMaskBand()
Expand Down
46 changes: 31 additions & 15 deletions swig/include/python/gdal_python.i
Original file line number Diff line number Diff line change
Expand Up @@ -744,12 +744,12 @@ void wrapper_VSIGetMemFileBuffer(const char *utf8_path, GByte **out, vsi_l_offse

%feature("pythonappend") GetMaskBand %{
if hasattr(self, '_parent_ds') and self._parent_ds():
self._parent_ds()._add_band_ref(val)
self._parent_ds()._add_child_ref(val)
%}

%feature("pythonappend") GetOverview %{
if hasattr(self, '_parent_ds') and self._parent_ds():
self._parent_ds()._add_band_ref(val)
self._parent_ds()._add_child_ref(val)
%}

%feature("shadow") ComputeStatistics %{
Expand Down Expand Up @@ -1449,25 +1449,25 @@ CPLErr ReadRaster1( double xoff, double yoff, double xsize, double ysize,
else:
return self._SetGCPs2(gcps, wkt_or_spatial_ref)

def _add_band_ref(self, band):
if band is None:
def _add_child_ref(self, child):
if child is None:
return

import weakref

if not hasattr(self, '_band_references'):
self._band_references = weakref.WeakSet()
if not hasattr(self, '_child_references'):
self._child_references = weakref.WeakSet()

self._band_references.add(band)
band._parent_ds = weakref.ref(self)
self._child_references.add(child)
child._parent_ds = weakref.ref(self)

def _invalidate_bands(self):
if hasattr(self, '_band_references'):
for band in self._band_references:
band.this = None
def _invalidate_children(self):
if hasattr(self, '_child_references'):
for child in self._child_references:
child.this = None

def __del__(self):
self._invalidate_bands()
self._invalidate_children()

def __enter__(self):
return self
Expand All @@ -1479,7 +1479,7 @@ CPLErr ReadRaster1( double xoff, double yoff, double xsize, double ysize,
%feature("pythonappend") Close %{
self.thisown = 0
self.this = None
self._invalidate_bands()
self._invalidate_children()
return val
%}

Expand Down Expand Up @@ -1580,7 +1580,23 @@ def ReleaseResultSet(self, sql_lyr):
%}

%feature("pythonappend") GetRasterBand %{
self._add_band_ref(val)
self._add_child_ref(val)
%}

%feature("pythonappend") GetLayerByName %{
self._add_child_ref(val)
%}

%feature("pythonappend") GetLayerByIndex %{
self._add_child_ref(val)
%}

%feature("pythonappend") CreateLayer %{
self._add_child_ref(val)
%}

%feature("pythonappend") CopyLayer %{
self._add_child_ref(val)
%}

}
Expand Down

0 comments on commit 4383d1a

Please sign in to comment.