diff --git a/autotest/gdrivers/vrtprocesseddataset.py b/autotest/gdrivers/vrtprocesseddataset.py index eea8907b0ced..5d2bf3eb8792 100755 --- a/autotest/gdrivers/vrtprocesseddataset.py +++ b/autotest/gdrivers/vrtprocesseddataset.py @@ -1198,6 +1198,102 @@ def test_vrtprocesseddataset_trimming_errors(tmp_vsimem): ) +############################################################################### +# Test expressions + + +@pytest.mark.parametrize( + "expression,src,expected,error", + [ + pytest.param( + "return [ALL_BANDS[1], ALL_BANDS[2]]", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[[3, 4]], [[5, 6]]]), + None, + id="multiple bands in, multiple bands out (1)", + ), + pytest.param( + "return [ALL_BANDS]", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + None, + id="multiple bands in, multiple bands out (2)", + ), + pytest.param( + "B1", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[1, 2]]), + None, + id="multiple bands in, single band out (1)", + ), + pytest.param( + "ALL_BANDS[0]", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[1, 2]]), + None, + id="multiple bands in, single band out (2)", + ), + pytest.param( + "return [B1];", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[1, 2]]), + None, + id="multiple bands in, single band out (3)", + ), + pytest.param( + "return [B1, B2]", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[1, 2]]), + "returned 2 values but 1 output band", + id="return wrong number of bands", + ), + pytest.param( + "return [ALL_BANDS, B2]", + np.array([[[1, 2]], [[3, 4]], [[5, 6]]]), + np.array([[1, 2]]), + "must return a vector or a list of scalars", + id="return wrong number of bands", + ), + ], +) +def test_vrtprocesseddataset_expression(tmp_vsimem, expression, src, expected, error): + + src_filename = tmp_vsimem / "src.tif" + with gdal.GetDriverByName("GTiff").Create(src_filename, 2, 1, 3) as src_ds: + src_ds.WriteArray(src) + src_ds.SetGeoTransform([0, 1, 0, 0, 0, 1]) + + expected_output_bands = 1 if len(expected.shape) == 2 else expected.shape[0] + + output_band_xml = "".join( + f"""""" + for i in range(expected_output_bands) + ) + + ds = gdal.Open( + f""" + + {src_filename} + + + + Expression + {expression} + + + {output_band_xml} + + """ + ) + + if error: + with pytest.raises(Exception, match=error): + result = ds.ReadAsArray() + else: + result = ds.ReadAsArray() + np.testing.assert_equal(result, expected) + + ############################################################################### # Test that serialization (for example due to statistics computation) properly # works diff --git a/frmts/vrt/vrtprocesseddataset.cpp b/frmts/vrt/vrtprocesseddataset.cpp index b8cd27b76915..700665066012 100644 --- a/frmts/vrt/vrtprocesseddataset.cpp +++ b/frmts/vrt/vrtprocesseddataset.cpp @@ -436,9 +436,11 @@ CPLErr VRTProcessedDataset::Init(const CPLXMLNode *psTree, if (nCurrentBandCount != nBands) { - CPLError(CE_Failure, CPLE_AppDefined, - "Number of output bands of last step is not consistent with " - "number of VRTProcessedRasterBand's"); + CPLError( + CE_Failure, CPLE_AppDefined, + "Number of output bands of last step (%d) is not consistent with " + "number of VRTProcessedRasterBand's (%d)", + nCurrentBandCount, nBands); return CE_Failure; } diff --git a/frmts/vrt/vrtprocesseddatasetfunctions.cpp b/frmts/vrt/vrtprocesseddatasetfunctions.cpp index f6fb184fb1f6..6db4d0c03451 100644 --- a/frmts/vrt/vrtprocesseddatasetfunctions.cpp +++ b/frmts/vrt/vrtprocesseddatasetfunctions.cpp @@ -14,9 +14,13 @@ #include "cpl_string.h" #include "vrtdataset.h" +#include + #include +#include #include #include +#include #include #include @@ -1468,6 +1472,218 @@ static CPLErr TrimmingProcess( return CE_None; } +/************************************************************************/ +/* ExpressionInit() */ +/************************************************************************/ + +class ExpressionData +{ + public: + ExpressionData(int nInBands, std::string_view osExpression) + : m_osExpression(osExpression), m_adfValuesForPixel(nInBands) + { + m_oSymbolTable.add_vector("ALL_BANDS", m_adfValuesForPixel); + m_oSymbolTable.add_vector("OUTPUT_BANDS", m_adfResults); + } + + void set_variable_band(const char *pszVariable, int nBand) + { + m_oSymbolTable.add_variable(pszVariable, m_adfValuesForPixel[nBand]); + } + + bool compile() + { + m_oExpression.register_symbol_table(m_oSymbolTable); + bool bSuccess = m_oParser.compile(m_osExpression, m_oExpression); + + if (!bSuccess) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Failed to parse expression."); + + for (size_t i = 0; i < m_oParser.error_count(); i++) + { + const auto &oError = m_oParser.get_error(i); + + CPLError(CE_Warning, CPLE_AppDefined, + "Position: %02d " + "Type: [%s] " + "Message: %s\n", + static_cast(oError.token.position), + exprtk::parser_error::to_str(oError.mode).c_str(), + oError.diagnostic.c_str()); + } + } + + return bSuccess; + } + + const std::vector *evaluate(const double *padfValues) + { + std::copy(padfValues, padfValues + m_adfValuesForPixel.size(), + m_adfValuesForPixel.begin()); + + m_adfResults.clear(); + double value = m_oExpression.value(); // force evaluation + + const auto &results = m_oExpression.results(); + + // We follow a different method to get the result depending on + // how the expression was formed. If a "return" statement was + // used, the result will be accessible via the "result" object. + // If no "return" statement was used, the result is accessible + // from the "value" variable (and must not be a vector.) + if (results.count() == 0) + { + m_adfResults.resize(1); + m_adfResults[0] = value; + } + else if (results.count() == 1) + { + + if (results[0].type == exprtk::type_store::e_scalar) + { + m_adfResults.resize(1); + results.get_scalar(0, m_adfResults[0]); + } + else if (results[0].type == exprtk::type_store::e_vector) + { + results.get_vector(0, m_adfResults); + } + else + { + CPLError(CE_Failure, CPLE_AppDefined, + "Expression returned an unexpected type."); + return nullptr; + } + } + else + { + m_adfResults.resize(results.count()); + for (size_t i = 0; i < results.count(); i++) + { + if (results[i].type != exprtk::type_store::e_scalar) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Expression must return a vector or a list of " + "scalars."); + return nullptr; + } + else + { + results.get_scalar(i, m_adfResults[i]); + } + } + } + + return &m_adfResults; + } + + private: + std::string m_osExpression{}; + std::vector m_adfValuesForPixel; + std::vector m_adfResults{}; + + exprtk::expression m_oExpression{}; + exprtk::parser m_oParser{}; + exprtk::symbol_table m_oSymbolTable{}; +}; + +static CPLErr ExpressionInit(const char * /*pszFuncName*/, void * /*pUserData*/, + CSLConstList papszFunctionArgs, int nInBands, + GDALDataType eInDT, double * /* padfInNoData */, + int * /*pnOutBands */, GDALDataType *peOutDT, + double ** /* ppadfOutNoData */, + const char * /* pszVRTPath */, + VRTPDWorkingDataPtr *ppWorkingData) +{ + CPLAssert(eInDT == GDT_Float64); + + *peOutDT = eInDT; + *ppWorkingData = nullptr; + + const char *pszExpression = + CSLFetchNameValue(papszFunctionArgs, "expression"); + + auto data = std::make_unique(nInBands, pszExpression); + + for (int i = 0; i < nInBands; i++) + { + std::string osVar = "B" + std::to_string(i + 1); + data->set_variable_band(osVar.c_str(), i); + } + + if (!data->compile()) + { + return CE_Failure; + } + + *ppWorkingData = data.release(); + + return CE_None; +} + +static void ExpressionFree(const char * /* pszFuncName */, + void * /* pUserData */, + VRTPDWorkingDataPtr pWorkingData) +{ + ExpressionData *data = static_cast(pWorkingData); + delete data; +} + +static CPLErr ExpressionProcess( + const char * /* pszFuncName */, void * /* pUserData */, + VRTPDWorkingDataPtr pWorkingData, CSLConstList /* papszFunctionArgs */, + int nBufXSize, int nBufYSize, const void *pInBuffer, + size_t /* nInBufferSize */, GDALDataType eInDT, int nInBands, + const double *CPL_RESTRICT /* padfInNoData */, void *pOutBuffer, + size_t /* nOutBufferSize */, GDALDataType eOutDT, int nOutBands, + const double *CPL_RESTRICT /* padfOutNoData */, double /* dfSrcXOff */, + double /* dfSrcYOff */, double /* dfSrcXSize */, double /* dfSrcYSize */, + const double /* adfSrcGT */[], const char * /* pszVRTPath "*/, + CSLConstList /* papszExtra */) +{ + ExpressionData *expr = static_cast(pWorkingData); + + const size_t nElts = static_cast(nBufXSize) * nBufYSize; + + CPL_IGNORE_RET_VAL(eInDT); + CPLAssert(eInDT == GDT_Float64); + const double *CPL_RESTRICT padfSrc = static_cast(pInBuffer); + + CPLAssert(eOutDT == GDT_Float64); + CPL_IGNORE_RET_VAL(eOutDT); + double *CPL_RESTRICT padfDst = static_cast(pOutBuffer); + + for (size_t i = 0; i < nElts; i++) + { + const auto *padfResults = expr->evaluate(padfSrc); + + if (!padfResults) + { + return CE_Failure; + } + + if (padfResults->size() != static_cast(nOutBands)) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Expression returned %d values but " + "%d output bands were expected.", + static_cast(padfResults->size()), nOutBands); + return CE_Failure; + } + + for (int iDstBand = 0; iDstBand < nOutBands; iDstBand++) + { + *padfDst++ = (*padfResults)[iDstBand]; + } + + padfSrc += nInBands; + } + + return CE_None; +} + /************************************************************************/ /* GDALVRTRegisterDefaultProcessedDatasetFuncs() */ /************************************************************************/ @@ -1573,4 +1789,13 @@ void GDALVRTRegisterDefaultProcessedDatasetFuncs() "", GDT_Float64, nullptr, 0, nullptr, 0, TrimmingInit, TrimmingFree, TrimmingProcess, nullptr); + + GDALVRTRegisterProcessedDatasetFunc( + "Expression", nullptr, + "" + " " + "", + GDT_Float64, nullptr, 0, nullptr, 0, ExpressionInit, ExpressionFree, + ExpressionProcess, nullptr); }