Skip to content

Commit

Permalink
autotest: move VRT derived expression tests to vrtderived.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Nov 11, 2024
1 parent 43a7c75 commit d7eb5bb
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 142 deletions.
142 changes: 0 additions & 142 deletions autotest/gcore/vrt_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -2837,145 +2837,3 @@ def test_vrt_update_nested_VRTDataset(tmp_vsimem):
with gdal.Open(vrt_filename) as ds:
assert ds.GetRasterBand(1).Checksum() == 4672
assert ds.GetRasterBand(1).GetMinimum() == 74


###############################################################################
# Test arbitrary expression pixel functions


def vrt_expression_xml(tmpdir, expression, sources):

drv = gdal.GetDriverByName("GTiff")

nx = 1
ny = 1

xml = f"""<VRTDataset rasterXSize="{nx}" rasterYSize="{ny}">
<VRTRasterBand dataType="Float64" band="1" subClass="VRTDerivedRasterBand">
<PixelFunctionType>expression</PixelFunctionType>
<PixelFunctionArguments expression="{expression}" />"""

for i, source in enumerate(sources):
if type(source) is tuple:
source_name, source_value = source
else:
source_name = ""
source_value = source

src_fname = tmpdir / f"source_{i}.tif"

with drv.Create(src_fname, 1, 1, 1, gdal.GDT_Float64) as ds:
ds.GetRasterBand(1).Fill(source_value)

xml += f"""<SimpleSource name="{source_name}">
<SourceFilename relativeToVRT="0">{src_fname}</SourceFilename>
<SourceBand>1</SourceBand>
</SimpleSource>"""

xml += "</VRTRasterBand></VRTDataset>"

return xml


@pytest.mark.parametrize(
"expression,sources,result",
[
pytest.param("A", [("A", 77)], 77, id="identity"),
pytest.param(
"(NIR-R)/(NIR+R)",
[("NIR", 77), ("R", 63)],
(77 - 63) / (77 + 63),
id="simple expression",
),
pytest.param(
"if (A > B) 1.5*C ; else A",
[("A", 77), ("B", 63), ("C", 18)],
27,
id="conditional (explicit)",
),
pytest.param(
"(A > B)*(1.5*C) + (A <= B)*(A)",
[("A", 77), ("B", 63), ("C", 18)],
27,
id="conditional (implicit)",
),
pytest.param(
"B2 * PopDensity",
[("PopDensity", 3), ("", 7)],
21,
id="implicit source name",
),
pytest.param(
"B1 / sum(ALL)",
[("", 3), ("", 5), ("", 31)],
3 / (3 + 5 + 31),
id="use of ALL variable",
),
pytest.param(
"B1 / sum(B2, B3) ",
[("", 3), ("", 5), ("", 31)],
3 / (5 + 31),
id="aggregate specified inputs",
),
pytest.param(
"var q[2] := {B2, B3}; B1 * q",
[("", 3), ("", 5), ("", 31)],
15, # First value in returned vector. This behavior doesn't seem desirable
# but I haven't figured out how to detect a vector return.
id="return vector",
),
pytest.param(
"B1 + B2 + B3",
(5, 9, float("nan")),
float("nan"),
id="nan propagated via arithmetic",
),
pytest.param(
"if (B3) B1 ; else B2",
(5, 9, float("nan")),
5,
id="nan = truth in conditional?",
),
pytest.param(
"if (B3 > 0) B1 ; else B2",
(5, 9, float("nan")),
9,
id="nan comparison is false in conditional",
),
pytest.param(
"if (B1 > 5) B1",
(1,),
float("nan"),
id="expression returns nodata",
),
],
)
def test_vrt_pixelfn_expression(tmp_path, expression, sources, result):
pytest.importorskip("numpy")

xml = vrt_expression_xml(tmp_path, expression, sources)

with gdal.Open(xml) as ds:
assert pytest.approx(ds.ReadAsArray()[0][0], nan_ok=True) == result


@gdaltest.enable_exceptions()
@pytest.mark.parametrize(
"expression,sources,exception",
[
pytest.param(
"A*B + C",
[("A", 77), ("B", 63)],
"failed to parse expression",
id="undefined variable",
),
],
)
def test_vrt_pixelfn_expression_invalid(tmp_path, expression, sources, exception):
pytest.importorskip("numpy")

xml = vrt_expression_xml(tmp_path, expression, sources)

with gdal.Open(xml) as ds:
with pytest.raises(Exception, match=exception):
ds.ReadAsArray()
142 changes: 142 additions & 0 deletions autotest/gdrivers/vrtderived.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,148 @@ def identity(in_ar, out_ar, *args, **kwargs):
assert vrt_ds.GetRasterBand(1).DataType == dtype


###############################################################################
# Test arbitrary expression pixel functions


def vrt_expression_xml(tmpdir, expression, sources):

drv = gdal.GetDriverByName("GTiff")

nx = 1
ny = 1

xml = f"""<VRTDataset rasterXSize="{nx}" rasterYSize="{ny}">
<VRTRasterBand dataType="Float64" band="1" subClass="VRTDerivedRasterBand">
<PixelFunctionType>expression</PixelFunctionType>
<PixelFunctionArguments expression="{expression}" />"""

for i, source in enumerate(sources):
if type(source) is tuple:
source_name, source_value = source
else:
source_name = ""
source_value = source

src_fname = tmpdir / f"source_{i}.tif"

with drv.Create(src_fname, 1, 1, 1, gdal.GDT_Float64) as ds:
ds.GetRasterBand(1).Fill(source_value)

xml += f"""<SimpleSource name="{source_name}">
<SourceFilename relativeToVRT="0">{src_fname}</SourceFilename>
<SourceBand>1</SourceBand>
</SimpleSource>"""

xml += "</VRTRasterBand></VRTDataset>"

return xml


@pytest.mark.parametrize(
"expression,sources,result",
[
pytest.param("A", [("A", 77)], 77, id="identity"),
pytest.param(
"(NIR-R)/(NIR+R)",
[("NIR", 77), ("R", 63)],
(77 - 63) / (77 + 63),
id="simple expression",
),
pytest.param(
"if (A > B) 1.5*C ; else A",
[("A", 77), ("B", 63), ("C", 18)],
27,
id="conditional (explicit)",
),
pytest.param(
"(A > B)*(1.5*C) + (A <= B)*(A)",
[("A", 77), ("B", 63), ("C", 18)],
27,
id="conditional (implicit)",
),
pytest.param(
"B2 * PopDensity",
[("PopDensity", 3), ("", 7)],
21,
id="implicit source name",
),
pytest.param(
"B1 / sum(ALL)",
[("", 3), ("", 5), ("", 31)],
3 / (3 + 5 + 31),
id="use of ALL variable",
),
pytest.param(
"B1 / sum(B2, B3) ",
[("", 3), ("", 5), ("", 31)],
3 / (5 + 31),
id="aggregate specified inputs",
),
pytest.param(
"var q[2] := {B2, B3}; B1 * q",
[("", 3), ("", 5), ("", 31)],
15, # First value in returned vector. This behavior doesn't seem desirable
# but I haven't figured out how to detect a vector return.
id="return vector",
),
pytest.param(
"B1 + B2 + B3",
(5, 9, float("nan")),
float("nan"),
id="nan propagated via arithmetic",
),
pytest.param(
"if (B3) B1 ; else B2",
(5, 9, float("nan")),
5,
id="nan = truth in conditional?",
),
pytest.param(
"if (B3 > 0) B1 ; else B2",
(5, 9, float("nan")),
9,
id="nan comparison is false in conditional",
),
pytest.param(
"if (B1 > 5) B1",
(1,),
float("nan"),
id="expression returns nodata",
),
],
)
def test_vrt_pixelfn_expression(tmp_path, expression, sources, result):
pytest.importorskip("numpy")

xml = vrt_expression_xml(tmp_path, expression, sources)

with gdal.Open(xml) as ds:
assert pytest.approx(ds.ReadAsArray()[0][0], nan_ok=True) == result


@gdaltest.enable_exceptions()
@pytest.mark.parametrize(
"expression,sources,exception",
[
pytest.param(
"A*B + C",
[("A", 77), ("B", 63)],
"failed to parse expression",
id="undefined variable",
),
],
)
def test_vrt_pixelfn_expression_invalid(tmp_path, expression, sources, exception):
pytest.importorskip("numpy")

xml = vrt_expression_xml(tmp_path, expression, sources)

with gdal.Open(xml) as ds:
with pytest.raises(Exception, match=exception):
ds.ReadAsArray()


###############################################################################
# Cleanup.

Expand Down

0 comments on commit d7eb5bb

Please sign in to comment.