Skip to content

Commit

Permalink
VRT expressions: add GDAL_EXPRTK_MAX_VECTOR_SIZE and GDAL_EXPRTK_ENAB…
Browse files Browse the repository at this point in the history
…LE_LOOPS
  • Loading branch information
dbaston committed Nov 13, 2024
1 parent cfd9754 commit c5bec4e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 22 deletions.
81 changes: 59 additions & 22 deletions autotest/gdrivers/vrtprocesseddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,20 +1203,22 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):


@pytest.mark.parametrize(
"expression,src,expected,error",
"expression,src,expected,error,env",
[
pytest.param(
"return [ALL_BANDS[1], ALL_BANDS[2]]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[[3, 4]], [[5, 6]]]),
None,
{},
id="multiple bands in, multiple bands out (1)",
),
pytest.param(
"return [ALL_BANDS]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
None,
{},
id="multiple bands in, multiple bands out (2)",
),
pytest.param(
Expand All @@ -1234,41 +1236,47 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
np.arange(100).reshape(50, 1, 2),
np.array([[[9, 10]], [[29, 30]], [[49, 50]], [[69, 70]], [[89, 90]]]),
None,
{},
id="procedural",
),
pytest.param(
"B1",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
None,
{},
id="multiple bands in, single band out (1)",
),
pytest.param(
"ALL_BANDS[0]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
None,
{},
id="multiple bands in, single band out (2)",
),
pytest.param(
"return [B1];",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
None,
{},
id="multiple bands in, single band out (3)",
),
pytest.param(
"return [B1, B2]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
"returned 2 values but 1 output band",
{},
id="return wrong number of bands",
),
pytest.param(
"return [ALL_BANDS, B2]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
"must return a vector or a list of scalars",
{},
id="return wrong number of bands",
),
pytest.param(
Expand All @@ -1282,11 +1290,39 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[[1, 1]], [[2, 2]], [[3, 3]]]),
"Attempted to access index 3",
{},
id="out of bounds vector access",
),
pytest.param(
"""
var out[5];
return [out];
""",
np.array([[[1, 2]]]),
np.array([[[1, 1]]]),
"Failed to parse expression",
{"GDAL_EXPRTK_MAX_VECTOR_SIZE": "4"},
id="vector too large",
),
pytest.param(
"""
var out[3];
for (var i := 0; i < out[]; i += 1) {
out[i] := i;
}
return [out];
""",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[[1, 1]], [[2, 2]], [[3, 3]]]),
"Failed to parse expression",
{"GDAL_EXPRTK_ENABLE_LOOPS": "NO"},
id="loops disabled",
),
],
)
def test_vrtprocesseddataset_expression(tmp_vsimem, expression, src, expected, error):
def test_vrtprocesseddataset_expression(
tmp_vsimem, expression, src, expected, env, error
):

src_filename = tmp_vsimem / "src.tif"

Expand All @@ -1304,28 +1340,29 @@ def test_vrtprocesseddataset_expression(tmp_vsimem, expression, src, expected, e
for i in range(expected_output_bands)
)

ds = gdal.Open(
f"""<VRTDataset subclass='VRTProcessedDataset'>
<Input>
<SourceFilename>{src_filename}</SourceFilename>
</Input>
<ProcessingSteps>
<Step>
<Algorithm>Expression</Algorithm>
<Argument name="expression">{expression.replace('<', '&lt;').replace('>', '&gt;')}</Argument>
</Step>
</ProcessingSteps>
{output_band_xml}
</VRTDataset>
"""
)
vrt_xml = f"""<VRTDataset subclass='VRTProcessedDataset'>
<Input>
<SourceFilename>{src_filename}</SourceFilename>
</Input>
<ProcessingSteps>
<Step>
<Algorithm>Expression</Algorithm>
<Argument name="expression">{expression.replace('<', '&lt;').replace('>', '&gt;')}</Argument>
</Step>
</ProcessingSteps>
{output_band_xml}
</VRTDataset>
"""

if error:
with pytest.raises(Exception, match=error):
with gdal.config_options(env):
if error:
with pytest.raises(Exception, match=error):
ds = gdal.Open(vrt_xml)
result = ds.ReadAsArray()
else:
ds = gdal.Open(vrt_xml)
result = ds.ReadAsArray()
else:
result = ds.ReadAsArray()
np.testing.assert_equal(result, expected)
np.testing.assert_equal(result, expected)


###############################################################################
Expand Down
24 changes: 24 additions & 0 deletions frmts/vrt/vrtexpression.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "cpl_conv.h"
#include "cpl_error.h"
#include "cpl_string.h"
#include "vrtexpression.h"

#define exprtk_disable_caseinsensitivity
Expand Down Expand Up @@ -47,7 +49,29 @@ class GDALExpressionEvaluator::Impl

Impl()
{
using settings_t = std::decay_t<decltype(m_oParser.settings())>;

m_oParser.register_vector_access_runtime_check(m_oVectorAccessCheck);

int nMaxVectorSize = std::atoi(
CPLGetConfigOption("GDAL_EXPRTK_MAX_VECTOR_SIZE", "100000"));

if (nMaxVectorSize > 0)
{
m_oParser.settings().set_max_local_vector_size(nMaxVectorSize);
}

bool bEnableLoops =
CPLTestBool(CPLGetConfigOption("GDAL_EXPRTK_ENABLE_LOOPS", "YES"));
if (!bEnableLoops)
{
m_oParser.settings().disable_control_structure(
settings_t::e_ctrl_for_loop);
m_oParser.settings().disable_control_structure(
settings_t::e_ctrl_while_loop);
m_oParser.settings().disable_control_structure(
settings_t::e_ctrl_repeat_loop);
}
}

CPLErr compile()
Expand Down

0 comments on commit c5bec4e

Please sign in to comment.