Skip to content

Commit

Permalink
VRTProcessedDataset: Add Expression algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Nov 11, 2024
1 parent d7eb5bb commit 98ceafa
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 3 deletions.
96 changes: 96 additions & 0 deletions autotest/gdrivers/vrtprocesseddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,102 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem):
)


###############################################################################
# Test expressions


@pytest.mark.parametrize(
"expression,src,expected,error",
[
pytest.param(
"return [ALL_BANDS[1], ALL_BANDS[2]]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[[3, 4]], [[5, 6]]]),
None,
id="multiple bands in, multiple bands out (1)",
),
pytest.param(
"return [ALL_BANDS]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
None,
id="multiple bands in, multiple bands out (2)",
),
pytest.param(
"B1",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
None,
id="multiple bands in, single band out (1)",
),
pytest.param(
"ALL_BANDS[0]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
None,
id="multiple bands in, single band out (2)",
),
pytest.param(
"return [B1];",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
None,
id="multiple bands in, single band out (3)",
),
pytest.param(
"return [B1, B2]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
"returned 2 values but 1 output band",
id="return wrong number of bands",
),
pytest.param(
"return [ALL_BANDS, B2]",
np.array([[[1, 2]], [[3, 4]], [[5, 6]]]),
np.array([[1, 2]]),
"must return a vector or a list of scalars",
id="return wrong number of bands",
),
],
)
def test_vrtprocesseddataset_expression(tmp_vsimem, expression, src, expected, error):

src_filename = tmp_vsimem / "src.tif"
with gdal.GetDriverByName("GTiff").Create(src_filename, 2, 1, 3) as src_ds:
src_ds.WriteArray(src)
src_ds.SetGeoTransform([0, 1, 0, 0, 0, 1])

expected_output_bands = 1 if len(expected.shape) == 2 else expected.shape[0]

output_band_xml = "".join(
f"""<VRTRasterBand band="{i+1}" dataType="Float32" subClass="VRTProcessedRasterBand"/>"""
for i in range(expected_output_bands)
)

ds = gdal.Open(
f"""<VRTDataset subclass='VRTProcessedDataset'>
<Input>
<SourceFilename>{src_filename}</SourceFilename>
</Input>
<ProcessingSteps>
<Step>
<Algorithm>Expression</Algorithm>
<Argument name="expression">{expression}</Argument>
</Step>
</ProcessingSteps>
{output_band_xml}
</VRTDataset>
"""
)

if error:
with pytest.raises(Exception, match=error):
result = ds.ReadAsArray()
else:
result = ds.ReadAsArray()
np.testing.assert_equal(result, expected)


###############################################################################
# Test that serialization (for example due to statistics computation) properly
# works
Expand Down
8 changes: 5 additions & 3 deletions frmts/vrt/vrtprocesseddataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,11 @@ CPLErr VRTProcessedDataset::Init(const CPLXMLNode *psTree,

if (nCurrentBandCount != nBands)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Number of output bands of last step is not consistent with "
"number of VRTProcessedRasterBand's");
CPLError(
CE_Failure, CPLE_AppDefined,
"Number of output bands of last step (%d) is not consistent with "
"number of VRTProcessedRasterBand's (%d)",
nCurrentBandCount, nBands);
return CE_Failure;
}

Expand Down
225 changes: 225 additions & 0 deletions frmts/vrt/vrtprocesseddatasetfunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
#include "cpl_string.h"
#include "vrtdataset.h"

#include <exprtk.hpp>

#include <algorithm>
#include <functional>
#include <limits>
#include <map>
#include <optional>
#include <set>
#include <vector>

Expand Down Expand Up @@ -1468,6 +1472,218 @@ static CPLErr TrimmingProcess(
return CE_None;
}

/************************************************************************/
/* ExpressionInit() */
/************************************************************************/

class ExpressionData
{
public:
ExpressionData(int nInBands, std::string_view osExpression)
: m_osExpression(osExpression), m_adfValuesForPixel(nInBands)
{
m_oSymbolTable.add_vector("ALL_BANDS", m_adfValuesForPixel);
m_oSymbolTable.add_vector("OUTPUT_BANDS", m_adfResults);
}

void set_variable_band(const char *pszVariable, int nBand)
{
m_oSymbolTable.add_variable(pszVariable, m_adfValuesForPixel[nBand]);
}

bool compile()
{
m_oExpression.register_symbol_table(m_oSymbolTable);
bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression);

if (!bSuccess)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Failed to parse expression.");

for (size_t i = 0; i < m_oParser.error_count(); i++)
{
const auto &oError = m_oParser.get_error(i);

CPLError(CE_Warning, CPLE_AppDefined,
"Position: %02d "
"Type: [%s] "
"Message: %s\n",
static_cast<int>(oError.token.position),
exprtk::parser_error::to_str(oError.mode).c_str(),
oError.diagnostic.c_str());
}
}

return bSuccess;
}

const std::vector<double> *evaluate(const double *padfValues)
{
std::copy(padfValues, padfValues + m_adfValuesForPixel.size(),
m_adfValuesForPixel.begin());

m_adfResults.clear();
double value = m_oExpression.value(); // force evaluation

const auto &results = m_oExpression.results();

// We follow a different method to get the result depending on
// how the expression was formed. If a "return" statement was
// used, the result will be accessible via the "result" object.
// If no "return" statement was used, the result is accessible
// from the "value" variable (and must not be a vector.)
if (results.count() == 0)
{
m_adfResults.resize(1);
m_adfResults[0] = value;
}
else if (results.count() == 1)
{

if (results[0].type == exprtk::type_store<double>::e_scalar)
{
m_adfResults.resize(1);
results.get_scalar(0, m_adfResults[0]);
}
else if (results[0].type == exprtk::type_store<double>::e_vector)
{
results.get_vector(0, m_adfResults);
}
else
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression returned an unexpected type.");
return nullptr;
}
}
else
{
m_adfResults.resize(results.count());
for (size_t i = 0; i < results.count(); i++)
{
if (results[i].type != exprtk::type_store<double>::e_scalar)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression must return a vector or a list of "
"scalars.");
return nullptr;
}
else
{
results.get_scalar(i, m_adfResults[i]);
}
}
}

return &m_adfResults;
}

private:
std::string m_osExpression{};
std::vector<double> m_adfValuesForPixel;
std::vector<double> m_adfResults{};

exprtk::expression<double> m_oExpression{};
exprtk::parser<double> m_oParser{};
exprtk::symbol_table<double> m_oSymbolTable{};
};

static CPLErr ExpressionInit(const char * /*pszFuncName*/, void * /*pUserData*/,
CSLConstList papszFunctionArgs, int nInBands,
GDALDataType eInDT, double * /* padfInNoData */,
int * /*pnOutBands */, GDALDataType *peOutDT,
double ** /* ppadfOutNoData */,
const char * /* pszVRTPath */,
VRTPDWorkingDataPtr *ppWorkingData)
{
CPLAssert(eInDT == GDT_Float64);

*peOutDT = eInDT;
*ppWorkingData = nullptr;

const char *pszExpression =
CSLFetchNameValue(papszFunctionArgs, "expression");

auto data = std::make_unique<ExpressionData>(nInBands, pszExpression);

for (int i = 0; i < nInBands; i++)
{
std::string osVar = "B" + std::to_string(i + 1);
data->set_variable_band(osVar.c_str(), i);
}

if (!data->compile())
{
return CE_Failure;
}

*ppWorkingData = data.release();

return CE_None;
}

static void ExpressionFree(const char * /* pszFuncName */,
void * /* pUserData */,
VRTPDWorkingDataPtr pWorkingData)
{
ExpressionData *data = static_cast<ExpressionData *>(pWorkingData);
delete data;
}

static CPLErr ExpressionProcess(
const char * /* pszFuncName */, void * /* pUserData */,
VRTPDWorkingDataPtr pWorkingData, CSLConstList /* papszFunctionArgs */,
int nBufXSize, int nBufYSize, const void *pInBuffer,
size_t /* nInBufferSize */, GDALDataType eInDT, int nInBands,
const double *CPL_RESTRICT /* padfInNoData */, void *pOutBuffer,
size_t /* nOutBufferSize */, GDALDataType eOutDT, int nOutBands,
const double *CPL_RESTRICT /* padfOutNoData */, double /* dfSrcXOff */,
double /* dfSrcYOff */, double /* dfSrcXSize */, double /* dfSrcYSize */,
const double /* adfSrcGT */[], const char * /* pszVRTPath "*/,
CSLConstList /* papszExtra */)
{
ExpressionData *expr = static_cast<ExpressionData *>(pWorkingData);

const size_t nElts = static_cast<size_t>(nBufXSize) * nBufYSize;

CPL_IGNORE_RET_VAL(eInDT);
CPLAssert(eInDT == GDT_Float64);
const double *CPL_RESTRICT padfSrc = static_cast<const double *>(pInBuffer);

CPLAssert(eOutDT == GDT_Float64);
CPL_IGNORE_RET_VAL(eOutDT);
double *CPL_RESTRICT padfDst = static_cast<double *>(pOutBuffer);

for (size_t i = 0; i < nElts; i++)
{
const auto *padfResults = expr->evaluate(padfSrc);

if (!padfResults)
{
return CE_Failure;
}

if (padfResults->size() != static_cast<std::size_t>(nOutBands))
{
CPLError(CE_Failure, CPLE_AppDefined,
"Expression returned %d values but "
"%d output bands were expected.",
static_cast<int>(padfResults->size()), nOutBands);
return CE_Failure;
}

for (int iDstBand = 0; iDstBand < nOutBands; iDstBand++)
{
*padfDst++ = (*padfResults)[iDstBand];
}

padfSrc += nInBands;
}

return CE_None;
}

/************************************************************************/
/* GDALVRTRegisterDefaultProcessedDatasetFuncs() */
/************************************************************************/
Expand Down Expand Up @@ -1573,4 +1789,13 @@ void GDALVRTRegisterDefaultProcessedDatasetFuncs()
"</ProcessedDatasetFunctionArgumentsList>",
GDT_Float64, nullptr, 0, nullptr, 0, TrimmingInit, TrimmingFree,
TrimmingProcess, nullptr);

GDALVRTRegisterProcessedDatasetFunc(
"Expression", nullptr,
"<ProcessedDatasetFunctionArgumentsList>"
" <Argument name='expression' description='the expression to "
"evaluate' type='string' required='true' />"
"</ProcessedDatasetFunctionArgumentsList>",
GDT_Float64, nullptr, 0, nullptr, 0, ExpressionInit, ExpressionFree,
ExpressionProcess, nullptr);
}

0 comments on commit 98ceafa

Please sign in to comment.