diff --git a/.devcontainer/check_test_count.sh b/.devcontainer/check_test_count.sh new file mode 100644 index 0000000..8360d83 --- /dev/null +++ b/.devcontainer/check_test_count.sh @@ -0,0 +1,37 @@ +#!/bin/bash -l + +# Description +# This script checks if all the tests are included in the matrix in the test step in ci-databricks.yml. +# It is used in the pipeline to ensure that all the tests are included in the matrix. +# The script must be invoked with a filter matching the paths NOT included in the matrix + +# $@: (Optional) Can be set to specify a filter for running python tests at the specified path. +echo "Filter (paths): '$@'" + +# Exit immediately with failure status if any command fails +set -e + +cd source/settlement_report_python/tests/ +# Enable extended globbing. E.g. see https://stackoverflow.com/questions/8525437/list-files-not-matching-a-pattern +shopt -s extglob + +# This script runs pytest with the --collect-only flag to get the number of tests. +# 'grep' filters the output to get the line with the number of tests collected. Multiple lines can be returned. +# 'awk' is used to get the second column of the output which contains the number of tests. +# 'head' is used to get the first line of the output which contains the number of tests. +# Example output line returned by the grep filter: 'collected 10 items' +executed_test_count=$(coverage run --branch -m pytest $@ --collect-only | grep collected | awk '{print $2}' | head -n 1) + +total_test_count=$(coverage run --branch -m pytest --collect-only | grep collected | awk '{print $2}' | head -n 1) + +echo "Number of tests being executed: $executed_test_count" +echo "Total number of pytest tests: $total_test_count" + + +if [ "$total_test_count" == "$executed_test_count" ]; then + echo "Not missing any tests." +else + difference=$((total_test_count - executed_test_count)) + echo "Found $difference tests not executed. A folder is missing in the matrix." + exit 1 +fi diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..6af9207 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,55 @@ +{ + "name": "Spark Dev", + "build": { + "dockerfile": "../.docker/Dockerfile", + "args": {} + }, + "customizations": { + "vscode": { + "extensions": [ + "matangover.mypy", + "ms-python.flake8", + "ms-dotnettools.dotnet-interactive-vscode", + "ms-python.python", + "ms-python.black-formatter", + "littlefoxteam.vscode-python-test-adapter", + "hbenl.vscode-test-explorer", + "eamodio.gitlens", + "ms-python.vscode-pylance", + "HashiCorp.terraform", + "christian-kohler.path-intellisense", + "Gruntfuggly.todo-tree", + "DavidAnson.vscode-markdownlint", + "kevinglasson.cornflakes-linter", + "KevinRose.vsc-python-indent", + "sonarsource.sonarlint-vscode" + ], + // Set *default* container specific settings.json values on container create. + "settings": { + "terminal.integrated.shell.linux": "/bin/bash", + "editor.formatOnSave": false, + "[python]": { + "editor.formatOnSave": true + }, + "python.formatting.provider": "black", + "python.defaultInterpreterPath": "/opt/conda/bin/python", + "python.languageServer": "Pylance", + "markdownlint.config": { + "MD007": { + "indent": 4 + } + } + } + } + }, + "containerEnv": { + "GRANT_SUDO": "yes" + }, + "forwardPorts": [ + 5568 + ], + "appPort": [ + "5568:5050" + ], + "containerUser": "root" +} \ No newline at end of file diff --git a/.devcontainer/docker-compose-windows.yml b/.devcontainer/docker-compose-windows.yml new file mode 100644 index 0000000..19cd5e2 --- /dev/null +++ b/.devcontainer/docker-compose-windows.yml @@ -0,0 +1,15 @@ +services: + python-unit-test: + image: ghcr.io/energinet-datahub/geh-settlement-report/python-unit-test:${IMAGE_TAG:-latest} + volumes: + # Forwards the local Docker socket to the container. + - /var/run/docker.sock:/var/run/docker-host.sock + # Update this to wherever you want VS Code to mount the folder of your project + - ..:/workspaces/geh-settlement-report:cached + # Map to Azure CLI token cache location (on Windows) + - "${USERPROFILE}/.azure:/home/joyvan/.azure" + environment: + # Pass the environment variables from your shell straight through to your containers. + # No warning is issued if the variable in the shell environment is not set. + # See https://docs.docker.com/compose/environment-variables/set-environment-variables/#additional-information-1 + - AZURE_KEYVAULT_URL diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 0000000..211c8d4 --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,15 @@ +services: + python-unit-test: + image: ghcr.io/energinet-datahub/geh-settlement-report/python-unit-test:${IMAGE_TAG:-latest} + volumes: + # Forwards the local Docker socket to the container. + - /var/run/docker.sock:/var/run/docker-host.sock + # Update this to wherever you want VS Code to mount the folder of your project + - ..:/workspaces/geh-settlement-report:cached + # Map to Azure CLI token cache location (on Linux) + - "${HOME}/.azure:/home/joyvan/.azure" + environment: + # Pass the environment variables from your shell straight through to your containers. + # No warning is issued if the variable in the shell environment is not set. + # See https://docs.docker.com/compose/environment-variables/set-environment-variables/#additional-information-1 + - AZURE_KEYVAULT_URL diff --git a/.devcontainer/requirements.txt b/.devcontainer/requirements.txt new file mode 100644 index 0000000..e09f175 --- /dev/null +++ b/.devcontainer/requirements.txt @@ -0,0 +1,38 @@ +# This is a pip 'requirements.txt' file +# See https://pip.pypa.io/en/stable/reference/requirements-file-format/ + +# +# PYTHON TOOLS +# +black +build +coverage-threshold +flake8 +mypy +pyspelling +pytest-xdist + +# +# CODE DEPENDENCIES +# - Make sure any packages specified in setup.py are pinned to the same version here +# +databricks-cli==0.18 +dataclasses-json==0.6.7 +delta-spark==3.2.0 +pyspark==3.5.1 +dependency_injector==4.43.0 +azure-identity==1.17.1 +azure-keyvault-secrets==4.8.0 +azure-monitor-opentelemetry==1.6.4 +azure-core==1.32.0 +azure-monitor-query==1.4.0 +opengeh-spark-sql-migrations @ git+https://git@github.com/Energinet-DataHub/opengeh-python-packages@2.4.1#subdirectory=source/spark_sql_migrations +python-dateutil==2.8.2 +types-python-dateutil==2.9.0.20241003 +opengeh-telemetry @ git+https://git@github.com/Energinet-DataHub/opengeh-python-packages@2.4.1#subdirectory=source/telemetry + +coverage==7.6.8 +pytest==8.3.3 +configargparse==1.7.0 +pytest-mock==3.14.0 +virtualenv==20.24.2 \ No newline at end of file diff --git a/.docker/Dockerfile b/.docker/Dockerfile new file mode 100644 index 0000000..ac0bcfc --- /dev/null +++ b/.docker/Dockerfile @@ -0,0 +1,56 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The spark version should follow the spark version in databricks. +# The databricks version of spark is controlled from dh3-infrastructure and uses latest LTS (ATTOW - Spark v3.5.0) +# pyspark-slim version should match pyspark version in requirements.txt +FROM ghcr.io/energinet-datahub/pyspark-slim:3.5.1-5 + +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +USER root + +RUN apt-get update; \ + # Install git as it is needed by spark + apt-get install --no-install-recommends -y git; \ + # Curl is temporarily installed in order to download the Azure CLI (consider multi stage build instead) + apt-get install --no-install-recommends -y curl; \ + # Install Azure CLI, see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli-linux?pivots=apt + # as it is needed by integration tests + curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash; \ + # Cleanup apt cache to reduce image size + apt-get remove -y curl; \ + rm -rf /var/lib/apt/lists/* + +# This replaces the default spark configuration in the docker image with the ones defined in the sibling file +COPY spark-defaults.conf $SPARK_HOME/conf/ + +# Install python packages used in pyspark development (keep spark dependent packages alligned) +# delta-spark version has to have compatibility with spark version (https://docs.delta.io/latest/releases.html) +# example (delta 2.2.x = spark 3.3.x) +COPY requirements.txt requirements.txt +RUN pip --no-cache-dir install -r requirements.txt + +# Set misc environment variables required for properly run spark +# note the amount of memory used on the driver is adjusted here +ENV PATH=$SPARK_HOME/bin:$HADOOP_HOME/bin:$PATH \ + PYTHONPATH="${SPARK_HOME}/python:${SPARK_HOME}/python/lib/py4j-0.10.9-src.zip" \ + SPARK_OPTS="--driver-java-options=-Xms1024M --driver-java-options=-Xmx4096M --driver-java-options=-Dlog4j.logLevel=info" + +# Dynamically downloading spark dependencies from conf/spark-defaults.conf. This is done to save time in the build pipeline so that we don't need to download on every build. +RUN spark-shell + +# Make $HOME owned by the root, which is the user used in the container +# This is needed for e.g. commands that create files or folders in $HOME +RUN sudo chown -R root:users $HOME diff --git a/.docker/entrypoint.sh b/.docker/entrypoint.sh new file mode 100644 index 0000000..3091364 --- /dev/null +++ b/.docker/entrypoint.sh @@ -0,0 +1,44 @@ +#!/bin/bash -l + +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# $1: Mandatory test folder path +# $2: (Optional) Can be set to specify a filter for running python tests by using 'keyword expressions'. +# See use of '-k' and 'keyword expressions' here: https://docs.pytest.org/en/7.4.x/how-to/usage.html#specifying-which-tests-to-run +echo "Tests folder path: '$1'" +echo "Filter (paths): '$2'" + +# Configure Azure CLI to use token cache which must be mapped as volume from host machine +export AZURE_CONFIG_DIR=/home/joyvan/.azure + +# There env vars are important to ensure that the driver and worker nodes in spark are alligned +export PYSPARK_PYTHON=/opt/conda/bin/python +export PYSPARK_DRIVER_PYTHON=/opt/conda/bin/python + +# Exit immediately with failure status if any command fails +set -e + +# Enable extended globbing. E.g. see https://stackoverflow.com/questions/8525437/list-files-not-matching-a-pattern +shopt -s extglob + +cd $1 +coverage run --branch -m pytest -vv --junitxml=pytest-results.xml $2 + +# Create data for threshold evaluation +coverage json +# Create human reader friendly HTML report +coverage html + +coverage-threshold --line-coverage-min 25 diff --git a/.docker/requirements.txt b/.docker/requirements.txt new file mode 100644 index 0000000..e09f175 --- /dev/null +++ b/.docker/requirements.txt @@ -0,0 +1,38 @@ +# This is a pip 'requirements.txt' file +# See https://pip.pypa.io/en/stable/reference/requirements-file-format/ + +# +# PYTHON TOOLS +# +black +build +coverage-threshold +flake8 +mypy +pyspelling +pytest-xdist + +# +# CODE DEPENDENCIES +# - Make sure any packages specified in setup.py are pinned to the same version here +# +databricks-cli==0.18 +dataclasses-json==0.6.7 +delta-spark==3.2.0 +pyspark==3.5.1 +dependency_injector==4.43.0 +azure-identity==1.17.1 +azure-keyvault-secrets==4.8.0 +azure-monitor-opentelemetry==1.6.4 +azure-core==1.32.0 +azure-monitor-query==1.4.0 +opengeh-spark-sql-migrations @ git+https://git@github.com/Energinet-DataHub/opengeh-python-packages@2.4.1#subdirectory=source/spark_sql_migrations +python-dateutil==2.8.2 +types-python-dateutil==2.9.0.20241003 +opengeh-telemetry @ git+https://git@github.com/Energinet-DataHub/opengeh-python-packages@2.4.1#subdirectory=source/telemetry + +coverage==7.6.8 +pytest==8.3.3 +configargparse==1.7.0 +pytest-mock==3.14.0 +virtualenv==20.24.2 \ No newline at end of file diff --git a/.docker/spark-defaults.conf b/.docker/spark-defaults.conf new file mode 100644 index 0000000..5844dae --- /dev/null +++ b/.docker/spark-defaults.conf @@ -0,0 +1,15 @@ +# Default system properties included when running spark-submit. +# This is useful for setting default environmental settings. + +# Example: +# spark.master spark://master:7077 +# spark.eventLog.enabled true +# spark.eventLog.dir hdfs://namenode:8021/directory +# spark.serializer org.apache.spark.serializer.KryoSerializer +# spark.driver.memory 16g +# spark.executor.extraJavaOptions -XX:+PrintGCDetails -Dkey=value -Dnumbers="one two three" + +spark.jars.packages io.delta:delta-core_2.12:2.2.0 + +# spark.hadoop.fs.AbstractFileSystem.abfss.impl org.apache.hadoop.fs.azurebfs.Abfss +# spark.hadoop.fs.abfss.impl org.apache.hadoop.fs.azurebfs.SecureAzureBlobFileSystem diff --git a/.github/workflows/ci-orchestrator.yml b/.github/workflows/ci-orchestrator.yml index 5a2847e..dc3d712 100644 --- a/.github/workflows/ci-orchestrator.yml +++ b/.github/workflows/ci-orchestrator.yml @@ -35,11 +35,25 @@ jobs: changes: uses: ./.github/workflows/detect-changes.yml + ci_docker: + needs: changes + uses: Energinet-DataHub/.github/.github/workflows/python-build-and-push-docker-image.yml@v13 + with: + docker_changed: ${{ needs.changes.outputs.docker == 'true' }} + docker_changed_in_commit: ${{ needs.changes.outputs.docker_in_commit == 'true' }} + ci_dotnet: needs: changes if: ${{ needs.changes.outputs.dotnet == 'true' || needs.changes.outputs.db_migrations == 'true' }} uses: ./.github/workflows/ci-dotnet.yml + ci_python: + needs: changes + if: ${{ needs.changes.outputs.settlement_report_job == 'true' }} + uses: ./.github/workflows/ci-python.yml + with: + image_tag: ${{ needs.ci_docker.outputs.image_tag }} + render_c4model_views: needs: changes if: ${{ needs.changes.outputs.render_c4model_views == 'true' }} diff --git a/.github/workflows/ci-python.yml b/.github/workflows/ci-python.yml new file mode 100644 index 0000000..9803631 --- /dev/null +++ b/.github/workflows/ci-python.yml @@ -0,0 +1,96 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: CI Databricks + +on: + workflow_call: + inputs: + image_tag: + type: string + default: latest + +jobs: + databricks_ci_build: + uses: Energinet-DataHub/.github/.github/workflows/databricks-build-prerelease.yml@v14 + with: + python_version: 3.11.7 + architecture: x64 + wheel_working_directory: ./source/settlement_report_python + multiple_wheels: true + should_include_assets: true + + settlement_report_unit_tests: + strategy: + fail-fast: false + matrix: + # IMPORTANT: When adding a new folder here it should also be added in the `unit_test_check` job! + tests_filter_expression: + - name: Settlement report unit testing + paths: domain/ entry_points/ infrastructure/ + uses: Energinet-DataHub/.github/.github/workflows/python-ci.yml@v14 + with: + job_name: ${{ matrix.tests_filter_expression.name }} + operating_system: dh3-ubuntu-20.04-4core + path_static_checks: ./source/settlement_report_python + # documented here: https://github.com/Energinet-DataHub/opengeh-wholesale/tree/main/source/databricks#styling-and-formatting + ignore_errors_and_warning_flake8: E501,F401,E402,E203,W503 + tests_folder_path: ./source/settlement_report_python/tests + test_report_path: ./source/settlement_report_python/tests + # See .docker/entrypoint.py on how to use the filter expression + tests_filter_expression: ${{ matrix.tests_filter_expression.paths }} + image_tag: ${{ inputs.image_tag }} + + # Tests that require the integration test environment + settlement_report_integration_tests: + uses: Energinet-DataHub/.github/.github/workflows/python-ci.yml@v14 + with: + job_name: Settlement report integration testing + operating_system: dh3-ubuntu-20.04-4core + path_static_checks: ./source/settlement_report_python + # documented here: https://github.com/Energinet-DataHub/opengeh-wholesale/tree/main/source/databricks#styling-and-formatting + ignore_errors_and_warning_flake8: E501,F401,E402,E203,W503 + tests_folder_path: ./source/settlement_report_python/tests + test_report_path: ./source/settlement_report_python/tests + # See .docker/entrypoint.py on how to use the filter expression + tests_filter_expression: integration_test/ + use_integrationtest_environment: true + azure_integrationtest_tenant_id: ${{ vars.integration_test_azure_tenant_id }} + azure_integrationtest_subscription_id: ${{ vars.integration_test_azure_subscription_id }} + azure_integrationtest_spn_id: ${{ vars.integration_test_azure_spn_id_oidc }} + azure_keyvault_url: ${{ vars.integration_test_azure_keyvault_url }} + image_tag: ${{ inputs.image_tag }} + + settlement_report_mypy_check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Run pip intall and mypy check of files in package + shell: bash + run: | + pip install --upgrade pip + pip install mypy types-python-dateutil + mypy ./source/settlement_report_python/settlement_report_job --disallow-untyped-defs --ignore-missing-imports + + settlement_report_black_check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + options: --check --diff + src: ./source/settlement_report_python diff --git a/.github/workflows/detect-changes.yml b/.github/workflows/detect-changes.yml index a8b58a4..d957f7e 100644 --- a/.github/workflows/detect-changes.yml +++ b/.github/workflows/detect-changes.yml @@ -29,7 +29,12 @@ on: value: ${{ jobs.changes.outputs.render_c4model_views }} db_migrations: value: ${{ jobs.changes.outputs.db_migrations }} - + settlement_report_job: + value: ${{ jobs.changes.outputs.settlement_report_job }} + docker: + value: ${{ jobs.changes.outputs.docker }} + docker_in_commit: + value: ${{ jobs.changes.outputs.docker_in_commit }} jobs: changes: name: Determine relevant jobs @@ -39,6 +44,9 @@ jobs: dotnet: ${{ steps.filter.outputs.dotnet }} render_c4model_views: ${{ steps.filter.outputs.render_c4model_views }} db_migrations: ${{ steps.filter.outputs.db_migrations }} + settlement_report_job: ${{ steps.filter.outputs.settlement_report_job }} + docker: ${{ steps.filter.outputs.docker }} + docker_in_commit: ${{ steps.docker_changed.outputs.any_changed }} steps: # For pull requests it's not necessary to checkout the code because GitHub REST API is used to determine changes - name: Checkout repository @@ -62,3 +70,14 @@ jobs: - 'docs/diagrams/c4-model/views.dsl' - 'docs/diagrams/c4-model/views.json' - 'docs/diagrams/c4-model/model.dsl' + settlement_report_job: + - 'source/settlement_report_python/**' + docker: + - .docker/** + + - name: Package content or build has changed + id: docker_changed + uses: tj-actions/changed-files@v41 + with: + since_last_remote_commit: true + files: .docker/** \ No newline at end of file diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..0ac2335 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,19 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. + // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp + // List of extensions which should be recommended for users of this workspace. + "recommendations": [ + "editorconfig.editorconfig", + "ms-vscode-remote.remote-containers", + "4ops.terraform", + "vscode-icons-team.vscode-icons", + "davidanson.vscode-markdownlint", + "ms-vscode.powershell", + "redhat.vscode-yaml", + "github.vscode-github-actions", + "systemticks.c4-dsl-extension", + "ms-azuretools.vscode-docker" + ], + // List of extensions recommended by VS Code that should not be recommended for users of this workspace. + "unwantedRecommendations": [] +} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..306f58e --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7f13c49 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,23 @@ +{ + "python.testing.pytestArgs": [ + "source" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true, + "powershell.codeFormatting.avoidSemicolonsAsLineTerminators": true, + "powershell.codeFormatting.useCorrectCasing": true, + "powershell.codeFormatting.whitespaceBetweenParameters": true, + "markdownlint.config": { + "MD007": { + "indent": 4 + }, + "MD024": { + "siblings_only": true + }, + }, + "redhat.telemetry.enabled": false, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.formatting.provider": "none", +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..46fa57c --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,35 @@ +{ + // See https://go.microsoft.com/fwlink/?LinkId=733558 + // for the documentation about the tasks.json format + "version": "2.0.0", + "tasks": [ + { + "label": "Structurizr Lite: Load 'views'", + "type": "docker-run", + "dockerRun": { + "image": "structurizr/lite:latest", + "ports": [ + { + "containerPort": 8080, + "hostPort": 8080 + } + ], + "volumes": [ + { + "localPath": "${workspaceFolder}/docs/diagrams/c4-model", + "containerPath": "/usr/local/structurizr" + } + ], + "env": { + "STRUCTURIZR_WORKSPACE_FILENAME": "views" + }, + "remove": true + }, + "problemMatcher": [], + "group": { + "kind": "build", + "isDefault": true + } + } + ] +} \ No newline at end of file diff --git a/source/settlement_report_python/contracts/settlement-report-balance-fixing-parameters-reference.txt b/source/settlement_report_python/contracts/settlement-report-balance-fixing-parameters-reference.txt new file mode 100644 index 0000000..673bf68 --- /dev/null +++ b/source/settlement_report_python/contracts/settlement-report-balance-fixing-parameters-reference.txt @@ -0,0 +1,19 @@ +# This file contains all the parameters that the settlement report job consumes. Some are required and some or not. +# +# Empty lines and lines starting with '#' are ignores in the tests. + +# Required parameters +--report-id={report-id} +--period-start=2022-05-31T22:00:00Z +--period-end=2022-06-01T22:00:00Z +--calculation-type=balance_fixing +--requesting-actor-market-role=energy_supplier +# market-role values: datahub_administrator, energy_supplier, grid_access_provider, system_operator +--requesting-actor-id=1234567890123 +--grid-area-codes=[804, 805] + +# Optional parameters +--energy-supplier-ids=[1234567890123] +--split-report-by-grid-area +--prevent-large-text-files +--include-basis-data diff --git a/source/settlement_report_python/contracts/settlement-report-wholesale-calculations-parameters-reference.txt b/source/settlement_report_python/contracts/settlement-report-wholesale-calculations-parameters-reference.txt new file mode 100644 index 0000000..fa392c9 --- /dev/null +++ b/source/settlement_report_python/contracts/settlement-report-wholesale-calculations-parameters-reference.txt @@ -0,0 +1,19 @@ +# This file contains all the parameters that the settlement report job consumes. Some are required and some or not. +# +# Empty lines and lines starting with '#' are ignores in the tests. + +# Required parameters +--report-id={report-id} +--period-start=2022-05-31T22:00:00Z +--period-end=2022-06-01T22:00:00Z +--calculation-type=wholesale_fixing +--requesting-actor-market-role=energy_supplier +# market-role values: datahub_administrator, energy_supplier, grid_access_provider, system_operator +--requesting-actor-id=1234567890123 +--calculation-id-by-grid-area={"804": "95bd2365-c09b-4ee7-8c25-8dd56b564811", "805": "d3e2b83a-2fd9-4bcd-a6dc-41e4ce74cd6d"} + +# Optional parameters +--energy-supplier-ids=[1234567890123] +--split-report-by-grid-area +--prevent-large-text-files +--include-basis-data diff --git a/source/settlement_report_python/settlement_report_job/__init__.py b/source/settlement_report_python/settlement_report_job/__init__.py new file mode 100644 index 0000000..2b02430 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging as logger + +logger.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel( + logger.WARNING +) +logger.getLogger("azure.monitor.opentelemetry.exporter.export").setLevel(logger.WARNING) diff --git a/source/settlement_report_python/settlement_report_job/domain/__init__.py b/source/settlement_report_python/settlement_report_job/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/__init__.py b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/charge_link_periods_factory.py b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/charge_link_periods_factory.py new file mode 100644 index 0000000..cad223b --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/charge_link_periods_factory.py @@ -0,0 +1,45 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame + +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.charge_link_periods.read_and_filter import ( + read_and_filter, +) +from settlement_report_job.domain.charge_link_periods.prepare_for_csv import ( + prepare_for_csv, +) + + +def create_charge_link_periods( + args: SettlementReportArgs, + repository: WholesaleRepository, +) -> DataFrame: + charge_link_periods = read_and_filter( + args.period_start, + args.period_end, + args.calculation_id_by_grid_area, + args.energy_supplier_ids, + args.requesting_actor_market_role, + args.requesting_actor_id, + repository, + ) + + return prepare_for_csv( + charge_link_periods, + ) diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/order_by_columns.py new file mode 100644 index 0000000..16eb0a9 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/order_by_columns.py @@ -0,0 +1,21 @@ +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns( + requesting_actor_market_role: MarketRole, +) -> list[str]: + order_by_column_names = [ + CsvColumnNames.metering_point_type, + CsvColumnNames.metering_point_id, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.charge_link_from_date, + ] + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + order_by_column_names.insert(0, CsvColumnNames.energy_supplier_id) + + return order_by_column_names diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/prepare_for_csv.py new file mode 100644 index 0000000..d48cd9b --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/prepare_for_csv.py @@ -0,0 +1,70 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.map_to_csv_naming import ( + METERING_POINT_TYPES, + CHARGE_TYPES, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + charge_link_periods: DataFrame, +) -> DataFrame: + columns = [ + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + F.col(DataProductColumnNames.metering_point_id).alias( + CsvColumnNames.metering_point_id + ), + map_from_dict(METERING_POINT_TYPES)[ + F.col(DataProductColumnNames.metering_point_type) + ].alias(CsvColumnNames.metering_point_type), + map_from_dict(CHARGE_TYPES)[F.col(DataProductColumnNames.charge_type)].alias( + CsvColumnNames.charge_type + ), + F.col(DataProductColumnNames.charge_owner_id).alias( + CsvColumnNames.charge_owner_id + ), + F.col(DataProductColumnNames.charge_code).alias(CsvColumnNames.charge_code), + F.col(DataProductColumnNames.quantity).alias(CsvColumnNames.charge_quantity), + F.col(DataProductColumnNames.from_date).alias( + CsvColumnNames.charge_link_from_date + ), + F.col(DataProductColumnNames.to_date).alias(CsvColumnNames.charge_link_to_date), + ] + if DataProductColumnNames.energy_supplier_id in charge_link_periods.columns: + columns.append( + F.col(DataProductColumnNames.energy_supplier_id).alias( + CsvColumnNames.energy_supplier_id + ) + ) + + return charge_link_periods.select(columns) diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/read_and_filter.py b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/read_and_filter.py new file mode 100644 index 0000000..054cdad --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_link_periods/read_and_filter.py @@ -0,0 +1,120 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime +from uuid import UUID + +from pyspark.sql import DataFrame + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.join_metering_points_periods_and_charge_link_periods import ( + join_metering_points_periods_and_charge_link_periods, +) +from settlement_report_job.domain.utils.merge_periods import ( + merge_connected_periods, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.repository_filtering import ( + read_charge_link_periods, + read_metering_point_periods_by_calculation_ids, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +logger = Logger(__name__) + + +@use_span() +def read_and_filter( + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + repository: WholesaleRepository, +) -> DataFrame: + logger.info("Creating charge links") + + charge_link_periods = read_charge_link_periods( + repository=repository, + period_start=period_start, + period_end=period_end, + charge_owner_id=requesting_actor_id, + requesting_actor_market_role=requesting_actor_market_role, + ) + + charge_link_periods = _join_with_metering_point_periods( + charge_link_periods, + period_start, + period_end, + calculation_id_by_grid_area, + energy_supplier_ids, + repository, + ) + + charge_link_periods = charge_link_periods.select( + _get_select_columns(requesting_actor_market_role) + ) + + charge_link_periods = merge_connected_periods(charge_link_periods) + + return charge_link_periods + + +def _join_with_metering_point_periods( + charge_link_periods: DataFrame, + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + repository: WholesaleRepository, +) -> DataFrame: + metering_point_periods = read_metering_point_periods_by_calculation_ids( + repository=repository, + period_start=period_start, + period_end=period_end, + calculation_id_by_grid_area=calculation_id_by_grid_area, + energy_supplier_ids=energy_supplier_ids, + ) + + charge_link_periods = join_metering_points_periods_and_charge_link_periods( + charge_link_periods, metering_point_periods + ) + + return charge_link_periods + + +def _get_select_columns( + requesting_actor_market_role: MarketRole, +) -> list[str]: + select_columns = [ + DataProductColumnNames.metering_point_id, + DataProductColumnNames.metering_point_type, + DataProductColumnNames.charge_type, + DataProductColumnNames.charge_code, + DataProductColumnNames.charge_owner_id, + DataProductColumnNames.quantity, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + DataProductColumnNames.grid_area_code, + ] + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + select_columns.append(DataProductColumnNames.energy_supplier_id) + + return select_columns diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_price_points/__init__.py b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_price_points/charge_price_points_factory.py b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/charge_price_points_factory.py new file mode 100644 index 0000000..0b60868 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/charge_price_points_factory.py @@ -0,0 +1,43 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame + +from settlement_report_job.domain.charge_price_points.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.domain.charge_price_points.read_and_filter import ( + read_and_filter, +) +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + + +def create_charge_price_points( + args: SettlementReportArgs, + repository: WholesaleRepository, +) -> DataFrame: + charge_price_points = read_and_filter( + args.period_start, + args.period_end, + args.calculation_id_by_grid_area, + args.energy_supplier_ids, + args.requesting_actor_market_role, + args.requesting_actor_id, + repository, + ) + + return prepare_for_csv(charge_price_points, args.time_zone) diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_price_points/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/order_by_columns.py new file mode 100644 index 0000000..da3f89c --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/order_by_columns.py @@ -0,0 +1,12 @@ +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns() -> list[str]: + return [ + CsvColumnNames.charge_type, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.resolution, + CsvColumnNames.is_tax, + CsvColumnNames.time, + ] diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_price_points/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/prepare_for_csv.py new file mode 100644 index 0000000..b298928 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/prepare_for_csv.py @@ -0,0 +1,102 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F, Window + +from telemetry_logging import Logger, use_span + +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.get_start_of_day import get_start_of_day +from settlement_report_job.domain.utils.map_to_csv_naming import ( + CHARGE_TYPES, + TAX_INDICATORS, +) +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + filtered_charge_price_points: DataFrame, + time_zone: str, +) -> DataFrame: + filtered_charge_price_points = filtered_charge_price_points.withColumn( + CsvColumnNames.time, + get_start_of_day(DataProductColumnNames.charge_time, time_zone), + ) + + win = Window.partitionBy( + DataProductColumnNames.grid_area_code, + DataProductColumnNames.charge_type, + DataProductColumnNames.charge_owner_id, + DataProductColumnNames.charge_code, + DataProductColumnNames.resolution, + DataProductColumnNames.is_tax, + CsvColumnNames.time, + ).orderBy(DataProductColumnNames.charge_time) + filtered_charge_price_points = filtered_charge_price_points.withColumn( + "chronological_order", F.row_number().over(win) + ) + + pivoted_df = ( + filtered_charge_price_points.groupBy( + DataProductColumnNames.grid_area_code, + DataProductColumnNames.charge_type, + DataProductColumnNames.charge_owner_id, + DataProductColumnNames.charge_code, + DataProductColumnNames.resolution, + DataProductColumnNames.is_tax, + CsvColumnNames.time, + ) + .pivot( + "chronological_order", + list(range(1, 25 + 1)), + ) + .agg(F.first(DataProductColumnNames.charge_price)) + ) + + charge_price_column_names = [ + F.col(str(i)).alias(f"{CsvColumnNames.energy_price}{i}") + for i in range(1, 25 + 1) + ] + + csv_df = pivoted_df.select( + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + map_from_dict(CHARGE_TYPES)[F.col(DataProductColumnNames.charge_type)].alias( + CsvColumnNames.charge_type + ), + F.col(DataProductColumnNames.charge_owner_id).alias( + CsvColumnNames.charge_owner_id + ), + F.col(DataProductColumnNames.charge_code).alias(CsvColumnNames.charge_code), + F.col(DataProductColumnNames.resolution).alias(CsvColumnNames.resolution), + map_from_dict(TAX_INDICATORS)[F.col(DataProductColumnNames.is_tax)].alias( + CsvColumnNames.is_tax + ), + F.col(CsvColumnNames.time), + *charge_price_column_names, + ) + + return csv_df diff --git a/source/settlement_report_python/settlement_report_job/domain/charge_price_points/read_and_filter.py b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/read_and_filter.py new file mode 100644 index 0000000..81c620a --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/charge_price_points/read_and_filter.py @@ -0,0 +1,147 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime +from uuid import UUID + +from pyspark.sql import DataFrame, functions as F + +from settlement_report_job.domain.utils.factory_filters import ( + filter_by_charge_owner_and_tax_depending_on_market_role, +) +from settlement_report_job.domain.utils.join_metering_points_periods_and_charge_link_periods import ( + join_metering_points_periods_and_charge_link_periods, +) +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.repository_filtering import ( + read_metering_point_periods_by_calculation_ids, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +logger = Logger(__name__) + + +@use_span() +def read_and_filter( + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + repository: WholesaleRepository, +) -> DataFrame: + logger.info("Creating charge prices") + + charge_price_points = ( + repository.read_charge_price_points() + .where((F.col(DataProductColumnNames.charge_time) >= period_start)) + .where(F.col(DataProductColumnNames.charge_time) < period_end) + ) + + charge_price_points = _join_with_charge_link_and_metering_point_periods( + charge_price_points, + period_start, + period_end, + calculation_id_by_grid_area, + energy_supplier_ids, + repository, + ) + + charge_price_information_periods = ( + repository.read_charge_price_information_periods() + ) + + charge_price_points = charge_price_points.join( + charge_price_information_periods, + on=[ + DataProductColumnNames.calculation_id, + DataProductColumnNames.charge_key, + ], + how="inner", + ).select( + charge_price_points["*"], + charge_price_information_periods[DataProductColumnNames.is_tax], + charge_price_information_periods[DataProductColumnNames.resolution], + ) + + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.GRID_ACCESS_PROVIDER, + ]: + charge_price_points = filter_by_charge_owner_and_tax_depending_on_market_role( + charge_price_points, + requesting_actor_market_role, + requesting_actor_id, + ) + + return charge_price_points + + +def _join_with_charge_link_and_metering_point_periods( + charge_price_points: DataFrame, + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + repository: WholesaleRepository, +) -> DataFrame: + charge_link_periods = repository.read_charge_link_periods().where( + (F.col(DataProductColumnNames.from_date) < period_end) + & (F.col(DataProductColumnNames.to_date) > period_start) + ) + + metering_point_periods = read_metering_point_periods_by_calculation_ids( + repository=repository, + period_start=period_start, + period_end=period_end, + calculation_id_by_grid_area=calculation_id_by_grid_area, + energy_supplier_ids=energy_supplier_ids, + ) + + charge_link_periods_and_metering_point_periods = ( + join_metering_points_periods_and_charge_link_periods( + charge_link_periods, metering_point_periods + ) + ) + + charge_price_points = ( + charge_price_points.join( + charge_link_periods_and_metering_point_periods, + on=[ + DataProductColumnNames.calculation_id, + DataProductColumnNames.charge_key, + ], + how="inner", + ) + .where( + F.col(DataProductColumnNames.charge_time) + >= F.col(DataProductColumnNames.from_date) + ) + .where( + F.col(DataProductColumnNames.charge_time) + < F.col(DataProductColumnNames.to_date) + ) + .select( + charge_price_points["*"], + charge_link_periods_and_metering_point_periods[ + DataProductColumnNames.grid_area_code + ], + ) + ).distinct() + + return charge_price_points diff --git a/source/settlement_report_python/settlement_report_job/domain/energy_results/__init__.py b/source/settlement_report_python/settlement_report_job/domain/energy_results/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/energy_results/energy_results_factory.py b/source/settlement_report_python/settlement_report_job/domain/energy_results/energy_results_factory.py new file mode 100644 index 0000000..bbf4d66 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/energy_results/energy_results_factory.py @@ -0,0 +1,44 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame + +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + + +from settlement_report_job.domain.energy_results.read_and_filter import ( + read_and_filter_from_view, +) +from settlement_report_job.domain.energy_results.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.domain.utils.settlement_report_args_utils import ( + should_have_result_file_per_grid_area, +) + + +def create_energy_results( + args: SettlementReportArgs, + repository: WholesaleRepository, +) -> DataFrame: + energy = read_and_filter_from_view(args, repository) + + return prepare_for_csv( + energy, + should_have_result_file_per_grid_area(args), + args.requesting_actor_market_role, + ) diff --git a/source/settlement_report_python/settlement_report_job/domain/energy_results/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/energy_results/order_by_columns.py new file mode 100644 index 0000000..c5e01ae --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/energy_results/order_by_columns.py @@ -0,0 +1,19 @@ +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns(requesting_actor_market_role: MarketRole) -> list: + order_by_column_names = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.time, + ] + + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + order_by_column_names.insert(1, CsvColumnNames.energy_supplier_id) + + return order_by_column_names diff --git a/source/settlement_report_python/settlement_report_job/domain/energy_results/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/energy_results/prepare_for_csv.py new file mode 100644 index 0000000..4cf6e7a --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/energy_results/prepare_for_csv.py @@ -0,0 +1,73 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +import settlement_report_job.domain.utils.map_to_csv_naming as market_naming + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + energy: DataFrame, + one_file_per_grid_area: bool, + requesting_actor_market_role: MarketRole, +) -> DataFrame: + select_columns = [ + F.col(DataProductColumnNames.grid_area_code).alias( + CsvColumnNames.grid_area_code + ), + map_from_dict(market_naming.CALCULATION_TYPES_TO_ENERGY_BUSINESS_PROCESS)[ + F.col(DataProductColumnNames.calculation_type) + ].alias(CsvColumnNames.calculation_type), + F.col(DataProductColumnNames.time).alias(CsvColumnNames.time), + F.col(DataProductColumnNames.resolution).alias(CsvColumnNames.resolution), + map_from_dict(market_naming.METERING_POINT_TYPES)[ + F.col(DataProductColumnNames.metering_point_type) + ].alias(CsvColumnNames.metering_point_type), + map_from_dict(market_naming.SETTLEMENT_METHODS)[ + F.col(DataProductColumnNames.settlement_method) + ].alias(CsvColumnNames.settlement_method), + F.col(DataProductColumnNames.quantity).alias(CsvColumnNames.energy_quantity), + ] + + if requesting_actor_market_role is MarketRole.DATAHUB_ADMINISTRATOR: + select_columns.insert( + 1, + F.col(DataProductColumnNames.energy_supplier_id).alias( + CsvColumnNames.energy_supplier_id + ), + ) + + if one_file_per_grid_area: + select_columns.append( + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + ) + + return energy.select(select_columns) diff --git a/source/settlement_report_python/settlement_report_job/domain/energy_results/read_and_filter.py b/source/settlement_report_python/settlement_report_job/domain/energy_results/read_and_filter.py new file mode 100644 index 0000000..e1768c4 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/energy_results/read_and_filter.py @@ -0,0 +1,81 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.domain.utils.factory_filters import ( + filter_by_energy_supplier_ids, + filter_by_grid_area_codes, + filter_by_calculation_id_by_grid_area, + read_and_filter_by_latest_calculations, +) + +log = Logger(__name__) + + +def _get_view_read_function( + requesting_actor_market_role: MarketRole, + repository: WholesaleRepository, +) -> Callable[[], DataFrame]: + if requesting_actor_market_role == MarketRole.GRID_ACCESS_PROVIDER: + return repository.read_energy + else: + return repository.read_energy_per_es + + +@use_span() +def read_and_filter_from_view( + args: SettlementReportArgs, repository: WholesaleRepository +) -> DataFrame: + read_from_repository_func = _get_view_read_function( + args.requesting_actor_market_role, repository + ) + + df = read_from_repository_func().where( + (F.col(DataProductColumnNames.time) >= args.period_start) + & (F.col(DataProductColumnNames.time) < args.period_end) + ) + + if args.energy_supplier_ids: + df = df.where(filter_by_energy_supplier_ids(args.energy_supplier_ids)) + + if args.calculation_type is CalculationType.BALANCE_FIXING and args.grid_area_codes: + df = df.where(filter_by_grid_area_codes(args.grid_area_codes)) + df = read_and_filter_by_latest_calculations( + df=df, + repository=repository, + grid_area_codes=args.grid_area_codes, + period_start=args.period_start, + period_end=args.period_end, + time_zone=args.time_zone, + time_column_name=DataProductColumnNames.time, + ) + elif args.calculation_id_by_grid_area: + # args.calculation_id_by_grid_area should never be null when not BALANCE_FIXING. + df = df.where( + filter_by_calculation_id_by_grid_area(args.calculation_id_by_grid_area) + ) + + return df diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/__init__.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/clamp_period.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/clamp_period.py new file mode 100644 index 0000000..9bf134b --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/clamp_period.py @@ -0,0 +1,26 @@ +from datetime import datetime + +from pyspark.sql import DataFrame, functions as F + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +def clamp_to_selected_period( + periods: DataFrame, selected_period_start: datetime, selected_period_end: datetime +) -> DataFrame: + periods = periods.withColumn( + DataProductColumnNames.to_date, + F.when( + F.col(DataProductColumnNames.to_date) > selected_period_end, + selected_period_end, + ).otherwise(F.col(DataProductColumnNames.to_date)), + ).withColumn( + DataProductColumnNames.from_date, + F.when( + F.col(DataProductColumnNames.from_date) < selected_period_start, + selected_period_start, + ).otherwise(F.col(DataProductColumnNames.from_date)), + ) + return periods diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/metering_point_periods_factory.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/metering_point_periods_factory.py new file mode 100644 index 0000000..a0ead43 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/metering_point_periods_factory.py @@ -0,0 +1,87 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + +from settlement_report_job.domain.metering_point_periods.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.domain.metering_point_periods.read_and_filter_wholesale import ( + read_and_filter as read_and_filter_wholesale, +) +from settlement_report_job.domain.metering_point_periods.read_and_filter_balance_fixing import ( + read_and_filter as read_and_filter_balance_fixing, +) +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +def create_metering_point_periods( + args: SettlementReportArgs, + repository: WholesaleRepository, +) -> DataFrame: + selected_columns = _get_select_columns(args.requesting_actor_market_role) + if args.calculation_type is CalculationType.BALANCE_FIXING: + metering_point_periods = read_and_filter_balance_fixing( + args.period_start, + args.period_end, + args.grid_area_codes, + args.energy_supplier_ids, + selected_columns, + args.time_zone, + repository, + ) + else: + metering_point_periods = read_and_filter_wholesale( + args.period_start, + args.period_end, + args.calculation_id_by_grid_area, + args.energy_supplier_ids, + args.requesting_actor_market_role, + args.requesting_actor_id, + selected_columns, + repository, + ) + + return prepare_for_csv( + metering_point_periods, + args.requesting_actor_market_role, + ) + + +def _get_select_columns(requesting_actor_market_role: MarketRole) -> list[str]: + select_columns = [ + DataProductColumnNames.metering_point_id, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + DataProductColumnNames.grid_area_code, + DataProductColumnNames.from_grid_area_code, + DataProductColumnNames.to_grid_area_code, + DataProductColumnNames.metering_point_type, + DataProductColumnNames.settlement_method, + ] + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + select_columns.append(DataProductColumnNames.energy_supplier_id) + return select_columns diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/order_by_columns.py new file mode 100644 index 0000000..dcbc8fa --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/order_by_columns.py @@ -0,0 +1,20 @@ +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns( + requesting_actor_market_role: MarketRole, +) -> list[str]: + order_by_column_names = [ + CsvColumnNames.grid_area_code_in_metering_points_csv, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.metering_point_from_date, + ] + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + order_by_column_names.insert(1, CsvColumnNames.energy_supplier_id) + + return order_by_column_names diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/prepare_for_csv.py new file mode 100644 index 0000000..f888429 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/prepare_for_csv.py @@ -0,0 +1,92 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.map_to_csv_naming import ( + METERING_POINT_TYPES, + SETTLEMENT_METHODS, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + metering_point_periods: DataFrame, + requesting_actor_market_role: MarketRole, +) -> DataFrame: + + columns = [ + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + F.col(DataProductColumnNames.metering_point_id).alias( + CsvColumnNames.metering_point_id + ), + F.col(DataProductColumnNames.from_date).alias( + CsvColumnNames.metering_point_from_date + ), + F.col(DataProductColumnNames.to_date).alias( + CsvColumnNames.metering_point_to_date + ), + F.col(DataProductColumnNames.grid_area_code).alias( + CsvColumnNames.grid_area_code_in_metering_points_csv + ), + map_from_dict(METERING_POINT_TYPES)[ + F.col(DataProductColumnNames.metering_point_type) + ].alias(CsvColumnNames.metering_point_type), + map_from_dict(SETTLEMENT_METHODS)[ + F.col(DataProductColumnNames.settlement_method) + ].alias(CsvColumnNames.settlement_method), + ] + if requesting_actor_market_role is MarketRole.GRID_ACCESS_PROVIDER: + columns.insert( + 5, + F.col(DataProductColumnNames.to_grid_area_code).alias( + CsvColumnNames.to_grid_area_code + ), + ) + columns.insert( + 6, + F.col(DataProductColumnNames.from_grid_area_code).alias( + CsvColumnNames.from_grid_area_code + ), + ) + + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + columns.append( + F.col(DataProductColumnNames.energy_supplier_id).alias( + CsvColumnNames.energy_supplier_id + ) + ) + + csv_df = metering_point_periods.select(columns) + + return csv_df diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/read_and_filter_balance_fixing.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/read_and_filter_balance_fixing.py new file mode 100644 index 0000000..28f6501 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/read_and_filter_balance_fixing.py @@ -0,0 +1,113 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span + +from settlement_report_job.domain.utils.factory_filters import ( + read_and_filter_by_latest_calculations, +) +from settlement_report_job.domain.utils.merge_periods import ( + merge_connected_periods, +) +from settlement_report_job.domain.metering_point_periods.clamp_period import ( + clamp_to_selected_period, +) +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.repository_filtering import ( + read_filtered_metering_point_periods_by_grid_area_codes, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +log = Logger(__name__) + + +@use_span() +def read_and_filter( + period_start: datetime, + period_end: datetime, + grid_area_codes: list[str], + energy_supplier_ids: list[str] | None, + select_columns: list[str], + time_zone: str, + repository: WholesaleRepository, +) -> DataFrame: + + metering_point_periods = read_filtered_metering_point_periods_by_grid_area_codes( + repository=repository, + period_start=period_start, + period_end=period_end, + grid_area_codes=grid_area_codes, + energy_supplier_ids=energy_supplier_ids, + ) + + metering_point_periods_daily = _explode_into_daily_period( + metering_point_periods, time_zone + ) + + metering_point_periods_from_latest_calculations = ( + read_and_filter_by_latest_calculations( + df=metering_point_periods_daily, + grid_area_codes=grid_area_codes, + period_start=period_start, + period_end=period_end, + time_zone=time_zone, + time_column_name=DataProductColumnNames.from_date, + repository=repository, + ) + ) + + metering_point_periods_from_latest_calculations = ( + metering_point_periods_from_latest_calculations.select(*select_columns) + ) + + metering_point_periods_from_latest_calculations = merge_connected_periods( + metering_point_periods_from_latest_calculations + ) + + metering_point_periods_from_latest_calculations = clamp_to_selected_period( + metering_point_periods_from_latest_calculations, period_start, period_end + ) + + return metering_point_periods_from_latest_calculations + + +def _explode_into_daily_period(df: DataFrame, time_zone: str) -> DataFrame: + df = df.withColumn( + "local_daily_from_date", + F.explode( + F.sequence( + F.from_utc_timestamp(DataProductColumnNames.from_date, time_zone), + F.date_sub( + F.from_utc_timestamp(DataProductColumnNames.to_date, time_zone), 1 + ), + F.expr("interval 1 day"), + ) + ), + ) + df = df.withColumn("local_daily_to_date", F.date_add("local_daily_from_date", 1)) + + df = df.withColumn( + DataProductColumnNames.from_date, + F.to_utc_timestamp("local_daily_from_date", time_zone), + ).withColumn( + DataProductColumnNames.to_date, + F.to_utc_timestamp("local_daily_to_date", time_zone), + ) + + return df diff --git a/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/read_and_filter_wholesale.py b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/read_and_filter_wholesale.py new file mode 100644 index 0000000..d7a7be0 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/metering_point_periods/read_and_filter_wholesale.py @@ -0,0 +1,101 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime +from uuid import UUID + +from pyspark.sql import DataFrame + +from telemetry_logging import Logger, use_span + +from settlement_report_job.domain.utils.join_metering_points_periods_and_charge_link_periods import ( + join_metering_points_periods_and_charge_link_periods, +) +from settlement_report_job.domain.utils.merge_periods import ( + merge_connected_periods, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.metering_point_periods.clamp_period import ( + clamp_to_selected_period, +) +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.repository_filtering import ( + read_metering_point_periods_by_calculation_ids, + read_charge_link_periods, +) + +log = Logger(__name__) + + +@use_span() +def read_and_filter( + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + select_columns: list[str], + repository: WholesaleRepository, +) -> DataFrame: + + metering_point_periods = read_metering_point_periods_by_calculation_ids( + repository=repository, + period_start=period_start, + period_end=period_end, + calculation_id_by_grid_area=calculation_id_by_grid_area, + energy_supplier_ids=energy_supplier_ids, + ) + + if requesting_actor_market_role == MarketRole.SYSTEM_OPERATOR: + metering_point_periods = _filter_by_charge_owner( + metering_point_periods=metering_point_periods, + period_start=period_start, + period_end=period_end, + requesting_actor_market_role=requesting_actor_market_role, + requesting_actor_id=requesting_actor_id, + repository=repository, + ) + + metering_point_periods = metering_point_periods.select(*select_columns) + + metering_point_periods = merge_connected_periods(metering_point_periods) + + metering_point_periods = clamp_to_selected_period( + metering_point_periods, period_start, period_end + ) + + return metering_point_periods + + +def _filter_by_charge_owner( + metering_point_periods: DataFrame, + period_start: datetime, + period_end: datetime, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + repository: WholesaleRepository, +) -> DataFrame: + charge_link_periods = read_charge_link_periods( + repository=repository, + period_start=period_start, + period_end=period_end, + charge_owner_id=requesting_actor_id, + requesting_actor_market_role=requesting_actor_market_role, + ) + metering_point_periods = join_metering_points_periods_and_charge_link_periods( + charge_link_periods=charge_link_periods, + metering_point_periods=metering_point_periods, + ) + + return metering_point_periods diff --git a/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/__init__.py b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/monthly_amounts_factory.py b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/monthly_amounts_factory.py new file mode 100644 index 0000000..9c2669d --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/monthly_amounts_factory.py @@ -0,0 +1,42 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame + +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + +from settlement_report_job.domain.monthly_amounts.read_and_filter import ( + read_and_filter_from_view, +) +from settlement_report_job.domain.monthly_amounts.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.domain.utils.settlement_report_args_utils import ( + should_have_result_file_per_grid_area, +) + + +def create_monthly_amounts( + args: SettlementReportArgs, + repository: WholesaleRepository, +) -> DataFrame: + monthly_amounts = read_and_filter_from_view(args, repository) + + return prepare_for_csv( + monthly_amounts, + should_have_result_file_per_grid_area(args=args), + ) diff --git a/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/order_by_columns.py new file mode 100644 index 0000000..7c7a7fe --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/order_by_columns.py @@ -0,0 +1,20 @@ +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns(requesting_actor_market_role: MarketRole) -> list[str]: + order_by_column_names = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.resolution, + ] + + if requesting_actor_market_role not in [ + MarketRole.GRID_ACCESS_PROVIDER, + MarketRole.SYSTEM_OPERATOR, + ]: + order_by_column_names.insert(2, CsvColumnNames.charge_owner_id) + + return order_by_column_names diff --git a/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/prepare_for_csv.py new file mode 100644 index 0000000..4a53369 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/prepare_for_csv.py @@ -0,0 +1,72 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +import settlement_report_job.domain.utils.map_to_csv_naming as market_naming + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + monthly_amounts: DataFrame, + create_ephemeral_grid_area_column: bool, +) -> DataFrame: + select_columns = [ + map_from_dict(market_naming.CALCULATION_TYPES_TO_ENERGY_BUSINESS_PROCESS)[ + F.col(DataProductColumnNames.calculation_type) + ].alias(CsvColumnNames.calculation_type), + map_from_dict(market_naming.CALCULATION_TYPES_TO_PROCESS_VARIANT)[ + F.col(DataProductColumnNames.calculation_type) + ].alias(CsvColumnNames.correction_settlement_number), + F.col(DataProductColumnNames.grid_area_code).alias( + CsvColumnNames.grid_area_code + ), + F.col(DataProductColumnNames.energy_supplier_id).alias( + CsvColumnNames.energy_supplier_id + ), + F.col(DataProductColumnNames.time).alias(CsvColumnNames.time), + F.col(DataProductColumnNames.resolution).alias(CsvColumnNames.resolution), + F.col(DataProductColumnNames.quantity_unit).alias(CsvColumnNames.quantity_unit), + F.col(DataProductColumnNames.currency).alias(CsvColumnNames.currency), + F.col(DataProductColumnNames.amount).alias(CsvColumnNames.amount), + map_from_dict(market_naming.CHARGE_TYPES)[ + F.col(DataProductColumnNames.charge_type) + ].alias(CsvColumnNames.charge_type), + F.col(DataProductColumnNames.charge_code).alias(CsvColumnNames.charge_code), + F.col(DataProductColumnNames.charge_owner_id).alias( + CsvColumnNames.charge_owner_id + ), + ] + + if create_ephemeral_grid_area_column: + select_columns.append( + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + ) + + return monthly_amounts.select(select_columns) diff --git a/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/read_and_filter.py b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/read_and_filter.py new file mode 100644 index 0000000..e3abf58 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/monthly_amounts/read_and_filter.py @@ -0,0 +1,134 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pyspark.sql import DataFrame +from pyspark.sql.functions import lit, col + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.factory_filters import ( + filter_by_calculation_id_by_grid_area, + filter_by_charge_owner_and_tax_depending_on_market_role, + filter_by_energy_supplier_ids, +) + +log = Logger(__name__) + + +@use_span() +def read_and_filter_from_view( + args: SettlementReportArgs, repository: WholesaleRepository +) -> DataFrame: + monthly_amounts_per_charge = repository.read_monthly_amounts_per_charge_v1() + monthly_amounts_per_charge = _filter_monthly_amounts_per_charge( + monthly_amounts_per_charge, args + ) + + total_monthly_amounts = repository.read_total_monthly_amounts_v1() + total_monthly_amounts = _filter_total_monthly_amounts(total_monthly_amounts, args) + total_monthly_amounts = _prepare_total_monthly_amounts_columns_for_union( + total_monthly_amounts, + monthly_amounts_per_charge_column_ordering=monthly_amounts_per_charge.columns, + ) + + monthly_amounts = monthly_amounts_per_charge.union(total_monthly_amounts) + monthly_amounts = monthly_amounts.withColumn( + DataProductColumnNames.resolution, lit("P1M") + ) + + return monthly_amounts + + +def _apply_shared_filters( + monthly_amounts: DataFrame, args: SettlementReportArgs +) -> DataFrame: + monthly_amounts = monthly_amounts.where( + (col(DataProductColumnNames.time) >= args.period_start) + & (col(DataProductColumnNames.time) < args.period_end) + ) + if args.calculation_id_by_grid_area: + # Can never be null, but mypy requires it be specified + monthly_amounts = monthly_amounts.where( + filter_by_calculation_id_by_grid_area(args.calculation_id_by_grid_area) + ) + if args.energy_supplier_ids: + monthly_amounts = monthly_amounts.where( + filter_by_energy_supplier_ids(args.energy_supplier_ids) + ) + return monthly_amounts + + +def _filter_monthly_amounts_per_charge( + monthly_amounts_per_charge: DataFrame, args: SettlementReportArgs +) -> DataFrame: + monthly_amounts_per_charge = _apply_shared_filters(monthly_amounts_per_charge, args) + + monthly_amounts_per_charge = ( + filter_by_charge_owner_and_tax_depending_on_market_role( + monthly_amounts_per_charge, + args.requesting_actor_market_role, + args.requesting_actor_id, + ) + ) + + return monthly_amounts_per_charge + + +def _filter_total_monthly_amounts( + total_monthly_amounts: DataFrame, args: SettlementReportArgs +) -> DataFrame: + total_monthly_amounts = _apply_shared_filters(total_monthly_amounts, args) + + if args.requesting_actor_market_role in [ + MarketRole.ENERGY_SUPPLIER, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + total_monthly_amounts = total_monthly_amounts.where( + col(DataProductColumnNames.charge_owner_id).isNull() + ) + elif args.requesting_actor_market_role in [ + MarketRole.GRID_ACCESS_PROVIDER, + MarketRole.SYSTEM_OPERATOR, + ]: + total_monthly_amounts = total_monthly_amounts.where( + col(DataProductColumnNames.charge_owner_id) == args.requesting_actor_id + ) + return total_monthly_amounts + + +def _prepare_total_monthly_amounts_columns_for_union( + base_total_monthly_amounts: DataFrame, + monthly_amounts_per_charge_column_ordering: list[str], +) -> DataFrame: + for null_column in [ + DataProductColumnNames.quantity_unit, + DataProductColumnNames.charge_type, + DataProductColumnNames.charge_code, + DataProductColumnNames.is_tax, + DataProductColumnNames.charge_owner_id, + # charge_owner_id is not always null, but it should not be part of total_monthly_amounts + ]: + base_total_monthly_amounts = base_total_monthly_amounts.withColumn( + null_column, lit(None) + ) + + return base_total_monthly_amounts.select( + monthly_amounts_per_charge_column_ordering, + ) diff --git a/source/settlement_report_python/settlement_report_job/domain/time_series_points/__init__.py b/source/settlement_report_python/settlement_report_job/domain/time_series_points/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/time_series_points/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/time_series_points/order_by_columns.py new file mode 100644 index 0000000..67aa1ed --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/time_series_points/order_by_columns.py @@ -0,0 +1,17 @@ +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns(requesting_actor_market_role: MarketRole) -> list[str]: + order_by_column_names = [ + CsvColumnNames.metering_point_type, + CsvColumnNames.metering_point_id, + CsvColumnNames.time, + ] + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + order_by_column_names.insert(0, CsvColumnNames.energy_supplier_id) + + return order_by_column_names diff --git a/source/settlement_report_python/settlement_report_job/domain/time_series_points/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/time_series_points/prepare_for_csv.py new file mode 100644 index 0000000..a7553b2 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/time_series_points/prepare_for_csv.py @@ -0,0 +1,121 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F, Window + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.get_start_of_day import get_start_of_day +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.map_to_csv_naming import ( + METERING_POINT_TYPES, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, +) + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + filtered_time_series_points: DataFrame, + metering_point_resolution: MeteringPointResolutionDataProductValue, + time_zone: str, + requesting_actor_market_role: MarketRole, +) -> DataFrame: + desired_number_of_quantity_columns = _get_desired_quantity_column_count( + metering_point_resolution + ) + + filtered_time_series_points = filtered_time_series_points.withColumn( + CsvColumnNames.time, + get_start_of_day(DataProductColumnNames.observation_time, time_zone), + ) + + win = Window.partitionBy( + DataProductColumnNames.grid_area_code, + DataProductColumnNames.energy_supplier_id, + DataProductColumnNames.metering_point_id, + DataProductColumnNames.metering_point_type, + CsvColumnNames.time, + ).orderBy(DataProductColumnNames.observation_time) + filtered_time_series_points = filtered_time_series_points.withColumn( + "chronological_order", F.row_number().over(win) + ) + + pivoted_df = ( + filtered_time_series_points.groupBy( + DataProductColumnNames.grid_area_code, + DataProductColumnNames.energy_supplier_id, + DataProductColumnNames.metering_point_id, + DataProductColumnNames.metering_point_type, + CsvColumnNames.time, + ) + .pivot( + "chronological_order", + list(range(1, desired_number_of_quantity_columns + 1)), + ) + .agg(F.first(DataProductColumnNames.quantity)) + ) + + quantity_column_names = [ + F.col(str(i)).alias(f"{CsvColumnNames.energy_quantity}{i}") + for i in range(1, desired_number_of_quantity_columns + 1) + ] + + csv_df = pivoted_df.select( + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + F.col(DataProductColumnNames.energy_supplier_id).alias( + CsvColumnNames.energy_supplier_id + ), + F.col(DataProductColumnNames.metering_point_id).alias( + CsvColumnNames.metering_point_id + ), + map_from_dict(METERING_POINT_TYPES)[ + F.col(DataProductColumnNames.metering_point_type) + ].alias(CsvColumnNames.metering_point_type), + F.col(CsvColumnNames.time), + *quantity_column_names, + ) + + if requesting_actor_market_role in [ + MarketRole.GRID_ACCESS_PROVIDER, + MarketRole.ENERGY_SUPPLIER, + ]: + csv_df = csv_df.drop(CsvColumnNames.energy_supplier_id) + + return csv_df + + +def _get_desired_quantity_column_count( + resolution: MeteringPointResolutionDataProductValue, +) -> int: + if resolution == MeteringPointResolutionDataProductValue.HOUR: + return 25 + elif resolution == MeteringPointResolutionDataProductValue.QUARTER: + return 25 * 4 + else: + raise ValueError(f"Unknown time series resolution: {resolution}") diff --git a/source/settlement_report_python/settlement_report_job/domain/time_series_points/read_and_filter.py b/source/settlement_report_python/settlement_report_job/domain/time_series_points/read_and_filter.py new file mode 100644 index 0000000..f06d729 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/time_series_points/read_and_filter.py @@ -0,0 +1,127 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime +from uuid import UUID + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.time_series_points.system_operator_filter import ( + filter_time_series_points_on_charge_owner, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, +) +from settlement_report_job.domain.utils.factory_filters import ( + filter_by_energy_supplier_ids, + filter_by_calculation_id_by_grid_area, + read_and_filter_by_latest_calculations, +) + +log = Logger(__name__) + + +@use_span() +def read_and_filter_for_balance_fixing( + period_start: datetime, + period_end: datetime, + grid_area_codes: list[str], + energy_supplier_ids: list[str] | None, + metering_point_resolution: MeteringPointResolutionDataProductValue, + time_zone: str, + repository: WholesaleRepository, +) -> DataFrame: + log.info("Creating time series points") + time_series_points = _read_from_view( + period_start, + period_end, + metering_point_resolution, + energy_supplier_ids, + repository, + ) + + time_series_points = read_and_filter_by_latest_calculations( + df=time_series_points, + repository=repository, + grid_area_codes=grid_area_codes, + period_start=period_start, + period_end=period_end, + time_zone=time_zone, + time_column_name=DataProductColumnNames.observation_time, + ) + + return time_series_points + + +@use_span() +def read_and_filter_for_wholesale( + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + metering_point_resolution: MeteringPointResolutionDataProductValue, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + repository: WholesaleRepository, +) -> DataFrame: + log.info("Creating time series points") + + time_series_points = _read_from_view( + period_start=period_start, + period_end=period_end, + resolution=metering_point_resolution, + energy_supplier_ids=energy_supplier_ids, + repository=repository, + ) + + time_series_points = time_series_points.where( + filter_by_calculation_id_by_grid_area(calculation_id_by_grid_area) + ) + + if requesting_actor_market_role is MarketRole.SYSTEM_OPERATOR: + time_series_points = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points, + system_operator_id=requesting_actor_id, + charge_link_periods=repository.read_charge_link_periods(), + charge_price_information_periods=repository.read_charge_price_information_periods(), + ) + + return time_series_points + + +@use_span() +def _read_from_view( + period_start: datetime, + period_end: datetime, + resolution: MeteringPointResolutionDataProductValue, + energy_supplier_ids: list[str] | None, + repository: WholesaleRepository, +) -> DataFrame: + time_series_points = repository.read_metering_point_time_series().where( + (F.col(DataProductColumnNames.observation_time) >= period_start) + & (F.col(DataProductColumnNames.observation_time) < period_end) + & (F.col(DataProductColumnNames.resolution) == resolution.value) + ) + + if energy_supplier_ids: + time_series_points = time_series_points.where( + filter_by_energy_supplier_ids(energy_supplier_ids) + ) + + return time_series_points diff --git a/source/settlement_report_python/settlement_report_job/domain/time_series_points/system_operator_filter.py b/source/settlement_report_python/settlement_report_job/domain/time_series_points/system_operator_filter.py new file mode 100644 index 0000000..eb047b6 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/time_series_points/system_operator_filter.py @@ -0,0 +1,49 @@ +from pyspark.sql import DataFrame, functions as F + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +def filter_time_series_points_on_charge_owner( + time_series_points: DataFrame, + system_operator_id: str, + charge_link_periods: DataFrame, + charge_price_information_periods: DataFrame, +) -> DataFrame: + """ + Filters away all time series data that is not related to the system operator, and which is not a tax. + """ + + charge_price_information_periods = charge_price_information_periods.where( + (~F.col(DataProductColumnNames.is_tax)) + & (F.col(DataProductColumnNames.charge_owner_id) == system_operator_id) + ) + + filtered_charge_link_periods = charge_link_periods.join( + charge_price_information_periods, + on=[DataProductColumnNames.calculation_id, DataProductColumnNames.charge_key], + how="inner", + ).select( + charge_link_periods[DataProductColumnNames.calculation_id], + charge_link_periods[DataProductColumnNames.metering_point_id], + charge_link_periods[DataProductColumnNames.from_date], + charge_link_periods[DataProductColumnNames.to_date], + ) + + filtered_df = time_series_points.join( + filtered_charge_link_periods, + on=[ + time_series_points[DataProductColumnNames.calculation_id] + == filtered_charge_link_periods[DataProductColumnNames.calculation_id], + time_series_points[DataProductColumnNames.metering_point_id] + == filtered_charge_link_periods[DataProductColumnNames.metering_point_id], + F.col(DataProductColumnNames.observation_time) + >= F.col(DataProductColumnNames.from_date), + F.col(DataProductColumnNames.observation_time) + < F.col(DataProductColumnNames.to_date), + ], + how="leftsemi", + ) + + return filtered_df diff --git a/source/settlement_report_python/settlement_report_job/domain/time_series_points/time_series_points_factory.py b/source/settlement_report_python/settlement_report_job/domain/time_series_points/time_series_points_factory.py new file mode 100644 index 0000000..defc502 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/time_series_points/time_series_points_factory.py @@ -0,0 +1,99 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime +from uuid import UUID + +from pyspark.sql import DataFrame + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.time_series_points.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.domain.time_series_points.read_and_filter import ( + read_and_filter_for_wholesale, + read_and_filter_for_balance_fixing, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, +) + +log = Logger(__name__) + + +@use_span() +def create_time_series_points_for_balance_fixing( + period_start: datetime, + period_end: datetime, + grid_area_codes: list[str], + energy_supplier_ids: list[str] | None, + metering_point_resolution: MeteringPointResolutionDataProductValue, + time_zone: str, + requesting_market_role: MarketRole, + repository: WholesaleRepository, +) -> DataFrame: + log.info("Creating time series points") + + time_series_points = read_and_filter_for_balance_fixing( + period_start=period_start, + period_end=period_end, + grid_area_codes=grid_area_codes, + energy_supplier_ids=energy_supplier_ids, + metering_point_resolution=metering_point_resolution, + time_zone=time_zone, + repository=repository, + ) + + prepared_time_series_points = prepare_for_csv( + filtered_time_series_points=time_series_points, + metering_point_resolution=metering_point_resolution, + time_zone=time_zone, + requesting_actor_market_role=requesting_market_role, + ) + return prepared_time_series_points + + +@use_span() +def create_time_series_points_for_wholesale( + period_start: datetime, + period_end: datetime, + calculation_id_by_grid_area: dict[str, UUID], + energy_supplier_ids: list[str] | None, + metering_point_resolution: MeteringPointResolutionDataProductValue, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + time_zone: str, + repository: WholesaleRepository, +) -> DataFrame: + log.info("Creating time series points") + + time_series_points = read_and_filter_for_wholesale( + period_start=period_start, + period_end=period_end, + calculation_id_by_grid_area=calculation_id_by_grid_area, + energy_supplier_ids=energy_supplier_ids, + metering_point_resolution=metering_point_resolution, + requesting_actor_market_role=requesting_actor_market_role, + requesting_actor_id=requesting_actor_id, + repository=repository, + ) + + prepared_time_series_points = prepare_for_csv( + filtered_time_series_points=time_series_points, + metering_point_resolution=metering_point_resolution, + time_zone=time_zone, + requesting_actor_market_role=requesting_actor_market_role, + ) + return prepared_time_series_points diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/__init__.py b/source/settlement_report_python/settlement_report_job/domain/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/csv_column_names.py b/source/settlement_report_python/settlement_report_job/domain/utils/csv_column_names.py new file mode 100644 index 0000000..00aade7 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/csv_column_names.py @@ -0,0 +1,37 @@ +class CsvColumnNames: + amount = "AMOUNT" + calculation_type = "ENERGYBUSINESSPROCESS" + charge_code = "CHARGEID" + charge_link_from_date = "PERIODSTART" + charge_link_to_date = "PERIODEND" + charge_quantity = "CHARGEOCCURRENCES" + charge_owner_id = "CHARGEOWNER" + charge_type = "CHARGETYPE" + correction_settlement_number = "PROCESSVARIANT" + currency = "ENERGYCURRENCY" + energy_quantity = "ENERGYQUANTITY" + energy_supplier_id = "ENERGYSUPPLIERID" + from_grid_area_code = "FROMGRIDAREAID" + grid_area_code = "METERINGGRIDAREAID" + grid_area_code_in_metering_points_csv = "GRIDAREAID" + metering_point_from_date = "VALIDFROM" + metering_point_id = "METERINGPOINTID" + metering_point_to_date = "VALIDTO" + metering_point_type = "TYPEOFMP" + price = "PRICE" + quantity_unit = "MEASUREUNIT" + resolution = "RESOLUTIONDURATION" + settlement_method = "SETTLEMENTMETHOD" + time = "STARTDATETIME" + to_grid_area_code = "TOGRIDAREAID" + is_tax = "TAXINDICATOR" + energy_price = "ENERGYPRICE" + + +class EphemeralColumns: + # Columns that are added to the DataFrame for processing + # but not part of the input or output schema. + + grid_area_code_partitioning = "grid_area_code_partitioning" + chunk_index = "chunk_index_partition" + start_of_day = "start_of_day" diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/factory_filters.py b/source/settlement_report_python/settlement_report_job/domain/utils/factory_filters.py new file mode 100644 index 0000000..11f37ab --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/factory_filters.py @@ -0,0 +1,124 @@ +from datetime import datetime +from uuid import UUID +from pyspark.sql import DataFrame, Column, functions as F + +from telemetry_logging import Logger +from settlement_report_job.domain.utils.csv_column_names import EphemeralColumns +from settlement_report_job.domain.utils.get_start_of_day import ( + get_start_of_day, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.repository import ( + WholesaleRepository, +) +from settlement_report_job.infrastructure.wholesale.data_values.calculation_type import ( + CalculationTypeDataProductValue, +) + +log = Logger(__name__) + + +def read_and_filter_by_latest_calculations( + df: DataFrame, + repository: WholesaleRepository, + grid_area_codes: list[str], + period_start: datetime, + period_end: datetime, + time_zone: str, + time_column_name: str | Column, +) -> DataFrame: + latest_balance_fixing_calculations = repository.read_latest_calculations().where( + ( + F.col(DataProductColumnNames.calculation_type) + == CalculationTypeDataProductValue.BALANCE_FIXING.value + ) + & (F.col(DataProductColumnNames.grid_area_code).isin(grid_area_codes)) + & (F.col(DataProductColumnNames.start_of_day) >= period_start) + & (F.col(DataProductColumnNames.start_of_day) < period_end) + ) + + df = filter_by_latest_calculations( + df, + latest_balance_fixing_calculations, + df_time_column=time_column_name, + time_zone=time_zone, + ) + + return df + + +def filter_by_latest_calculations( + df: DataFrame, + latest_calculations: DataFrame, + df_time_column: str | Column, + time_zone: str, +) -> DataFrame: + df = df.withColumn( + EphemeralColumns.start_of_day, + get_start_of_day(df_time_column, time_zone), + ) + + return ( + df.join( + latest_calculations, + on=[ + df[DataProductColumnNames.calculation_id] + == latest_calculations[DataProductColumnNames.calculation_id], + df[DataProductColumnNames.grid_area_code] + == latest_calculations[DataProductColumnNames.grid_area_code], + df[EphemeralColumns.start_of_day] + == latest_calculations[DataProductColumnNames.start_of_day], + ], + how="inner", + ) + .select(df["*"]) + .drop(EphemeralColumns.start_of_day) + ) + + +def filter_by_calculation_id_by_grid_area( + calculation_id_by_grid_area: dict[str, UUID], +) -> Column: + calculation_id_by_grid_area_structs = [ + F.struct(F.lit(grid_area_code), F.lit(str(calculation_id))) + for grid_area_code, calculation_id in calculation_id_by_grid_area.items() + ] + + return F.struct( + F.col(DataProductColumnNames.grid_area_code), + F.col(DataProductColumnNames.calculation_id), + ).isin(calculation_id_by_grid_area_structs) + + +def filter_by_energy_supplier_ids(energy_supplier_ids: list[str]) -> Column: + return F.col(DataProductColumnNames.energy_supplier_id).isin(energy_supplier_ids) + + +def filter_by_grid_area_codes(grid_area_codes: list[str]) -> Column: + return F.col(DataProductColumnNames.grid_area_code).isin(grid_area_codes) + + +def filter_by_charge_owner_and_tax_depending_on_market_role( + df: DataFrame, + requesting_actor_market_role: MarketRole, + charge_owner_id: str, +) -> DataFrame: + if requesting_actor_market_role == MarketRole.SYSTEM_OPERATOR: + df = df.where( + (F.col(DataProductColumnNames.charge_owner_id) == charge_owner_id) + & (~F.col(DataProductColumnNames.is_tax)) + ) + + if requesting_actor_market_role == MarketRole.GRID_ACCESS_PROVIDER: + df = df.where( + ( + (F.col(DataProductColumnNames.charge_owner_id) == charge_owner_id) + & (~F.col(DataProductColumnNames.is_tax)) + ) + | (F.col(DataProductColumnNames.is_tax)) + ) + + return df diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/get_start_of_day.py b/source/settlement_report_python/settlement_report_job/domain/utils/get_start_of_day.py new file mode 100644 index 0000000..1f6a2a5 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/get_start_of_day.py @@ -0,0 +1,9 @@ +from pyspark.sql import Column +from pyspark.sql import functions as F + + +def get_start_of_day(col: Column | str, time_zone: str) -> Column: + col = F.col(col) if isinstance(col, str) else col + return F.to_utc_timestamp( + F.date_trunc("DAY", F.from_utc_timestamp(col, time_zone)), time_zone + ) diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/join_metering_points_periods_and_charge_link_periods.py b/source/settlement_report_python/settlement_report_job/domain/utils/join_metering_points_periods_and_charge_link_periods.py new file mode 100644 index 0000000..f3a9301 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/join_metering_points_periods_and_charge_link_periods.py @@ -0,0 +1,52 @@ +from pyspark.sql import functions as F, DataFrame + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +def join_metering_points_periods_and_charge_link_periods( + charge_link_periods: DataFrame, + metering_point_periods: DataFrame, +) -> DataFrame: + """ + Joins metering point periods and charge link periods and returns the joined DataFrame. + Periods are joined on calculation_id and metering_point_id. + The output DataFrame will contain the columns from_date and to_date, which are the intersection of the periods. + """ + link_from_date = "link_from_date" + link_to_date = "link_to_date" + metering_point_from_date = "metering_point_from_date" + metering_point_to_date = "metering_point_to_date" + + charge_link_periods = charge_link_periods.withColumnRenamed( + DataProductColumnNames.from_date, link_from_date + ).withColumnRenamed(DataProductColumnNames.to_date, link_to_date) + metering_point_periods = metering_point_periods.withColumnRenamed( + DataProductColumnNames.from_date, metering_point_from_date + ).withColumnRenamed(DataProductColumnNames.to_date, metering_point_to_date) + + joined = ( + metering_point_periods.join( + charge_link_periods, + on=[ + DataProductColumnNames.calculation_id, + DataProductColumnNames.metering_point_id, + ], + how="inner", + ) + .where( + (F.col(link_from_date) < F.col(metering_point_to_date)) + & (F.col(link_to_date) > F.col(metering_point_from_date)) + ) + .withColumn( + DataProductColumnNames.from_date, + F.greatest(F.col(link_from_date), F.col(metering_point_from_date)), + ) + .withColumn( + DataProductColumnNames.to_date, + F.least(F.col(link_to_date), F.col(metering_point_to_date)), + ) + ) + + return joined diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/map_from_dict.py b/source/settlement_report_python/settlement_report_job/domain/utils/map_from_dict.py new file mode 100644 index 0000000..4585dbb --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/map_from_dict.py @@ -0,0 +1,29 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools + +from pyspark.sql import Column +from pyspark.sql import functions as F + + +def map_from_dict(d: dict) -> Column: + """Converts a dictionary to a Spark map column + + Args: + d (dict): Dictionary to convert to a Spark map column + + Returns: + Column: Spark map column + """ + return F.create_map([F.lit(x) for x in itertools.chain(*d.items())]) diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/map_to_csv_naming.py b/source/settlement_report_python/settlement_report_job/domain/utils/map_to_csv_naming.py new file mode 100644 index 0000000..6016a2f --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/map_to_csv_naming.py @@ -0,0 +1,76 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from settlement_report_job.infrastructure.wholesale.data_values import ( + ChargeTypeDataProductValue, + CalculationTypeDataProductValue, + MeteringPointResolutionDataProductValue, + MeteringPointTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.data_values.settlement_method import ( + SettlementMethodDataProductValue, +) + +METERING_POINT_TYPES = { + MeteringPointTypeDataProductValue.VE_PRODUCTION.value: "D01", + MeteringPointTypeDataProductValue.NET_PRODUCTION.value: "D05", + MeteringPointTypeDataProductValue.SUPPLY_TO_GRID.value: "D06", + MeteringPointTypeDataProductValue.CONSUMPTION_FROM_GRID.value: "D07", + MeteringPointTypeDataProductValue.WHOLESALE_SERVICES_INFORMATION.value: "D08", + MeteringPointTypeDataProductValue.OWN_PRODUCTION.value: "D09", + MeteringPointTypeDataProductValue.NET_FROM_GRID.value: "D10", + MeteringPointTypeDataProductValue.NET_TO_GRID.value: "D11", + MeteringPointTypeDataProductValue.TOTAL_CONSUMPTION.value: "D12", + MeteringPointTypeDataProductValue.ELECTRICAL_HEATING.value: "D14", + MeteringPointTypeDataProductValue.NET_CONSUMPTION.value: "D15", + MeteringPointTypeDataProductValue.EFFECT_SETTLEMENT.value: "D19", + MeteringPointTypeDataProductValue.CONSUMPTION.value: "E17", + MeteringPointTypeDataProductValue.PRODUCTION.value: "E18", + MeteringPointTypeDataProductValue.EXCHANGE.value: "E20", +} + +SETTLEMENT_METHODS = { + SettlementMethodDataProductValue.NON_PROFILED.value: "E02", + SettlementMethodDataProductValue.FLEX.value: "D01", +} + +CALCULATION_TYPES_TO_ENERGY_BUSINESS_PROCESS = { + CalculationTypeDataProductValue.BALANCE_FIXING.value: "D04", + CalculationTypeDataProductValue.WHOLESALE_FIXING.value: "D05", + CalculationTypeDataProductValue.FIRST_CORRECTION_SETTLEMENT.value: "D32", + CalculationTypeDataProductValue.SECOND_CORRECTION_SETTLEMENT.value: "D32", + CalculationTypeDataProductValue.THIRD_CORRECTION_SETTLEMENT.value: "D32", +} + +CALCULATION_TYPES_TO_PROCESS_VARIANT = { + CalculationTypeDataProductValue.FIRST_CORRECTION_SETTLEMENT.value: "1ST", + CalculationTypeDataProductValue.SECOND_CORRECTION_SETTLEMENT.value: "2ND", + CalculationTypeDataProductValue.THIRD_CORRECTION_SETTLEMENT.value: "3RD", +} + + +RESOLUTION_NAMES = { + MeteringPointResolutionDataProductValue.HOUR.value: "TSSD60", + MeteringPointResolutionDataProductValue.QUARTER.value: "TSSD15", +} + +CHARGE_TYPES = { + ChargeTypeDataProductValue.SUBSCRIPTION.value: "D01", + ChargeTypeDataProductValue.FEE.value: "D02", + ChargeTypeDataProductValue.TARIFF.value: "D03", +} + +TAX_INDICATORS = { + True: 1, + False: 0, +} diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/market_role.py b/source/settlement_report_python/settlement_report_job/domain/utils/market_role.py new file mode 100644 index 0000000..44f12ca --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/market_role.py @@ -0,0 +1,27 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class MarketRole(Enum): + """ + The market role value affects what is included in the settlement report. + The 'market-role' command line argument must use one of these values. + """ + + DATAHUB_ADMINISTRATOR = "datahub_administrator" + ENERGY_SUPPLIER = "energy_supplier" + GRID_ACCESS_PROVIDER = "grid_access_provider" + SYSTEM_OPERATOR = "system_operator" diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/merge_periods.py b/source/settlement_report_python/settlement_report_job/domain/utils/merge_periods.py new file mode 100644 index 0000000..d136c6b --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/merge_periods.py @@ -0,0 +1,54 @@ +from pyspark.sql import DataFrame, functions as F, Window + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +def merge_connected_periods(df: DataFrame) -> DataFrame: + """ + Merges connected and/or overlapping periods within each group of rows in the input DataFrame. + Args: + df: a dataframe that contains any number of columns plus the columns 'from_date' and 'to_date' + + Returns: + A DataFrame with the same columns as the input DataFrame. + Rows that had overlapping/connected periods are merged into single rows. + + """ + other_columns = [ + col + for col in df.columns + if col not in [DataProductColumnNames.from_date, DataProductColumnNames.to_date] + ] + window_spec = Window.partitionBy(other_columns).orderBy( + DataProductColumnNames.from_date + ) + + # Add columns to identify overlapping periods + df_with_next = df.withColumn( + "next_from_date", F.lead(DataProductColumnNames.from_date).over(window_spec) + ).withColumn( + "next_to_date", F.lead(DataProductColumnNames.to_date).over(window_spec) + ) + + # Add a column to identify the start of a new group of connected periods + df_with_group = df_with_next.withColumn( + "group", + F.sum( + F.when( + F.col(DataProductColumnNames.from_date) + > F.lag(DataProductColumnNames.to_date).over(window_spec), + 1, + ).otherwise(0) + ).over(window_spec.rowsBetween(Window.unboundedPreceding, Window.currentRow)), + ) + + # Merge overlapping periods within each group + other_columns.append("group") + merged_df = df_with_group.groupBy(other_columns).agg( + F.min(DataProductColumnNames.from_date).alias(DataProductColumnNames.from_date), + F.max(DataProductColumnNames.to_date).alias(DataProductColumnNames.to_date), + ) + + return merged_df diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/report_data_type.py b/source/settlement_report_python/settlement_report_job/domain/utils/report_data_type.py new file mode 100644 index 0000000..0ffb2b9 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/report_data_type.py @@ -0,0 +1,31 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ReportDataType(Enum): + """ + Types of data that can be included in a settlement report. + Used to distinguish between different types of data in the report. + """ + + TimeSeriesHourly = 1 + TimeSeriesQuarterly = 2 + MeteringPointPeriods = 3 + ChargeLinks = 4 + EnergyResults = 5 + WholesaleResults = 6 + MonthlyAmounts = 7 + ChargePricePoints = 8 diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/repository_filtering.py b/source/settlement_report_python/settlement_report_job/domain/utils/repository_filtering.py new file mode 100644 index 0000000..364bc4e --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/repository_filtering.py @@ -0,0 +1,115 @@ +from datetime import datetime +from uuid import UUID + +from pyspark.sql import DataFrame, functions as F + +from settlement_report_job.domain.utils.factory_filters import ( + filter_by_calculation_id_by_grid_area, + filter_by_charge_owner_and_tax_depending_on_market_role, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +def read_metering_point_periods_by_calculation_ids( + repository: WholesaleRepository, + period_start: datetime, + period_end: datetime, + energy_supplier_ids: list[str] | None, + calculation_id_by_grid_area: dict[str, UUID], +) -> DataFrame: + metering_point_periods = repository.read_metering_point_periods().where( + (F.col(DataProductColumnNames.from_date) < period_end) + & (F.col(DataProductColumnNames.to_date) > period_start) + ) + + metering_point_periods = metering_point_periods.where( + filter_by_calculation_id_by_grid_area(calculation_id_by_grid_area) + ) + + if energy_supplier_ids is not None: + metering_point_periods = metering_point_periods.where( + F.col(DataProductColumnNames.energy_supplier_id).isin(energy_supplier_ids) + ) + + return metering_point_periods + + +def read_filtered_metering_point_periods_by_grid_area_codes( + repository: WholesaleRepository, + period_start: datetime, + period_end: datetime, + energy_supplier_ids: list[str] | None, + grid_area_codes: list[str], +) -> DataFrame: + metering_point_periods = repository.read_metering_point_periods().where( + (F.col(DataProductColumnNames.from_date) < period_end) + & (F.col(DataProductColumnNames.to_date) > period_start) + ) + + metering_point_periods = metering_point_periods.where( + F.col(DataProductColumnNames.grid_area_code).isin(grid_area_codes) + ) + + if energy_supplier_ids is not None: + metering_point_periods = metering_point_periods.where( + F.col(DataProductColumnNames.energy_supplier_id).isin(energy_supplier_ids) + ) + + return metering_point_periods + + +def read_charge_link_periods( + repository: WholesaleRepository, + period_start: datetime, + period_end: datetime, + charge_owner_id: str, + requesting_actor_market_role: MarketRole, +) -> DataFrame: + charge_link_periods = repository.read_charge_link_periods().where( + (F.col(DataProductColumnNames.from_date) < period_end) + & (F.col(DataProductColumnNames.to_date) > period_start) + ) + + if requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.GRID_ACCESS_PROVIDER, + ]: + charge_price_information_periods = ( + repository.read_charge_price_information_periods() + ) + + charge_link_periods = _filter_by_charge_owner_and_tax( + charge_link_periods=charge_link_periods, + charge_price_information_periods=charge_price_information_periods, + charge_owner_id=charge_owner_id, + requesting_actor_market_role=requesting_actor_market_role, + ) + + return charge_link_periods + + +def _filter_by_charge_owner_and_tax( + charge_link_periods: DataFrame, + charge_price_information_periods: DataFrame, + charge_owner_id: str, + requesting_actor_market_role: MarketRole, +) -> DataFrame: + charge_price_information_periods = ( + filter_by_charge_owner_and_tax_depending_on_market_role( + charge_price_information_periods, + requesting_actor_market_role, + charge_owner_id, + ) + ) + + charge_link_periods = charge_link_periods.join( + charge_price_information_periods, + on=[DataProductColumnNames.calculation_id, DataProductColumnNames.charge_key], + how="inner", + ).select(charge_link_periods["*"]) + + return charge_link_periods diff --git a/source/settlement_report_python/settlement_report_job/domain/utils/settlement_report_args_utils.py b/source/settlement_report_python/settlement_report_job/domain/utils/settlement_report_args_utils.py new file mode 100644 index 0000000..79e9690 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/utils/settlement_report_args_utils.py @@ -0,0 +1,35 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + + +def should_have_result_file_per_grid_area( + args: SettlementReportArgs, +) -> bool: + exactly_one_grid_area_from_calc_ids = ( + args.calculation_id_by_grid_area is not None + and len(args.calculation_id_by_grid_area) == 1 + ) + + exactly_one_grid_area_from_grid_area_codes = ( + args.grid_area_codes is not None and len(args.grid_area_codes) == 1 + ) + + return ( + exactly_one_grid_area_from_calc_ids + or exactly_one_grid_area_from_grid_area_codes + or args.split_report_by_grid_area + ) diff --git a/source/settlement_report_python/settlement_report_job/domain/wholesale_results/__init__.py b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/domain/wholesale_results/order_by_columns.py b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/order_by_columns.py new file mode 100644 index 0000000..9e81770 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/order_by_columns.py @@ -0,0 +1,15 @@ +from pyspark.sql import functions as F +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames + + +def order_by_columns() -> list: + return [ + F.col(CsvColumnNames.grid_area_code), + F.col(CsvColumnNames.energy_supplier_id), + F.col(CsvColumnNames.metering_point_type), + F.col(CsvColumnNames.settlement_method), + F.col(CsvColumnNames.time), + F.col(CsvColumnNames.charge_owner_id), + F.col(CsvColumnNames.charge_type), + F.col(CsvColumnNames.charge_code), + ] diff --git a/source/settlement_report_python/settlement_report_job/domain/wholesale_results/prepare_for_csv.py b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/prepare_for_csv.py new file mode 100644 index 0000000..89c3f02 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/prepare_for_csv.py @@ -0,0 +1,80 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.domain.utils.map_from_dict import ( + map_from_dict, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +import settlement_report_job.domain.utils.map_to_csv_naming as market_naming + +log = Logger(__name__) + + +@use_span() +def prepare_for_csv( + wholesale: DataFrame, + one_file_per_grid_area: bool = False, +) -> DataFrame: + select_columns = [ + map_from_dict(market_naming.CALCULATION_TYPES_TO_ENERGY_BUSINESS_PROCESS)[ + F.col(DataProductColumnNames.calculation_type) + ].alias(CsvColumnNames.calculation_type), + map_from_dict(market_naming.CALCULATION_TYPES_TO_PROCESS_VARIANT)[ + F.col(DataProductColumnNames.calculation_type) + ].alias(CsvColumnNames.correction_settlement_number), + F.col(DataProductColumnNames.grid_area_code).alias( + CsvColumnNames.grid_area_code + ), + F.col(DataProductColumnNames.energy_supplier_id).alias( + CsvColumnNames.energy_supplier_id + ), + F.col(DataProductColumnNames.time).alias(CsvColumnNames.time), + F.col(DataProductColumnNames.resolution).alias(CsvColumnNames.resolution), + map_from_dict(market_naming.METERING_POINT_TYPES)[ + F.col(DataProductColumnNames.metering_point_type) + ].alias(CsvColumnNames.metering_point_type), + map_from_dict(market_naming.SETTLEMENT_METHODS)[ + F.col(DataProductColumnNames.settlement_method) + ].alias(CsvColumnNames.settlement_method), + F.col(DataProductColumnNames.quantity_unit).alias(CsvColumnNames.quantity_unit), + F.col(DataProductColumnNames.currency).alias(CsvColumnNames.currency), + F.col(DataProductColumnNames.quantity).alias(CsvColumnNames.energy_quantity), + F.col(DataProductColumnNames.price).alias(CsvColumnNames.price), + F.col(DataProductColumnNames.amount).alias(CsvColumnNames.amount), + map_from_dict(market_naming.CHARGE_TYPES)[ + F.col(DataProductColumnNames.charge_type) + ].alias(CsvColumnNames.charge_type), + F.col(DataProductColumnNames.charge_code).alias(CsvColumnNames.charge_code), + F.col(DataProductColumnNames.charge_owner_id).alias( + CsvColumnNames.charge_owner_id + ), + ] + + if one_file_per_grid_area: + select_columns.append( + F.col(DataProductColumnNames.grid_area_code).alias( + EphemeralColumns.grid_area_code_partitioning + ), + ) + + return wholesale.select(select_columns) diff --git a/source/settlement_report_python/settlement_report_job/domain/wholesale_results/read_and_filter.py b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/read_and_filter.py new file mode 100644 index 0000000..4e009cd --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/read_and_filter.py @@ -0,0 +1,59 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from uuid import UUID +from datetime import datetime + +from pyspark.sql import DataFrame, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.factory_filters import ( + filter_by_charge_owner_and_tax_depending_on_market_role, + filter_by_calculation_id_by_grid_area, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +log = Logger(__name__) + + +@use_span() +def read_and_filter_from_view( + energy_supplier_ids: list[str] | None, + calculation_id_by_grid_area: dict[str, UUID], + period_start: datetime, + period_end: datetime, + requesting_actor_market_role: MarketRole, + requesting_actor_id: str, + repository: WholesaleRepository, +) -> DataFrame: + df = repository.read_amounts_per_charge().where( + (F.col(DataProductColumnNames.time) >= period_start) + & (F.col(DataProductColumnNames.time) < period_end) + ) + + if energy_supplier_ids is not None: + df = df.where( + F.col(DataProductColumnNames.energy_supplier_id).isin(energy_supplier_ids) + ) + + df = df.where(filter_by_calculation_id_by_grid_area(calculation_id_by_grid_area)) + + df = filter_by_charge_owner_and_tax_depending_on_market_role( + df, requesting_actor_market_role, requesting_actor_id + ) + + return df diff --git a/source/settlement_report_python/settlement_report_job/domain/wholesale_results/wholesale_results_factory.py b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/wholesale_results_factory.py new file mode 100644 index 0000000..a48004b --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/domain/wholesale_results/wholesale_results_factory.py @@ -0,0 +1,51 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark.sql import DataFrame +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from telemetry_logging import use_span + +from settlement_report_job.domain.utils.settlement_report_args_utils import ( + should_have_result_file_per_grid_area, +) +from settlement_report_job.domain.wholesale_results.read_and_filter import ( + read_and_filter_from_view, +) +from settlement_report_job.domain.wholesale_results.prepare_for_csv import ( + prepare_for_csv, +) + + +@use_span() +def create_wholesale_results( + args: SettlementReportArgs, + repository: WholesaleRepository, +) -> DataFrame: + wholesale = read_and_filter_from_view( + args.energy_supplier_ids, + args.calculation_id_by_grid_area, + args.period_start, + args.period_end, + args.requesting_actor_market_role, + args.requesting_actor_id, + repository, + ) + + return prepare_for_csv( + wholesale, + should_have_result_file_per_grid_area(args), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/__init__.py b/source/settlement_report_python/settlement_report_job/entry_points/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/entry_points/entry_point.py b/source/settlement_report_python/settlement_report_job/entry_points/entry_point.py new file mode 100644 index 0000000..a5050b6 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/entry_point.py @@ -0,0 +1,132 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from argparse import Namespace +from collections.abc import Callable + +from opentelemetry.trace import SpanKind + +import telemetry_logging.logging_configuration as config +from telemetry_logging.span_recording import span_record_exception +from settlement_report_job.entry_points.tasks import task_factory +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.entry_points.job_args.settlement_report_job_args import ( + parse_job_arguments, + parse_command_line_arguments, +) +from settlement_report_job.entry_points.tasks.task_type import TaskType +from settlement_report_job.entry_points.utils.get_dbutils import get_dbutils +from settlement_report_job.infrastructure.spark_initializor import initialize_spark + + +# The start_x() methods should only have its name updated in correspondence with the +# wheels entry point for it. Further the method must remain parameterless because +# it will be called from the entry point when deployed. +def start_hourly_time_series_points() -> None: + _start_task(TaskType.TimeSeriesHourly) + + +def start_quarterly_time_series_points() -> None: + _start_task(TaskType.TimeSeriesQuarterly) + + +def start_metering_point_periods() -> None: + _start_task(TaskType.MeteringPointPeriods) + + +def start_charge_link_periods() -> None: + _start_task(TaskType.ChargeLinks) + + +def start_charge_price_points() -> None: + _start_task(TaskType.ChargePricePoints) + + +def start_energy_results() -> None: + _start_task(TaskType.EnergyResults) + + +def start_wholesale_results() -> None: + _start_task(TaskType.WholesaleResults) + + +def start_monthly_amounts() -> None: + _start_task(TaskType.MonthlyAmounts) + + +def start_zip() -> None: + _start_task(TaskType.Zip) + + +def _start_task(task_type: TaskType) -> None: + applicationinsights_connection_string = os.getenv( + "APPLICATIONINSIGHTS_CONNECTION_STRING" + ) + + start_task_with_deps( + task_type=task_type, + applicationinsights_connection_string=applicationinsights_connection_string, + ) + + +def start_task_with_deps( + *, + task_type: TaskType, + cloud_role_name: str = "dbr-settlement-report", + applicationinsights_connection_string: str | None = None, + parse_command_line_args: Callable[..., Namespace] = parse_command_line_arguments, + parse_job_args: Callable[..., SettlementReportArgs] = parse_job_arguments, +) -> None: + """Start overload with explicit dependencies for easier testing.""" + config.configure_logging( + cloud_role_name=cloud_role_name, + tracer_name="settlement-report-job", + applicationinsights_connection_string=applicationinsights_connection_string, + extras={"Subsystem": "wholesale-aggregations"}, + ) + + with config.get_tracer().start_as_current_span( + __name__, kind=SpanKind.SERVER + ) as span: + # Try/except added to enable adding custom fields to the exception as + # the span attributes do not appear to be included in the exception. + try: + + # The command line arguments are parsed to have necessary information for + # coming log messages + command_line_args = parse_command_line_args() + + # Add settlement_report_id to structured logging data to be included in + # every log message. + config.add_extras({"settlement_report_id": command_line_args.report_id}) + span.set_attributes(config.get_extras()) + args = parse_job_args(command_line_args) + spark = initialize_spark() + + task = task_factory.create(task_type, spark, args) + task.execute() + + # Added as ConfigArgParse uses sys.exit() rather than raising exceptions + except SystemExit as e: + if e.code != 0: + span_record_exception(e, span) + sys.exit(e.code) + + except Exception as e: + span_record_exception(e, span) + sys.exit(4) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/job_args/__init__.py b/source/settlement_report_python/settlement_report_job/entry_points/job_args/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/entry_points/job_args/args_helper.py b/source/settlement_report_python/settlement_report_job/entry_points/job_args/args_helper.py new file mode 100644 index 0000000..30769a9 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/job_args/args_helper.py @@ -0,0 +1,45 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime + +import configargparse + + +def valid_date(s: str) -> datetime: + """See https://stackoverflow.com/questions/25470844/specify-date-format-for-python-argparse-input-arguments""" + try: + return datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ") + except ValueError: + msg = "not a valid date: {0!r}".format(s) + raise configargparse.ArgumentTypeError(msg) + + +def valid_energy_supplier_ids(s: str) -> list[str]: + if not s.startswith("[") or not s.endswith("]"): + msg = "Energy supplier IDs must be a list enclosed by an opening '[' and a closing ']'" + raise configargparse.ArgumentTypeError(msg) + + # 1. Remove enclosing list characters 2. Split each id 3. Remove possibly enclosing spaces. + tokens = [token.strip() for token in s.strip("[]").split(",")] + + # Energy supplier IDs must always consist of 13 or 16 digits + if any( + (len(token) != 13 and len(token) != 16) + or any(c < "0" or c > "9" for c in token) + for token in tokens + ): + msg = "Energy supplier IDs must consist of 13 or 16 digits" + raise configargparse.ArgumentTypeError(msg) + + return tokens diff --git a/source/settlement_report_python/settlement_report_job/entry_points/job_args/calculation_type.py b/source/settlement_report_python/settlement_report_job/entry_points/job_args/calculation_type.py new file mode 100644 index 0000000..f505f0f --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/job_args/calculation_type.py @@ -0,0 +1,25 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class CalculationType(Enum): + """The job parameter values must correspond with these values.""" + + BALANCE_FIXING = "balance_fixing" + WHOLESALE_FIXING = "wholesale_fixing" + FIRST_CORRECTION_SETTLEMENT = "first_correction_settlement" + SECOND_CORRECTION_SETTLEMENT = "second_correction_settlement" + THIRD_CORRECTION_SETTLEMENT = "third_correction_settlement" diff --git a/source/settlement_report_python/settlement_report_job/entry_points/job_args/environment_variables.py b/source/settlement_report_python/settlement_report_job/entry_points/job_args/environment_variables.py new file mode 100644 index 0000000..b71050e --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/job_args/environment_variables.py @@ -0,0 +1,33 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from enum import Enum +from typing import Any + + +class EnvironmentVariable(Enum): + CATALOG_NAME = "CATALOG_NAME" + + +def get_catalog_name() -> str: + return get_env_variable_or_throw(EnvironmentVariable.CATALOG_NAME) + + +def get_env_variable_or_throw(variable: EnvironmentVariable) -> Any: + env_variable = os.getenv(variable.name) + if env_variable is None: + raise ValueError(f"Environment variable not found: {variable.name}") + + return env_variable diff --git a/source/settlement_report_python/settlement_report_job/entry_points/job_args/settlement_report_args.py b/source/settlement_report_python/settlement_report_job/entry_points/job_args/settlement_report_args.py new file mode 100644 index 0000000..c1db617 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/job_args/settlement_report_args.py @@ -0,0 +1,29 @@ +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.domain.utils.market_role import MarketRole + + +@dataclass +class SettlementReportArgs: + report_id: str + period_start: datetime + period_end: datetime + calculation_type: CalculationType + requesting_actor_market_role: MarketRole + requesting_actor_id: str + calculation_id_by_grid_area: Optional[dict[str, uuid.UUID]] + """ A dictionary containing grid area codes (keys) and calculation ids (values). None for balance fixing""" + grid_area_codes: Optional[list[str]] + """ None if NOT balance fixing""" + energy_supplier_ids: Optional[list[str]] + split_report_by_grid_area: bool + prevent_large_text_files: bool + time_zone: str + catalog_name: str + settlement_reports_output_path: str + """The path to the folder where the settlement reports are stored.""" + include_basis_data: bool diff --git a/source/settlement_report_python/settlement_report_job/entry_points/job_args/settlement_report_job_args.py b/source/settlement_report_python/settlement_report_job/entry_points/job_args/settlement_report_job_args.py new file mode 100644 index 0000000..37bf57e --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/job_args/settlement_report_job_args.py @@ -0,0 +1,155 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import sys +import uuid +from argparse import Namespace + +import configargparse +from configargparse import argparse + +from settlement_report_job.entry_points.job_args.args_helper import ( + valid_date, + valid_energy_supplier_ids, +) +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.infrastructure.paths import ( + get_settlement_reports_output_path, +) +from telemetry_logging import Logger, logging_configuration +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +import settlement_report_job.entry_points.job_args.environment_variables as env_vars + + +def parse_command_line_arguments() -> Namespace: + return _parse_args_or_throw(sys.argv[1:]) + + +def parse_job_arguments( + job_args: Namespace, +) -> SettlementReportArgs: + logger = Logger(__name__) + logger.info(f"Command line arguments: {repr(job_args)}") + + with logging_configuration.start_span("settlement_report.parse_job_arguments"): + + grid_area_codes = ( + _create_grid_area_codes(job_args.grid_area_codes) + if job_args.calculation_type is CalculationType.BALANCE_FIXING + else None + ) + + calculation_id_by_grid_area = ( + _create_calculation_ids_by_grid_area_code( + job_args.calculation_id_by_grid_area + ) + if job_args.calculation_type is not CalculationType.BALANCE_FIXING + else None + ) + + settlement_report_args = SettlementReportArgs( + report_id=job_args.report_id, + period_start=job_args.period_start, + period_end=job_args.period_end, + calculation_type=job_args.calculation_type, + requesting_actor_market_role=job_args.requesting_actor_market_role, + requesting_actor_id=job_args.requesting_actor_id, + calculation_id_by_grid_area=calculation_id_by_grid_area, + grid_area_codes=grid_area_codes, + energy_supplier_ids=job_args.energy_supplier_ids, + split_report_by_grid_area=job_args.split_report_by_grid_area, + prevent_large_text_files=job_args.prevent_large_text_files, + time_zone="Europe/Copenhagen", + catalog_name=env_vars.get_catalog_name(), + settlement_reports_output_path=get_settlement_reports_output_path( + env_vars.get_catalog_name() + ), + include_basis_data=job_args.include_basis_data, + ) + + return settlement_report_args + + +def _parse_args_or_throw(command_line_args: list[str]) -> argparse.Namespace: + p = configargparse.ArgParser( + description="Create settlement report", + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + + # Run parameters + p.add_argument("--report-id", type=str, required=True) + p.add_argument("--period-start", type=valid_date, required=True) + p.add_argument("--period-end", type=valid_date, required=True) + p.add_argument("--calculation-type", type=CalculationType, required=True) + p.add_argument("--requesting-actor-market-role", type=MarketRole, required=True) + p.add_argument("--requesting-actor-id", type=str, required=True) + p.add_argument("--calculation-id-by-grid-area", type=str, required=False) + p.add_argument("--grid-area-codes", type=str, required=False) + p.add_argument( + "--energy-supplier-ids", type=valid_energy_supplier_ids, required=False + ) + p.add_argument( + "--split-report-by-grid-area", action="store_true" + ) # true if present, false otherwise + p.add_argument( + "--prevent-large-text-files", action="store_true" + ) # true if present, false otherwise + p.add_argument( + "--include-basis-data", action="store_true" + ) # true if present, false otherwise + + args, unknown_args = p.parse_known_args(args=command_line_args) + if len(unknown_args): + unknown_args_text = ", ".join(unknown_args) + raise Exception(f"Unknown args: {unknown_args_text}") + + return args + + +def _create_grid_area_codes(grid_area_codes: str) -> list[str]: + if not grid_area_codes.startswith("[") or not grid_area_codes.endswith("]"): + msg = "Grid area codes must be a list enclosed by an opening '[' and a closing ']'" + raise configargparse.ArgumentTypeError(msg) + + # 1. Remove enclosing list characters 2. Split each grid area code 3. Remove possibly enclosing spaces. + tokens = [token.strip() for token in grid_area_codes.strip("[]").split(",")] + + # Grid area codes must always consist of 3 digits + if any( + len(token) != 3 or any(c < "0" or c > "9" for c in token) for token in tokens + ): + msg = "Grid area codes must consist of 3 digits" + raise configargparse.ArgumentTypeError(msg) + + return tokens + + +def _create_calculation_ids_by_grid_area_code(json_str: str) -> dict[str, uuid.UUID]: + try: + calculation_id_by_grid_area = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError( + f"Failed to parse `calculation_id_by_grid_area` json format as dict[str, uuid]: {e}" + ) + + for grid_area, calculation_id in calculation_id_by_grid_area.items(): + try: + calculation_id_by_grid_area[grid_area] = uuid.UUID(calculation_id) + except ValueError: + raise ValueError(f"Calculation ID for grid area {grid_area} is not a uuid") + + return calculation_id_by_grid_area diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/__init__.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/charge_link_periods_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/charge_link_periods_task.py new file mode 100644 index 0000000..912b7ac --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/charge_link_periods_task.py @@ -0,0 +1,49 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.domain.charge_link_periods.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_base import ( + TaskBase, +) +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.domain.charge_link_periods.charge_link_periods_factory import ( + create_charge_link_periods, +) +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + +from telemetry_logging import use_span + + +class ChargeLinkPeriodsTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating charge links. + """ + if not self.args.include_basis_data: + return + + repository = WholesaleRepository(self.spark, self.args.catalog_name) + charge_link_periods = create_charge_link_periods( + args=self.args, repository=repository + ) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=charge_link_periods, + report_data_type=ReportDataType.ChargeLinks, + order_by_columns=order_by_columns(self.args.requesting_actor_market_role), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/charge_price_points_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/charge_price_points_task.py new file mode 100644 index 0000000..5ac95e6 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/charge_price_points_task.py @@ -0,0 +1,48 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.charge_price_points.charge_price_points_factory import ( + create_charge_price_points, +) +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.domain.charge_price_points.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_base import ( + TaskBase, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from telemetry_logging import use_span + + +class ChargePricePointsTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating charge prices. + """ + if not self.args.include_basis_data: + return + + repository = WholesaleRepository(self.spark, self.args.catalog_name) + charge_price_points = create_charge_price_points( + args=self.args, repository=repository + ) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=charge_price_points, + report_data_type=ReportDataType.ChargePricePoints, + order_by_columns=order_by_columns(), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/energy_resuls_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/energy_resuls_task.py new file mode 100644 index 0000000..6e73302 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/energy_resuls_task.py @@ -0,0 +1,45 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.domain.energy_results.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_base import TaskBase +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.energy_results.energy_results_factory import ( + create_energy_results, +) +from telemetry_logging import use_span +from settlement_report_job.domain.utils.market_role import MarketRole + + +class EnergyResultsTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating energy results. + """ + if self.args.requesting_actor_market_role == MarketRole.SYSTEM_OPERATOR: + return + + repository = WholesaleRepository(self.spark, self.args.catalog_name) + energy_results_df = create_energy_results(args=self.args, repository=repository) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=energy_results_df, + report_data_type=ReportDataType.EnergyResults, + order_by_columns=order_by_columns(self.args.requesting_actor_market_role), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/metering_point_periods_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/metering_point_periods_task.py new file mode 100644 index 0000000..17f890e --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/metering_point_periods_task.py @@ -0,0 +1,48 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.domain.metering_point_periods.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_base import ( + TaskBase, +) +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.domain.metering_point_periods.metering_point_periods_factory import ( + create_metering_point_periods, +) +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from telemetry_logging import use_span + + +class MeteringPointPeriodsTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating metering point periods. + """ + if not self.args.include_basis_data: + return + + repository = WholesaleRepository(self.spark, self.args.catalog_name) + charge_link_periods = create_metering_point_periods( + args=self.args, repository=repository + ) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=charge_link_periods, + report_data_type=ReportDataType.MeteringPointPeriods, + order_by_columns=order_by_columns(self.args.requesting_actor_market_role), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/monthly_amounts_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/monthly_amounts_task.py new file mode 100644 index 0000000..7bd1c78 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/monthly_amounts_task.py @@ -0,0 +1,43 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.domain.monthly_amounts.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_base import TaskBase +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.monthly_amounts.monthly_amounts_factory import ( + create_monthly_amounts, +) +from telemetry_logging import use_span + + +class MonthlyAmountsTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating wholesale results. + """ + repository = WholesaleRepository(self.spark, self.args.catalog_name) + wholesale_results_df = create_monthly_amounts( + args=self.args, repository=repository + ) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=wholesale_results_df, + report_data_type=ReportDataType.MonthlyAmounts, + order_by_columns=order_by_columns(self.args.requesting_actor_market_role), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_base.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_base.py new file mode 100644 index 0000000..484fd2d --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from telemetry_logging import Logger + + +class TaskBase: + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + self.spark = spark + self.dbutils = dbutils + self.args = args + self.log = Logger(__name__) + + @abstractmethod + def execute(self) -> None: + pass diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_factory.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_factory.py new file mode 100644 index 0000000..993eee1 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_factory.py @@ -0,0 +1,69 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.entry_points.tasks.task_type import TaskType +from settlement_report_job.entry_points.tasks.charge_price_points_task import ( + ChargePricePointsTask, +) +from settlement_report_job.entry_points.tasks.charge_link_periods_task import ( + ChargeLinkPeriodsTask, +) +from settlement_report_job.entry_points.tasks.energy_resuls_task import ( + EnergyResultsTask, +) +from settlement_report_job.entry_points.tasks.monthly_amounts_task import ( + MonthlyAmountsTask, +) +from settlement_report_job.entry_points.tasks.task_base import ( + TaskBase, +) +from settlement_report_job.entry_points.tasks.metering_point_periods_task import ( + MeteringPointPeriodsTask, +) +from settlement_report_job.entry_points.tasks.time_series_points_task import ( + TimeSeriesPointsTask, +) +from settlement_report_job.entry_points.tasks.wholesale_results_task import ( + WholesaleResultsTask, +) +from settlement_report_job.entry_points.tasks.zip_task import ZipTask +from settlement_report_job.entry_points.utils.get_dbutils import get_dbutils + + +def create( + task_type: TaskType, + spark: SparkSession, + args: SettlementReportArgs, +) -> TaskBase: + dbutils = get_dbutils(spark) + if task_type is TaskType.MeteringPointPeriods: + return MeteringPointPeriodsTask(spark=spark, dbutils=dbutils, args=args) + elif task_type is TaskType.TimeSeriesQuarterly: + return TimeSeriesPointsTask( + spark=spark, + dbutils=dbutils, + args=args, + task_type=TaskType.TimeSeriesQuarterly, + ) + elif task_type is TaskType.TimeSeriesHourly: + return TimeSeriesPointsTask( + spark=spark, dbutils=dbutils, args=args, task_type=TaskType.TimeSeriesHourly + ) + elif task_type is TaskType.ChargeLinks: + return ChargeLinkPeriodsTask(spark=spark, dbutils=dbutils, args=args) + elif task_type is TaskType.ChargePricePoints: + return ChargePricePointsTask(spark=spark, dbutils=dbutils, args=args) + elif task_type is TaskType.EnergyResults: + return EnergyResultsTask(spark=spark, dbutils=dbutils, args=args) + elif task_type is TaskType.WholesaleResults: + return WholesaleResultsTask(spark=spark, dbutils=dbutils, args=args) + elif task_type is TaskType.MonthlyAmounts: + return MonthlyAmountsTask(spark=spark, dbutils=dbutils, args=args) + elif task_type is TaskType.Zip: + return ZipTask(spark=spark, dbutils=dbutils, args=args) + else: + raise ValueError(f"Unknown task type: {task_type}") diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_type.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_type.py new file mode 100644 index 0000000..848578e --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/task_type.py @@ -0,0 +1,31 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class TaskType(Enum): + """ + Databricks tasks that can be executed. + """ + + TimeSeriesHourly = 1 + TimeSeriesQuarterly = 2 + MeteringPointPeriods = 3 + ChargeLinks = 4 + ChargePricePoints = 5 + EnergyResults = 6 + WholesaleResults = 7 + MonthlyAmounts = 8 + Zip = 9 diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/time_series_points_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/time_series_points_task.py new file mode 100644 index 0000000..2214b9f --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/time_series_points_task.py @@ -0,0 +1,88 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.domain.time_series_points.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_type import TaskType +from settlement_report_job.entry_points.tasks.task_base import TaskBase +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.time_series_points.time_series_points_factory import ( + create_time_series_points_for_wholesale, + create_time_series_points_for_balance_fixing, +) +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from telemetry_logging import use_span +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, +) + + +class TimeSeriesPointsTask(TaskBase): + def __init__( + self, + spark: SparkSession, + dbutils: Any, + args: SettlementReportArgs, + task_type: TaskType, + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + self.task_type = task_type + + @use_span() + def execute( + self, + ) -> None: + """ + Entry point for the logic of creating time series. + """ + if not self.args.include_basis_data: + return + + if self.task_type is TaskType.TimeSeriesHourly: + report_type = ReportDataType.TimeSeriesHourly + metering_point_resolution = MeteringPointResolutionDataProductValue.HOUR + elif self.task_type is TaskType.TimeSeriesQuarterly: + report_type = ReportDataType.TimeSeriesQuarterly + metering_point_resolution = MeteringPointResolutionDataProductValue.QUARTER + else: + raise ValueError(f"Unsupported report data type: {self.task_type}") + + repository = WholesaleRepository(self.spark, self.args.catalog_name) + if self.args.calculation_type is CalculationType.BALANCE_FIXING: + time_series_points_df = create_time_series_points_for_balance_fixing( + period_start=self.args.period_start, + period_end=self.args.period_end, + grid_area_codes=self.args.grid_area_codes, + time_zone=self.args.time_zone, + energy_supplier_ids=self.args.energy_supplier_ids, + metering_point_resolution=metering_point_resolution, + requesting_market_role=self.args.requesting_actor_market_role, + repository=repository, + ) + else: + time_series_points_df = create_time_series_points_for_wholesale( + period_start=self.args.period_start, + period_end=self.args.period_end, + calculation_id_by_grid_area=self.args.calculation_id_by_grid_area, + time_zone=self.args.time_zone, + energy_supplier_ids=self.args.energy_supplier_ids, + metering_point_resolution=metering_point_resolution, + repository=repository, + requesting_actor_market_role=self.args.requesting_actor_market_role, + requesting_actor_id=self.args.requesting_actor_id, + ) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=time_series_points_df, + report_data_type=report_type, + order_by_columns=order_by_columns(self.args.requesting_actor_market_role), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/wholesale_results_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/wholesale_results_task.py new file mode 100644 index 0000000..41d57fb --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/wholesale_results_task.py @@ -0,0 +1,43 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.domain.wholesale_results.order_by_columns import ( + order_by_columns, +) +from settlement_report_job.entry_points.tasks.task_base import TaskBase +from settlement_report_job.infrastructure import csv_writer +from settlement_report_job.infrastructure.repository import WholesaleRepository +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.wholesale_results.wholesale_results_factory import ( + create_wholesale_results, +) +from telemetry_logging import use_span + + +class WholesaleResultsTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating wholesale results. + """ + repository = WholesaleRepository(self.spark, self.args.catalog_name) + wholesale_results_df = create_wholesale_results( + args=self.args, repository=repository + ) + + csv_writer.write( + dbutils=self.dbutils, + args=self.args, + df=wholesale_results_df, + report_data_type=ReportDataType.WholesaleResults, + order_by_columns=order_by_columns(), + ) diff --git a/source/settlement_report_python/settlement_report_job/entry_points/tasks/zip_task.py b/source/settlement_report_python/settlement_report_job/entry_points/tasks/zip_task.py new file mode 100644 index 0000000..1ea34ca --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/tasks/zip_task.py @@ -0,0 +1,38 @@ +from typing import Any + +from pyspark.sql import SparkSession + +from settlement_report_job.entry_points.tasks.task_base import TaskBase +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.infrastructure.paths import get_report_output_path + +from settlement_report_job.infrastructure.create_zip_file import create_zip_file +from telemetry_logging import use_span + + +class ZipTask(TaskBase): + def __init__( + self, spark: SparkSession, dbutils: Any, args: SettlementReportArgs + ) -> None: + super().__init__(spark=spark, dbutils=dbutils, args=args) + + @use_span() + def execute(self) -> None: + """ + Entry point for the logic of creating the final zip file. + """ + report_output_path = get_report_output_path(self.args) + files_to_zip = [ + f"{report_output_path}/{file_info.name}" + for file_info in self.dbutils.fs.ls(report_output_path) + ] + + self.log.info(f"Files to zip: {files_to_zip}") + zip_file_path = ( + f"{self.args.settlement_reports_output_path}/{self.args.report_id}.zip" + ) + self.log.info(f"Creating zip file: '{zip_file_path}'") + create_zip_file(self.dbutils, self.args.report_id, zip_file_path, files_to_zip) + self.log.info(f"Finished creating '{zip_file_path}'") diff --git a/source/settlement_report_python/settlement_report_job/entry_points/utils/__init__.py b/source/settlement_report_python/settlement_report_job/entry_points/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/entry_points/utils/get_dbutils.py b/source/settlement_report_python/settlement_report_job/entry_points/utils/get_dbutils.py new file mode 100644 index 0000000..b5dcdc3 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/entry_points/utils/get_dbutils.py @@ -0,0 +1,36 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from pyspark.sql import SparkSession + + +def get_dbutils(spark: SparkSession) -> Any: + """Get the DBUtils object from the SparkSession. + + Args: + spark (SparkSession): The SparkSession object. + + Returns: + DBUtils: The DBUtils object. + """ + try: + from pyspark.dbutils import DBUtils # type: ignore + + dbutils = DBUtils(spark) + except ImportError: + raise ImportError( + "DBUtils is not available in local mode. This is expected when running tests." # noqa + ) + return dbutils diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/__init__.py b/source/settlement_report_python/settlement_report_job/infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/create_zip_file.py b/source/settlement_report_python/settlement_report_job/infrastructure/create_zip_file.py new file mode 100644 index 0000000..882bc68 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/create_zip_file.py @@ -0,0 +1,51 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import zipfile + +from typing import Any +from telemetry_logging import use_span + + +@use_span() +def create_zip_file( + dbutils: Any, report_id: str, save_path: str, files_to_zip: list[str] +) -> None: + """Creates a zip file from a list of files and saves it to the specified path. + + Notice that we have to create the zip file in /tmp and then move it to the desired + location. This is done as `direct-append` or `non-sequential` writes are not + supported in Databricks. + + Args: + dbutils (Any): The DBUtils object. + report_id (str): The report ID. + save_path (str): The path to save the zip file. + files_to_zip (list[str]): The list of files to zip. + + Raises: + Exception: If there are no files to zip. + Exception: If the save path does not end with .zip. + """ + if len(files_to_zip) == 0: + raise Exception("No files to zip") + if not save_path.endswith(".zip"): + raise Exception("Save path must end with .zip") + + tmp_path = f"/tmp/{report_id}.zip" + with zipfile.ZipFile(tmp_path, "a", zipfile.ZIP_DEFLATED) as ref: + for fp in files_to_zip: + file_name = fp.split("/")[-1] + ref.write(fp, arcname=file_name) + dbutils.fs.mv(f"file:{tmp_path}", save_path) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/csv_writer.py b/source/settlement_report_python/settlement_report_job/infrastructure/csv_writer.py new file mode 100644 index 0000000..ad25b4d --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/csv_writer.py @@ -0,0 +1,271 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from pyspark.sql import DataFrame, Window, functions as F + +from telemetry_logging import Logger, use_span +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.infrastructure.report_name_factory import FileNameFactory +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import EphemeralColumns +from settlement_report_job.infrastructure.paths import get_report_output_path + + +log = Logger(__name__) + + +@dataclass +class TmpFile: + src: Path + dst: Path + tmp_dst: Path + + +@use_span() +def write( + dbutils: Any, + args: SettlementReportArgs, + df: DataFrame, + report_data_type: ReportDataType, + order_by_columns: list[str], + rows_per_file: int = 1_000_000, +) -> list[str]: + + report_output_path = get_report_output_path(args) + spark_output_path = f"{report_output_path}/{_get_folder_name(report_data_type)}" + + partition_columns = [] + if EphemeralColumns.grid_area_code_partitioning in df.columns: + partition_columns.append(EphemeralColumns.grid_area_code_partitioning) + + if args.prevent_large_text_files: + partition_columns.append(EphemeralColumns.chunk_index) + + headers = _write_files( + df=df, + path=spark_output_path, + partition_columns=partition_columns, + order_by=order_by_columns, + rows_per_file=rows_per_file, + ) + + file_name_factory = FileNameFactory(report_data_type, args) + new_files = _get_new_files( + spark_output_path, + report_output_path, + file_name_factory, + partition_columns=partition_columns, + ) + files_paths = _merge_files( + dbutils=dbutils, + new_files=new_files, + headers=headers, + ) + + file_names = [os.path.basename(file_path) for file_path in files_paths] + + return file_names + + +def _get_folder_name(report_data_type: ReportDataType) -> str: + if report_data_type == ReportDataType.TimeSeriesHourly: + return "time_series_points_hourly" + elif report_data_type == ReportDataType.TimeSeriesQuarterly: + return "time_series_points_quarterly" + elif report_data_type == ReportDataType.MeteringPointPeriods: + return "metering_point_periods" + elif report_data_type == ReportDataType.ChargeLinks: + return "charge_link_periods" + elif report_data_type == ReportDataType.ChargePricePoints: + return "charge_price_points" + elif report_data_type == ReportDataType.EnergyResults: + return "energy_results" + elif report_data_type == ReportDataType.MonthlyAmounts: + return "monthly_amounts" + elif report_data_type == ReportDataType.WholesaleResults: + return "wholesale_results" + else: + raise ValueError(f"Unsupported report data type: {report_data_type}") + + +@use_span() +def _write_files( + df: DataFrame, + path: str, + partition_columns: list[str], + order_by: list[str], + rows_per_file: int, +) -> list[str]: + """Write a DataFrame to multiple files. + + Args: + df (DataFrame): DataFrame to write. + path (str): Path to write the files. + rows_per_file (int): Number of rows per file. + partition_columns: list[str]: Columns to partition by. + order_by (list[str]): Columns to order by. + + Returns: + list[str]: Headers for the csv file. + """ + if EphemeralColumns.chunk_index in partition_columns: + partition_columns_without_chunk = [ + col for col in partition_columns if col != EphemeralColumns.chunk_index + ] + w = Window().partitionBy(partition_columns_without_chunk).orderBy(order_by) + chunk_index_col = F.ceil((F.row_number().over(w)) / F.lit(rows_per_file)) + df = df.withColumn(EphemeralColumns.chunk_index, chunk_index_col) + + if len(order_by) > 0: + df = df.orderBy(*order_by) + + csv_writer_options = _get_csv_writer_options() + + if partition_columns: + df.write.mode("overwrite").options(**csv_writer_options).partitionBy( + partition_columns + ).csv(path) + else: + df.write.mode("overwrite").options(**csv_writer_options).csv(path) + + return [c for c in df.columns if c not in partition_columns] + + +@use_span() +def _get_new_files( + spark_output_path: str, + report_output_path: str, + file_name_factory: FileNameFactory, + partition_columns: list[str], +) -> list[TmpFile]: + """Get the new files to move to the final location. + + Args: + partition_columns: + spark_output_path (str): The path where the files are written. + report_output_path: The path where the files will be moved. + file_name_factory (FileNameFactory): Factory class for creating file names for the csv files. + + Returns: + list[dict[str, Path]]: List of dictionaries with the source and destination + paths for the new files. + """ + new_files = [] + + file_info_list = _get_file_info_list( + spark_output_path=spark_output_path, partition_columns=partition_columns + ) + + distinct_chunk_indices = set([chunk_index for _, _, chunk_index in file_info_list]) + include_chunk_index = len(distinct_chunk_indices) > 1 + + for f, grid_area, chunk_index in file_info_list: + file_name = file_name_factory.create( + grid_area_code=grid_area, + chunk_index=chunk_index if include_chunk_index else None, + ) + new_name = Path(report_output_path) / file_name + tmp_dst = Path("/tmp") / file_name + new_files.append(TmpFile(f, new_name, tmp_dst)) + + return new_files + + +def _get_file_info_list( + spark_output_path: str, partition_columns: list[str] +) -> list[tuple[Path, str | None, str | None]]: + file_info_list = [] + + files = [f for f in Path(spark_output_path).rglob("*.csv")] + + partition_by_grid_area = ( + EphemeralColumns.grid_area_code_partitioning in partition_columns + ) + partition_by_chunk_index = EphemeralColumns.chunk_index in partition_columns + + regex = spark_output_path + if partition_by_grid_area: + regex = f"{regex}/{EphemeralColumns.grid_area_code_partitioning}=(\\w{{3}})" + + if partition_by_chunk_index: + regex = f"{regex}/{EphemeralColumns.chunk_index}=(\\d+)" + + for f in files: + partition_match = re.match(regex, str(f)) + if partition_match is None: + raise ValueError(f"File {f} does not match the expected pattern") + + groups = partition_match.groups() + group_count = 0 + + if partition_by_grid_area: + grid_area = groups[group_count] + group_count += 1 + else: + grid_area = None + + if partition_by_chunk_index and len(files) > 1: + chunk_index = groups[group_count] + group_count += 1 + else: + chunk_index = None + + file_info_list.append((f, grid_area, chunk_index)) + + return file_info_list + + +@use_span() +def _merge_files( + dbutils: Any, new_files: list[TmpFile], headers: list[str] +) -> list[str]: + """Merges the new files and moves them to the final location. + + Args: + dbutils (Any): The DBUtils object. + new_files (list[dict[str, Path]]): List of dictionaries with the source and + destination paths for the new files. + headers (list[str]): Headers for the csv file. + + Returns: + list[str]: List of the final file paths. + """ + print("Files to merge: " + str(new_files)) + for tmp_dst in set([f.tmp_dst for f in new_files]): + tmp_dst.parent.mkdir(parents=True, exist_ok=True) + with tmp_dst.open("w+") as f_tmp_dst: + print("Creating " + str(tmp_dst)) + f_tmp_dst.write(",".join(headers) + "\n") + + for _file in new_files: + with _file.src.open("r") as f_src: + with _file.tmp_dst.open("a") as f_tmp_dst: + f_tmp_dst.write(f_src.read()) + + for tmp_dst, dst in set([(f.tmp_dst, f.dst) for f in new_files]): + print("Moving " + str(tmp_dst) + " to " + str(dst)) + dbutils.fs.mv("file:" + str(tmp_dst), str(dst)) + + return list(set([str(_file.dst) for _file in new_files])) + + +def _get_csv_writer_options() -> dict[str, str]: + return {"timestampFormat": "yyyy-MM-dd'T'HH:mm:ss'Z'"} diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/paths.py b/source/settlement_report_python/settlement_report_job/infrastructure/paths.py new file mode 100644 index 0000000..b82beeb --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/paths.py @@ -0,0 +1,11 @@ +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + + +def get_settlement_reports_output_path(catalog_name: str) -> str: + return f"/Volumes/{catalog_name}/wholesale_settlement_report_output/settlement_reports" # noqa: E501 + + +def get_report_output_path(args: SettlementReportArgs) -> str: + return f"{args.settlement_reports_output_path}/{args.report_id}" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/report_name_factory.py b/source/settlement_report_python/settlement_report_job/infrastructure/report_name_factory.py new file mode 100644 index 0000000..876c6a6 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/report_name_factory.py @@ -0,0 +1,144 @@ +from datetime import timedelta +from zoneinfo import ZoneInfo + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + + +class MarketRoleInFileName: + """ + Market role identifiers used in the csv file name in the settlement report. + System operator and datahub admin are not included as they are not part of the file name. + """ + + ENERGY_SUPPLIER = "DDQ" + GRID_ACCESS_PROVIDER = "DDM" + + +class FileNameFactory: + def __init__(self, report_data_type: ReportDataType, args: SettlementReportArgs): + self.report_data_type = report_data_type + self.args = args + + def create( + self, + grid_area_code: str | None, + chunk_index: str | None, + ) -> str: + if self.report_data_type in { + ReportDataType.TimeSeriesHourly, + ReportDataType.TimeSeriesQuarterly, + ReportDataType.MeteringPointPeriods, + ReportDataType.ChargeLinks, + ReportDataType.ChargePricePoints, + }: + return self._create_basis_data_filename(grid_area_code, chunk_index) + if self.report_data_type in [ + ReportDataType.EnergyResults, + ReportDataType.WholesaleResults, + ReportDataType.MonthlyAmounts, + ]: + return self._create_result_filename(grid_area_code, chunk_index) + else: + raise NotImplementedError( + f"Report data type {self.report_data_type} is not supported." + ) + + def _create_result_filename( + self, + grid_area_code: str | None, + chunk_index: str | None, + ) -> str: + filename_parts = [ + self._get_pre_fix(), + grid_area_code if grid_area_code is not None else "flere-net", + self._get_actor_id_in_file_name(), + self._get_market_role_in_file_name(), + self._get_start_date(), + self._get_end_date(), + chunk_index, + ] + + filename_parts_without_none = [ + part for part in filename_parts if part is not None + ] + + return "_".join(filename_parts_without_none) + ".csv" + + def _create_basis_data_filename( + self, + grid_area_code: str | None, + chunk_index: str | None, + ) -> str: + + filename_parts = [ + self._get_pre_fix(), + grid_area_code, + self._get_actor_id_in_file_name(), + self._get_market_role_in_file_name(), + self._get_start_date(), + self._get_end_date(), + chunk_index, + ] + + filename_parts_without_none = [ + part for part in filename_parts if part is not None + ] + + return "_".join(filename_parts_without_none) + ".csv" + + def _get_start_date(self) -> str: + time_zone_info = ZoneInfo(self.args.time_zone) + return self.args.period_start.astimezone(time_zone_info).strftime("%d-%m-%Y") + + def _get_end_date(self) -> str: + time_zone_info = ZoneInfo(self.args.time_zone) + return ( + self.args.period_end.astimezone(time_zone_info) - timedelta(days=1) + ).strftime("%d-%m-%Y") + + def _get_pre_fix(self) -> str: + if self.report_data_type == ReportDataType.TimeSeriesHourly: + return "TSSD60" + elif self.report_data_type == ReportDataType.TimeSeriesQuarterly: + return "TSSD15" + elif self.report_data_type == ReportDataType.MeteringPointPeriods: + return "MDMP" + elif self.report_data_type == ReportDataType.ChargeLinks: + return "CHARGELINK" + elif self.report_data_type == ReportDataType.ChargePricePoints: + return "CHARGEPRICE" + elif self.report_data_type == ReportDataType.EnergyResults: + return "RESULTENERGY" + elif self.report_data_type == ReportDataType.WholesaleResults: + return "RESULTWHOLESALE" + elif self.report_data_type == ReportDataType.MonthlyAmounts: + return "RESULTMONTHLY" + raise NotImplementedError( + f"Report data type {self.report_data_type} is not supported." + ) + + def _get_actor_id_in_file_name(self) -> str | None: + + if self.args.requesting_actor_market_role in [ + MarketRole.GRID_ACCESS_PROVIDER, + MarketRole.ENERGY_SUPPLIER, + ]: + return self.args.requesting_actor_id + elif ( + self.args.energy_supplier_ids is not None + and len(self.args.energy_supplier_ids) == 1 + ): + return self.args.energy_supplier_ids[0] + return None + + def _get_market_role_in_file_name(self) -> str | None: + if self.args.requesting_actor_market_role == MarketRole.ENERGY_SUPPLIER: + return MarketRoleInFileName.ENERGY_SUPPLIER + elif self.args.requesting_actor_market_role == MarketRole.GRID_ACCESS_PROVIDER: + return MarketRoleInFileName.GRID_ACCESS_PROVIDER + + return None diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/repository.py b/source/settlement_report_python/settlement_report_job/infrastructure/repository.py new file mode 100644 index 0000000..55599be --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/repository.py @@ -0,0 +1,103 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pyspark.sql import DataFrame, SparkSession + +from settlement_report_job.infrastructure.wholesale.database_definitions import ( + WholesaleResultsDatabase, + WholesaleBasisDataDatabase, +) + + +class WholesaleRepository: + def __init__( + self, + spark: SparkSession, + catalog_name: str, + ) -> None: + self._spark = spark + self._catalog_name = catalog_name + + def read_metering_point_periods(self) -> DataFrame: + return self._read_view_or_table( + WholesaleBasisDataDatabase.DATABASE_NAME, + WholesaleBasisDataDatabase.METERING_POINT_PERIODS_VIEW_NAME, + ) + + def read_metering_point_time_series(self) -> DataFrame: + return self._read_view_or_table( + WholesaleBasisDataDatabase.DATABASE_NAME, + WholesaleBasisDataDatabase.TIME_SERIES_POINTS_VIEW_NAME, + ) + + def read_charge_price_points(self) -> DataFrame: + return self._read_view_or_table( + WholesaleBasisDataDatabase.DATABASE_NAME, + WholesaleBasisDataDatabase.CHARGE_PRICE_POINTS_VIEW_NAME, + ) + + def read_charge_link_periods(self) -> DataFrame: + return self._read_view_or_table( + WholesaleBasisDataDatabase.DATABASE_NAME, + WholesaleBasisDataDatabase.CHARGE_LINK_PERIODS_VIEW_NAME, + ) + + def read_charge_price_information_periods(self) -> DataFrame: + return self._read_view_or_table( + WholesaleBasisDataDatabase.DATABASE_NAME, + WholesaleBasisDataDatabase.CHARGE_PRICE_INFORMATION_PERIODS_VIEW_NAME, + ) + + def read_energy(self) -> DataFrame: + return self._read_view_or_table( + WholesaleResultsDatabase.DATABASE_NAME, + WholesaleResultsDatabase.ENERGY_V1_VIEW_NAME, + ) + + def read_latest_calculations(self) -> DataFrame: + return self._read_view_or_table( + WholesaleResultsDatabase.DATABASE_NAME, + WholesaleResultsDatabase.LATEST_CALCULATIONS_BY_DAY_VIEW_NAME, + ) + + def read_energy_per_es(self) -> DataFrame: + return self._read_view_or_table( + WholesaleResultsDatabase.DATABASE_NAME, + WholesaleResultsDatabase.ENERGY_PER_ES_V1_VIEW_NAME, + ) + + def read_amounts_per_charge(self) -> DataFrame: + return self._read_view_or_table( + WholesaleResultsDatabase.DATABASE_NAME, + WholesaleResultsDatabase.AMOUNTS_PER_CHARGE_VIEW_NAME, + ) + + def read_monthly_amounts_per_charge_v1(self) -> DataFrame: + return self._read_view_or_table( + WholesaleResultsDatabase.DATABASE_NAME, + WholesaleResultsDatabase.MONTHLY_AMOUNTS_PER_CHARGE_VIEW_NAME, + ) + + def read_total_monthly_amounts_v1(self) -> DataFrame: + return self._read_view_or_table( + WholesaleResultsDatabase.DATABASE_NAME, + WholesaleResultsDatabase.TOTAL_MONTHLY_AMOUNTS_VIEW_NAME, + ) + + def _read_view_or_table( + self, + database_name: str, + table_name: str, + ) -> DataFrame: + name = f"{self._catalog_name}.{database_name}.{table_name}" + return self._spark.read.format("delta").table(name) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/spark_initializor.py b/source/settlement_report_python/settlement_report_job/infrastructure/spark_initializor.py new file mode 100644 index 0000000..4ef48e1 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/spark_initializor.py @@ -0,0 +1,26 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyspark import SparkConf +from pyspark.sql.session import SparkSession + + +def initialize_spark() -> SparkSession: + # Set spark config with the session timezone so that datetimes are displayed consistently (in UTC) + spark_conf = ( + SparkConf(loadDefaults=True) + .set("spark.sql.session.timeZone", "UTC") + .set("spark.databricks.io.cache.enabled", "True") + ) + return SparkSession.builder.config(conf=spark_conf).getOrCreate() diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/__init__.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/__init__.py new file mode 100644 index 0000000..c3d13ef --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/__init__.py @@ -0,0 +1,6 @@ +""" +This sub-package contains the schemas, column names, and more from wholesale data products +on which the settlement report job depends. + +See the data products at https://energinet.atlassian.net/wiki/spaces/D3/pages/849903618/Wholesale+Data+Products. +""" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/column_names.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/column_names.py new file mode 100644 index 0000000..e1de3d9 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/column_names.py @@ -0,0 +1,35 @@ +class DataProductColumnNames: + amount = "amount" + balance_responsible_party_id = "balance_responsible_party_id" + calculation_id = "calculation_id" + calculation_period_end = "calculation_period_end" + calculation_period_start = "calculation_period_start" + calculation_type = "calculation_type" + calculation_version = "calculation_version" + charge_code = "charge_code" + charge_key = "charge_key" + charge_owner_id = "charge_owner_id" + charge_type = "charge_type" + charge_time = "charge_time" + charge_price = "charge_price" + currency = "currency" + energy_supplier_id = "energy_supplier_id" + from_date = "from_date" + from_grid_area_code = "from_grid_area_code" + grid_area_code = "grid_area_code" + is_tax = "is_tax" + metering_point_id = "metering_point_id" + metering_point_type = "metering_point_type" + observation_time = "observation_time" + parent_metering_point_id = "parent_metering_point_id" + price = "price" + quantity = "quantity" + quantity_qualities = "quantity_qualities" + quantity_unit = "quantity_unit" + resolution = "resolution" + result_id = "result_id" + settlement_method = "settlement_method" + start_of_day = "start_of_day" + time = "time" + to_grid_area_code = "to_grid_area_code" + to_date = "to_date" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/__init__.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/__init__.py new file mode 100644 index 0000000..4c16ff8 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/__init__.py @@ -0,0 +1,6 @@ +from .calculation_type import CalculationTypeDataProductValue +from .charge_resolution import ChargeResolutionDataProductValue +from .charge_type import ChargeTypeDataProductValue +from .metering_point_resolution import MeteringPointResolutionDataProductValue +from .metering_point_type import MeteringPointTypeDataProductValue +from .settlement_method import SettlementMethodDataProductValue diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/calculation_type.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/calculation_type.py new file mode 100644 index 0000000..5e86d5c --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/calculation_type.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class CalculationTypeDataProductValue(Enum): + AGGREGATION = "aggregation" + BALANCE_FIXING = "balance_fixing" + WHOLESALE_FIXING = "wholesale_fixing" + FIRST_CORRECTION_SETTLEMENT = "first_correction_settlement" + SECOND_CORRECTION_SETTLEMENT = "second_correction_settlement" + THIRD_CORRECTION_SETTLEMENT = "third_correction_settlement" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/charge_resolution.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/charge_resolution.py new file mode 100644 index 0000000..18f02f0 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/charge_resolution.py @@ -0,0 +1,24 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + + +class ChargeResolutionDataProductValue(Enum): + """ + Time resolution of the charges, which is read from the Wholesale data product + """ + + MONTH = "P1M" + DAY = "P1D" + HOUR = "PT1H" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/charge_type.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/charge_type.py new file mode 100644 index 0000000..a66e20e --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/charge_type.py @@ -0,0 +1,24 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + + +class ChargeTypeDataProductValue(Enum): + """ + Charge type which is read from the Wholesale data product + """ + + TARIFF = "tariff" + FEE = "fee" + SUBSCRIPTION = "subscription" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/metering_point_resolution.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/metering_point_resolution.py new file mode 100644 index 0000000..61ab96d --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/metering_point_resolution.py @@ -0,0 +1,21 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + + +class MeteringPointResolutionDataProductValue(Enum): + """Resolution values as defined for metering points in the data product(s).""" + + HOUR = "PT1H" + QUARTER = "PT15M" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/metering_point_type.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/metering_point_type.py new file mode 100644 index 0000000..9840978 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/metering_point_type.py @@ -0,0 +1,33 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + + +class MeteringPointTypeDataProductValue(Enum): + PRODUCTION = "production" + CONSUMPTION = "consumption" + EXCHANGE = "exchange" + # The following are for child metering points + VE_PRODUCTION = "ve_production" + NET_PRODUCTION = "net_production" + SUPPLY_TO_GRID = "supply_to_grid" + CONSUMPTION_FROM_GRID = "consumption_from_grid" + WHOLESALE_SERVICES_INFORMATION = "wholesale_services_information" + OWN_PRODUCTION = "own_production" + NET_FROM_GRID = "net_from_grid" + NET_TO_GRID = "net_to_grid" + TOTAL_CONSUMPTION = "total_consumption" + ELECTRICAL_HEATING = "electrical_heating" + NET_CONSUMPTION = "net_consumption" + EFFECT_SETTLEMENT = "effect_settlement" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/settlement_method.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/settlement_method.py new file mode 100644 index 0000000..ec5c757 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/data_values/settlement_method.py @@ -0,0 +1,19 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + + +class SettlementMethodDataProductValue(Enum): + FLEX = "flex" + NON_PROFILED = "non_profiled" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/database_definitions.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/database_definitions.py new file mode 100644 index 0000000..e456df3 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/database_definitions.py @@ -0,0 +1,32 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class WholesaleBasisDataDatabase: + DATABASE_NAME = "wholesale_basis_data" + METERING_POINT_PERIODS_VIEW_NAME = "metering_point_periods_v1" + TIME_SERIES_POINTS_VIEW_NAME = "time_series_points_v1" + CHARGE_PRICE_POINTS_VIEW_NAME = "charge_price_points_v1" + CHARGE_LINK_PERIODS_VIEW_NAME = "charge_link_periods_v1" + CHARGE_PRICE_INFORMATION_PERIODS_VIEW_NAME = "charge_price_information_periods_v1" + + +class WholesaleResultsDatabase: + DATABASE_NAME = "wholesale_results" + LATEST_CALCULATIONS_BY_DAY_VIEW_NAME = "latest_calculations_by_day_v1" + ENERGY_V1_VIEW_NAME = "energy_v1" + ENERGY_PER_ES_V1_VIEW_NAME = "energy_per_es_v1" + AMOUNTS_PER_CHARGE_VIEW_NAME = "amounts_per_charge_v1" # for some reason we call amounts per charge for wholesale results + MONTHLY_AMOUNTS_PER_CHARGE_VIEW_NAME = "monthly_amounts_per_charge_v1" + TOTAL_MONTHLY_AMOUNTS_VIEW_NAME = "total_monthly_amounts_v1" diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/__init__.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/__init__.py new file mode 100644 index 0000000..4bdacf2 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/__init__.py @@ -0,0 +1,6 @@ +from .charge_link_periods_v1 import charge_link_periods_v1 +from .charge_price_information_periods_v1 import charge_price_information_periods_v1 +from .metering_point_periods_v1 import metering_point_periods_v1 +from .metering_point_time_series_v1 import metering_point_time_series_v1 +from .monthly_amounts_per_charge_v1 import monthly_amounts_per_charge_v1 +from .amounts_per_charge_v1 import amounts_per_charge_v1 diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/amounts_per_charge_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/amounts_per_charge_v1.py new file mode 100644 index 0000000..fefe508 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/amounts_per_charge_v1.py @@ -0,0 +1,58 @@ +import pyspark.sql.types as t + +nullable = True + +amounts_per_charge_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'balance_fixing' | 'aggregation' | 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + # UUID + t.StructField("result_id", t.StringType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + t.StructField("energy_supplier_id", t.StringType(), not nullable), + # + t.StructField("charge_code", t.StringType(), not nullable), + # + t.StructField("charge_type", t.StringType(), not nullable), + # + t.StructField("charge_owner_id", t.StringType(), not nullable), + # + # 'PT15M' | 'PT1H' + t.StructField("resolution", t.StringType(), not nullable), + # + # 'kWh' + t.StructField("quantity_unit", t.StringType(), not nullable), + # + # 'consumption' | 'production' | 'exchange' + t.StructField("metering_point_type", t.StringType(), not nullable), + # + # 'flex' | 'non_profiled' | NULL + t.StructField("settlement_method", t.StringType(), nullable), + # + t.StructField("is_tax", t.BooleanType(), not nullable), + # + t.StructField("currency", t.StringType(), not nullable), + # + # UTC time + t.StructField("time", t.TimestampType(), not nullable), + # + t.StructField("quantity", t.DecimalType(18, 3), not nullable), + # + # [ 'measured' | 'missing' | 'calculated' | 'estimated' ] + # There is at least one element, and no element is included more than once. + t.StructField("quantity_qualities", t.ArrayType(t.StringType()), not nullable), + # + t.StructField("price", t.DecimalType(18, 6), nullable), + # + t.StructField("amount", t.DecimalType(18, 6), nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_link_periods_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_link_periods_v1.py new file mode 100644 index 0000000..67d6fa1 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_link_periods_v1.py @@ -0,0 +1,38 @@ +import pyspark.sql.types as t + +nullable = True + +charge_link_periods_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'wholesale_fixing' | 'aggregation' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + t.StructField("charge_key", t.StringType(), not nullable), + # + t.StructField("charge_code", t.StringType(), not nullable), + # + # 'subscription' | 'fee' | 'tariff' + t.StructField("charge_type", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("charge_owner_id", t.StringType(), not nullable), + # + # GSRN number + t.StructField("metering_point_id", t.StringType(), not nullable), + # + # The original type is integer, but in some contexts the quantity type is decimal. + t.StructField("quantity", t.IntegerType(), not nullable), + # + # UTC time + t.StructField("from_date", t.TimestampType(), not nullable), + # + # UTC time + t.StructField("to_date", t.TimestampType(), nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_price_information_periods_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_price_information_periods_v1.py new file mode 100644 index 0000000..c336f59 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_price_information_periods_v1.py @@ -0,0 +1,37 @@ +import pyspark.sql.types as t + +nullable = True + +charge_price_information_periods_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'wholesale_fixing' | 'aggregation' |'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + t.StructField("charge_key", t.StringType(), not nullable), + # + t.StructField("charge_code", t.StringType(), not nullable), + # + # 'subscription' | 'fee' | 'tariff' + t.StructField("charge_type", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("charge_owner_id", t.StringType(), not nullable), + # + t.StructField("resolution", t.StringType(), not nullable), + # + # Taxation + t.StructField("is_tax", t.BooleanType(), not nullable), + # + # UTC time + t.StructField("from_date", t.TimestampType(), not nullable), + # + # UTC time + t.StructField("to_date", t.TimestampType(), nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_price_points_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_price_points_v1.py new file mode 100644 index 0000000..8fc2fd6 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/charge_price_points_v1.py @@ -0,0 +1,32 @@ +import pyspark.sql.types as t + +nullable = True + +charge_price_points_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'wholesale_fixing' | 'aggregation' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + t.StructField("charge_key", t.StringType(), not nullable), + # + t.StructField("charge_code", t.StringType(), not nullable), + # + # 'subscription' | 'fee' | 'tariff' + t.StructField("charge_type", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("charge_owner_id", t.StringType(), not nullable), + # + # The original type is integer, but in some contexts the quantity type is decimal. + t.StructField("charge_price", t.DecimalType(), not nullable), + # + # UTC time + t.StructField("charge_time", t.TimestampType(), not nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/energy_per_es_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/energy_per_es_v1.py new file mode 100644 index 0000000..7a3d8b1 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/energy_per_es_v1.py @@ -0,0 +1,53 @@ +import pyspark.sql.types as t + +nullable = True + +energy_per_es_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'balance_fixing' | 'aggregation' | 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + # UTC time + t.StructField("calculation_period_start", t.TimestampType(), not nullable), + # + # UTC time + t.StructField("calculation_period_end", t.TimestampType(), not nullable), + t.StructField("calculation_version", t.LongType(), not nullable), + # + # UUID + t.StructField("result_id", t.StringType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("energy_supplier_id", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("balance_responsible_party_id", t.StringType(), not nullable), + # + # 'consumption' | 'production' + t.StructField("metering_point_type", t.StringType(), not nullable), + # + # 'flex' | 'non_profiled' + t.StructField("settlement_method", t.StringType(), nullable), + # + # 'PT15M' | 'PT1H' + t.StructField("resolution", t.StringType(), not nullable), + # + # UTC time + t.StructField("time", t.TimestampType(), not nullable), + # + t.StructField("quantity", t.DecimalType(18, 3), not nullable), + # + # 'kWh' + t.StructField("quantity_unit", t.StringType(), not nullable), + # + # [ 'measured' | 'missing' | 'calculated' | 'estimated' ] + # There is at least one element, and no element is included more than once. + t.StructField("quantity_qualities", t.ArrayType(t.StringType()), not nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/energy_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/energy_v1.py new file mode 100644 index 0000000..1d8cb02 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/energy_v1.py @@ -0,0 +1,48 @@ +import pyspark.sql.types as t + +nullable = True + +energy_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'balance_fixing' | 'aggregation' | 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + # UTC time + t.StructField("calculation_period_start", t.TimestampType(), not nullable), + # + # UTC time + t.StructField("calculation_period_end", t.TimestampType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + # UUID + t.StructField("result_id", t.StringType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + # 'consumption' | 'production' | 'exchange' + t.StructField("metering_point_type", t.StringType(), not nullable), + # + # 'flex' | 'non_profiled' | NULL + t.StructField("settlement_method", t.StringType(), nullable), + # + # 'PT15M' | 'PT1H' + t.StructField("resolution", t.StringType(), not nullable), + # + # UTC time + t.StructField("time", t.TimestampType(), not nullable), + # + t.StructField("quantity", t.DecimalType(18, 3), not nullable), + # + # 'kWh' + t.StructField("quantity_unit", t.StringType(), not nullable), + # + # [ 'measured' | 'missing' | 'calculated' | 'estimated' ] + # There is at least one element, and no element is included more than once. + t.StructField("quantity_qualities", t.ArrayType(t.StringType()), not nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/latest_calculations_by_day_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/latest_calculations_by_day_v1.py new file mode 100644 index 0000000..1f10299 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/latest_calculations_by_day_v1.py @@ -0,0 +1,19 @@ +import pyspark.sql.types as t + +nullable = True + +latest_calculations_by_day_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'aggregation' | 'balance_fixing' | 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + t.StructField("start_of_day", t.TimestampType(), not nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/metering_point_periods_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/metering_point_periods_v1.py new file mode 100644 index 0000000..3cf821a --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/metering_point_periods_v1.py @@ -0,0 +1,54 @@ +import pyspark.sql.types as t + +nullable = True + +metering_point_periods_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'balance_fixing' | 'aggregation' | 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + # GRSN number + t.StructField("metering_point_id", t.StringType(), not nullable), + # + # 'production' | 'consumption' | 'exchange' + # When wholesale calculations types also: + # 've_production' | 'net_production' | 'supply_to_grid' | + # 'consumption_from_grid' | 'wholesale_services_information' | + # 'own_production' | 'net_from_grid' 'net_to_grid' | 'total_consumption' | + # 'electrical_heating' | 'net_consumption' | 'effect_settlement' + t.StructField("metering_point_type", t.StringType(), not nullable), + # + # 'non_profiled' | 'flex' + t.StructField("settlement_method", t.StringType(), nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + t.StructField("resolution", t.StringType(), not nullable), + # + t.StructField("from_grid_area_code", t.StringType(), nullable), + # + t.StructField("to_grid_area_code", t.StringType(), nullable), + # + # GRSN number + t.StructField("parent_metering_point_id", t.StringType(), nullable), + # + # EIC or GLN number + t.StructField("energy_supplier_id", t.StringType(), nullable), + # + # EIC or GLN number + t.StructField("balance_responsible_party_id", t.StringType(), nullable), + # + # UTC time + t.StructField("from_date", t.TimestampType(), not nullable), + # + # UTC time + t.StructField("to_date", t.TimestampType(), nullable), + # + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/metering_point_time_series_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/metering_point_time_series_v1.py new file mode 100644 index 0000000..b5c7b4f --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/metering_point_time_series_v1.py @@ -0,0 +1,43 @@ +import pyspark.sql.types as t + +nullable = True + + +metering_point_time_series_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'balance_fixing' | 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + # GSRN number + t.StructField("metering_point_id", t.StringType(), not nullable), + # + # 'production' | 'consumption' | 'exchange' + # When wholesale calculations types also: + # 've_production' | 'net_production' | 'supply_to_grid' 'consumption_from_grid' | + # 'wholesale_services_information' | 'own_production' | 'net_from_grid' 'net_to_grid' | + # 'total_consumption' | 'electrical_heating' | 'net_consumption' | 'effect_settlement' + t.StructField("metering_point_type", t.StringType(), not nullable), + # + # 'PT15M' | 'PT1H' + t.StructField("resolution", t.StringType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("energy_supplier_id", t.StringType(), nullable), + # + # UTC time + t.StructField( + "observation_time", + t.TimestampType(), + not nullable, + ), + t.StructField("quantity", t.DecimalType(18, 3), not nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/monthly_amounts_per_charge_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/monthly_amounts_per_charge_v1.py new file mode 100644 index 0000000..9dc8b45 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/monthly_amounts_per_charge_v1.py @@ -0,0 +1,45 @@ +import pyspark.sql.types as t + +nullable = True + +monthly_amounts_per_charge_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + # UUID + t.StructField("result_id", t.StringType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("energy_supplier_id", t.StringType(), not nullable), + # + t.StructField("charge_code", t.StringType(), not nullable), + # + # 'tariff' | 'subscription' | 'fee' + t.StructField("charge_type", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("charge_owner_id", t.StringType(), not nullable), + # + # 'kWh' | 'pcs' + t.StructField("quantity_unit", t.StringType(), not nullable), + # + t.StructField("is_tax", t.BooleanType(), not nullable), + # + # 'DKK' + t.StructField("currency", t.StringType(), not nullable), + # + # UTC time + t.StructField("time", t.TimestampType(), not nullable), + # + t.StructField("amount", t.DecimalType(18, 6), nullable), + ] +) diff --git a/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/total_monthly_amounts_v1.py b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/total_monthly_amounts_v1.py new file mode 100644 index 0000000..cab6d54 --- /dev/null +++ b/source/settlement_report_python/settlement_report_job/infrastructure/wholesale/schemas/total_monthly_amounts_v1.py @@ -0,0 +1,35 @@ +import pyspark.sql.types as t + +nullable = True + +total_monthly_amounts_v1 = t.StructType( + [ + # UUID + t.StructField("calculation_id", t.StringType(), not nullable), + # + # 'wholesale_fixing' | 'first_correction_settlement' | + # 'second_correction_settlement' | 'third_correction_settlement' + t.StructField("calculation_type", t.StringType(), not nullable), + # + t.StructField("calculation_version", t.LongType(), not nullable), + # + # UUID + t.StructField("result_id", t.StringType(), not nullable), + # + t.StructField("grid_area_code", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("energy_supplier_id", t.StringType(), not nullable), + # + # EIC or GLN number + t.StructField("charge_owner_id", t.StringType(), nullable), + # + # 'DKK' + t.StructField("currency", t.StringType(), not nullable), + # + # UTC time + t.StructField("time", t.TimestampType(), not nullable), + # + t.StructField("amount", t.DecimalType(18, 6), nullable), + ] +) diff --git a/source/settlement_report_python/setup.py b/source/settlement_report_python/setup.py new file mode 100644 index 0000000..62463c1 --- /dev/null +++ b/source/settlement_report_python/setup.py @@ -0,0 +1,47 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from setuptools import setup, find_packages + +setup( + name="opengeh-settlement-report", + version=1.0, + description="Tools for settlement report", + long_description="", + long_description_content_type="text/markdown", + license="MIT", + packages=find_packages(), + # Make sure these packages are added to the docker container and pinned to the same versions + install_requires=[ + "ConfigArgParse==1.5.3", + "pyspark==3.5.1", + "delta-spark==3.1.0", + "python-dateutil==2.8.2", + "azure-monitor-opentelemetry==1.6.0", + "azure-core==1.30.0", + "opengeh-telemetry @ git+https://git@github.com/Energinet-DataHub/opengeh-python-packages@2.1.0#subdirectory=source/telemetry", + ], + entry_points={ + "console_scripts": [ + "create_hourly_time_series = settlement_report_job.entry_points.entry_point:start_hourly_time_series_points", + "create_quarterly_time_series = settlement_report_job.entry_points.entry_point:start_quarterly_time_series_points", + "create_metering_point_periods = settlement_report_job.entry_points.entry_point:start_metering_point_periods", + "create_charge_links = settlement_report_job.entry_points.entry_point:start_charge_link_periods", + "create_charge_price_points = settlement_report_job.entry_points.entry_point:start_charge_price_points", + "create_energy_results = settlement_report_job.entry_points.entry_point:start_energy_results", + "create_monthly_amounts = settlement_report_job.entry_points.entry_point:start_monthly_amounts", + "create_wholesale_results = settlement_report_job.entry_points.entry_point:start_wholesale_results", + "create_zip = settlement_report_job.entry_points.entry_point:start_zip", + ] + }, +) diff --git a/source/settlement_report_python/tests/.gitignore b/source/settlement_report_python/tests/.gitignore new file mode 100644 index 0000000..a3492fb --- /dev/null +++ b/source/settlement_report_python/tests/.gitignore @@ -0,0 +1,2 @@ +# Exclude folders generated by tests +**/__*__ diff --git a/source/settlement_report_python/tests/__init__.py b/source/settlement_report_python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/assertion.py b/source/settlement_report_python/tests/assertion.py new file mode 100644 index 0000000..023b70c --- /dev/null +++ b/source/settlement_report_python/tests/assertion.py @@ -0,0 +1,15 @@ +from pyspark.sql import SparkSession + + +def assert_file_names_and_columns( + path: str, + actual_files: list[str], + expected_columns: list[str], + expected_file_names: list[str], + spark: SparkSession, +): + assert set(actual_files) == set(expected_file_names) + for file_name in actual_files: + df = spark.read.csv(f"{path}/{file_name}", header=True) + assert df.count() > 0 + assert df.columns == expected_columns diff --git a/source/settlement_report_python/tests/conftest.py b/source/settlement_report_python/tests/conftest.py new file mode 100644 index 0000000..c7ec6e5 --- /dev/null +++ b/source/settlement_report_python/tests/conftest.py @@ -0,0 +1,508 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import uuid +import pytest +import logging +import yaml +from pathlib import Path +from typing import Callable, Generator + +from delta import configure_spark_with_delta_pip +from pyspark.sql import SparkSession + +from dbutils_fixture import DBUtilsFixture +from integration_test_configuration import IntegrationTestConfiguration +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.market_role import MarketRole + +from data_seeding import ( + standard_wholesale_fixing_scenario_data_generator, + standard_balance_fixing_scenario_data_generator, +) +from data_seeding.write_test_data import ( + write_metering_point_time_series_to_delta_table, + write_charge_link_periods_to_delta_table, + write_charge_price_points_to_delta_table, + write_charge_price_information_periods_to_delta_table, + write_energy_to_delta_table, + write_energy_per_es_to_delta_table, + write_latest_calculations_by_day_to_delta_table, + write_amounts_per_charge_to_delta_table, + write_metering_point_periods_to_delta_table, + write_monthly_amounts_per_charge_to_delta_table, + write_total_monthly_amounts_to_delta_table, +) + + +@pytest.fixture(scope="session") +def dbutils() -> DBUtilsFixture: + """ + Returns a DBUtilsFixture instance that can be used to mock dbutils. + """ + return DBUtilsFixture() + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_before_tests( + input_database_location: str, +): + + if os.path.exists(input_database_location): + shutil.rmtree(input_database_location) + + yield + + # Add cleanup code to be run after the tests + + +@pytest.fixture(scope="function") +def standard_wholesale_fixing_scenario_args( + settlement_reports_output_path: str, +) -> SettlementReportArgs: + return SettlementReportArgs( + report_id=str(uuid.uuid4()), + period_start=standard_wholesale_fixing_scenario_data_generator.FROM_DATE, + period_end=standard_wholesale_fixing_scenario_data_generator.TO_DATE, + calculation_type=CalculationType.WHOLESALE_FIXING, + calculation_id_by_grid_area={ + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[0]: uuid.UUID( + standard_wholesale_fixing_scenario_data_generator.CALCULATION_ID + ), + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[1]: uuid.UUID( + standard_wholesale_fixing_scenario_data_generator.CALCULATION_ID + ), + }, + grid_area_codes=None, + split_report_by_grid_area=True, + prevent_large_text_files=False, + time_zone="Europe/Copenhagen", + catalog_name="spark_catalog", + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.SYSTEM_OPERATOR, # using system operator since it is more complex (requires filter based on charge owner) + requesting_actor_id=standard_wholesale_fixing_scenario_data_generator.CHARGE_OWNER_ID_WITHOUT_TAX, + settlement_reports_output_path=settlement_reports_output_path, + include_basis_data=True, + ) + + +@pytest.fixture(scope="function") +def standard_wholesale_fixing_scenario_datahub_admin_args( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +) -> SettlementReportArgs: + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + return standard_wholesale_fixing_scenario_args + + +@pytest.fixture(scope="function") +def standard_wholesale_fixing_scenario_energy_supplier_args( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +) -> SettlementReportArgs: + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.ENERGY_SUPPLIER + ) + energy_supplier_id = ( + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = energy_supplier_id + standard_wholesale_fixing_scenario_args.energy_supplier_ids = [energy_supplier_id] + return standard_wholesale_fixing_scenario_args + + +@pytest.fixture(scope="function") +def standard_wholesale_fixing_scenario_grid_access_provider_args( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +) -> SettlementReportArgs: + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.GRID_ACCESS_PROVIDER + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = ( + standard_wholesale_fixing_scenario_data_generator.CHARGE_OWNER_ID_WITH_TAX + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + return standard_wholesale_fixing_scenario_args + + +@pytest.fixture(scope="function") +def standard_wholesale_fixing_scenario_system_operator_args( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +) -> SettlementReportArgs: + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.SYSTEM_OPERATOR + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = ( + standard_wholesale_fixing_scenario_data_generator.CHARGE_OWNER_ID_WITHOUT_TAX + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + return standard_wholesale_fixing_scenario_args + + +@pytest.fixture(scope="function") +def standard_balance_fixing_scenario_args( + settlement_reports_output_path: str, +) -> SettlementReportArgs: + return SettlementReportArgs( + report_id=str(uuid.uuid4()), + period_start=standard_balance_fixing_scenario_data_generator.FROM_DATE, + period_end=standard_balance_fixing_scenario_data_generator.TO_DATE, + calculation_type=CalculationType.BALANCE_FIXING, + calculation_id_by_grid_area=None, + grid_area_codes=standard_balance_fixing_scenario_data_generator.GRID_AREAS, + split_report_by_grid_area=True, + prevent_large_text_files=False, + time_zone="Europe/Copenhagen", + catalog_name="spark_catalog", + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.SYSTEM_OPERATOR, + requesting_actor_id="1212121212121", + settlement_reports_output_path=settlement_reports_output_path, + include_basis_data=True, + ) + + +@pytest.fixture(scope="function") +def standard_balance_fixing_scenario_grid_access_provider_args( + standard_balance_fixing_scenario_args: SettlementReportArgs, +) -> SettlementReportArgs: + standard_balance_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.GRID_ACCESS_PROVIDER + ) + standard_balance_fixing_scenario_args.requesting_actor_id = ( + standard_wholesale_fixing_scenario_data_generator.CHARGE_OWNER_ID_WITH_TAX + ) + standard_balance_fixing_scenario_args.energy_supplier_ids = None + return standard_balance_fixing_scenario_args + + +@pytest.fixture(scope="session") +def standard_balance_fixing_scenario_data_written_to_delta( + spark: SparkSession, + input_database_location: str, +) -> None: + time_series_points_df = standard_balance_fixing_scenario_data_generator.create_metering_point_time_series( + spark + ) + write_metering_point_time_series_to_delta_table( + spark, time_series_points_df, input_database_location + ) + + metering_point_periods = ( + standard_balance_fixing_scenario_data_generator.create_metering_point_periods( + spark + ) + ) + write_metering_point_periods_to_delta_table( + spark, metering_point_periods, input_database_location + ) + + energy_df = standard_balance_fixing_scenario_data_generator.create_energy(spark) + write_energy_to_delta_table(spark, energy_df, input_database_location) + + energy_per_es_df = ( + standard_balance_fixing_scenario_data_generator.create_energy_per_es(spark) + ) + write_energy_per_es_to_delta_table(spark, energy_per_es_df, input_database_location) + + latest_calculations_by_day = ( + standard_balance_fixing_scenario_data_generator.create_latest_calculations( + spark + ) + ) + write_latest_calculations_by_day_to_delta_table( + spark, latest_calculations_by_day, input_database_location + ) + + +@pytest.fixture(scope="session") +def standard_wholesale_fixing_scenario_data_written_to_delta( + spark: SparkSession, + input_database_location: str, +) -> None: + metering_point_periods = ( + standard_wholesale_fixing_scenario_data_generator.create_metering_point_periods( + spark + ) + ) + write_metering_point_periods_to_delta_table( + spark, metering_point_periods, input_database_location + ) + + time_series_points = standard_wholesale_fixing_scenario_data_generator.create_metering_point_time_series( + spark + ) + write_metering_point_time_series_to_delta_table( + spark, time_series_points, input_database_location + ) + + charge_link_periods = ( + standard_wholesale_fixing_scenario_data_generator.create_charge_link_periods( + spark + ) + ) + write_charge_link_periods_to_delta_table( + spark, charge_link_periods, input_database_location + ) + + charge_price_points = ( + standard_wholesale_fixing_scenario_data_generator.create_charge_price_points( + spark + ) + ) + write_charge_price_points_to_delta_table( + spark, charge_price_points, input_database_location + ) + + charge_price_information_periods = standard_wholesale_fixing_scenario_data_generator.create_charge_price_information_periods( + spark + ) + write_charge_price_information_periods_to_delta_table( + spark, charge_price_information_periods, input_database_location + ) + + energy = standard_wholesale_fixing_scenario_data_generator.create_energy(spark) + write_energy_to_delta_table(spark, energy, input_database_location) + + energy_per_es = ( + standard_wholesale_fixing_scenario_data_generator.create_energy_per_es(spark) + ) + write_energy_per_es_to_delta_table(spark, energy_per_es, input_database_location) + + amounts_per_charge = ( + standard_wholesale_fixing_scenario_data_generator.create_amounts_per_charge( + spark + ) + ) + write_amounts_per_charge_to_delta_table( + spark, amounts_per_charge, input_database_location + ) + + monthly_amounts_per_charge_df = standard_wholesale_fixing_scenario_data_generator.create_monthly_amounts_per_charge( + spark + ) + write_monthly_amounts_per_charge_to_delta_table( + spark, monthly_amounts_per_charge_df, input_database_location + ) + total_monthly_amounts_df = ( + standard_wholesale_fixing_scenario_data_generator.create_total_monthly_amounts( + spark + ) + ) + write_total_monthly_amounts_to_delta_table( + spark, total_monthly_amounts_df, input_database_location + ) + + +@pytest.fixture(scope="session") +def file_path_finder() -> Callable[[str], str]: + """ + Returns the path of the file. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + + def finder(file: str) -> str: + return os.path.dirname(os.path.normpath(file)) + + return finder + + +@pytest.fixture(scope="session") +def source_path(file_path_finder: Callable[[str], str]) -> str: + """ + Returns the /source folder path. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + return file_path_finder(f"{__file__}/../../..") + + +@pytest.fixture(scope="session") +def databricks_path(source_path: str) -> str: + """ + Returns the source/databricks folder path. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + return f"{source_path}/databricks" + + +@pytest.fixture(scope="session") +def settlement_report_path(databricks_path: str) -> str: + """ + Returns the source/databricks/ folder path. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + return f"{databricks_path}/settlement_report" + + +@pytest.fixture(scope="session") +def contracts_path(settlement_report_path: str) -> str: + """ + Returns the source/contract folder path. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + return f"{settlement_report_path}/contracts" + + +@pytest.fixture(scope="session") +def test_files_folder_path(tests_path: str) -> str: + return f"{tests_path}/test_files" + + +@pytest.fixture(scope="session") +def settlement_reports_output_path(data_lake_path: str) -> str: + return f"{data_lake_path}/settlement_reports_output" + + +@pytest.fixture(scope="session") +def input_database_location(data_lake_path: str) -> str: + return f"{data_lake_path}/input_database" + + +@pytest.fixture(scope="session") +def data_lake_path(tests_path: str, worker_id: str) -> str: + return f"{tests_path}/__data_lake__/{worker_id}" + + +@pytest.fixture(scope="session") +def tests_path(settlement_report_path: str) -> str: + """ + Returns the tests folder path. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + return f"{settlement_report_path}/tests" + + +@pytest.fixture(scope="session") +def settlement_report_job_container_path(databricks_path: str) -> str: + """ + Returns the /source folder path. + Please note that this only works if current folder haven't been changed prior using + `os.chdir()`. The correctness also relies on the prerequisite that this function is + actually located in a file located directly in the tests folder. + """ + return f"{databricks_path}/settlement_report" + + +@pytest.fixture(scope="session") +def spark( + tests_path: str, +) -> Generator[SparkSession, None, None]: + warehouse_location = f"{tests_path}/__spark-warehouse__" + + session = configure_spark_with_delta_pip( + SparkSession.builder.config("spark.sql.warehouse.dir", warehouse_location) + .config("spark.sql.streaming.schemaInference", True) + .config("spark.ui.showConsoleProgress", "false") + .config("spark.ui.enabled", "false") + .config("spark.ui.dagGraph.retainedRootRDDs", "1") + .config("spark.ui.retainedJobs", "1") + .config("spark.ui.retainedStages", "1") + .config("spark.ui.retainedTasks", "1") + .config("spark.sql.ui.retainedExecutions", "1") + .config("spark.worker.ui.retainedExecutors", "1") + .config("spark.worker.ui.retainedDrivers", "1") + .config("spark.default.parallelism", 1) + .config("spark.driver.memory", "2g") + .config("spark.executor.memory", "2g") + .config("spark.rdd.compress", False) + .config("spark.shuffle.compress", False) + .config("spark.shuffle.spill.compress", False) + .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + ).getOrCreate() + + yield session + session.stop() + + +@pytest.fixture(autouse=True) +def configure_dummy_logging() -> None: + """Ensure that logging hooks don't fail due to _TRACER_NAME not being set.""" + + from telemetry_logging.logging_configuration import configure_logging + + configure_logging( + cloud_role_name="any-cloud-role-name", tracer_name="any-tracer-name" + ) + + +@pytest.fixture(scope="session") +def integration_test_configuration(tests_path: str) -> IntegrationTestConfiguration: + """ + Load settings for integration tests either from a local YAML settings file or from environment variables. + Proceeds even if certain Azure-related keys are not present in the settings file. + """ + + settings_file_path = Path(tests_path) / "integrationtest.local.settings.yml" + + def load_settings_from_env() -> dict: + return { + key: os.getenv(key) + for key in [ + "AZURE_KEYVAULT_URL", + "AZURE_CLIENT_ID", + "AZURE_CLIENT_SECRET", + "AZURE_TENANT_ID", + "AZURE_SUBSCRIPTION_ID", + ] + if os.getenv(key) is not None + } + + settings = _load_settings_from_file(settings_file_path) or load_settings_from_env() + + # Set environment variables from loaded settings + for key, value in settings.items(): + if value is not None: + os.environ[key] = value + + if "AZURE_KEYVAULT_URL" in settings: + return IntegrationTestConfiguration( + azure_keyvault_url=settings["AZURE_KEYVAULT_URL"] + ) + + logging.error( + f"Integration test configuration could not be loaded from {settings_file_path} or environment variables." + ) + raise Exception( + "Failed to load integration test settings. Ensure that the Azure Key Vault URL is provided in the settings file or as an environment variable." + ) + + +def _load_settings_from_file(file_path: Path) -> dict: + if file_path.exists(): + with file_path.open() as stream: + return yaml.safe_load(stream) + else: + return {} diff --git a/source/settlement_report_python/tests/data_seeding/__init__.py b/source/settlement_report_python/tests/data_seeding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/data_seeding/standard_balance_fixing_scenario_data_generator.py b/source/settlement_report_python/tests/data_seeding/standard_balance_fixing_scenario_data_generator.py new file mode 100644 index 0000000..42175cb --- /dev/null +++ b/source/settlement_report_python/tests/data_seeding/standard_balance_fixing_scenario_data_generator.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + MeteringPointResolutionDataProductValue, + MeteringPointTypeDataProductValue, + SettlementMethodDataProductValue, +) +from test_factories.default_test_data_spec import create_energy_results_data_spec +from test_factories import ( + metering_point_time_series_factory, + metering_point_periods_factory, + latest_calculations_factory, + energy_factory, +) + +GRID_AREAS = ["804", "805"] +CALCULATION_ID = "ba6a4ce2-b549-494b-ad4b-80a35a05a925" +CALCULATION_TYPE = CalculationTypeDataProductValue.BALANCE_FIXING +ENERGY_SUPPLIER_IDS = ["1000000000000", "2000000000000"] +FROM_DATE = datetime(2024, 1, 1, 23) +TO_DATE = FROM_DATE + timedelta(days=1) +"""TO_DATE is exclusive""" +METERING_POINT_TYPES = [ + MeteringPointTypeDataProductValue.CONSUMPTION, + MeteringPointTypeDataProductValue.EXCHANGE, +] +RESULT_ID = "12345678-4e15-434c-9d93-b03a6dd272a5" +CALCULATION_PERIOD_START = FROM_DATE +CALCULATION_PERIOD_END = TO_DATE +QUANTITY_UNIT = "kwh" +QUANTITY_QUALITIES = ["measured"] +BALANCE_RESPONSIBLE_PARTY_ID = "1234567890123" + + +@dataclass +class MeteringPointSpec: + metering_point_id: str + metering_point_type: MeteringPointTypeDataProductValue + grid_area_code: str + energy_supplier_id: str + resolution: MeteringPointResolutionDataProductValue + + +def create_metering_point_time_series(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with metering point time series data for testing purposes. + There is one row for each combination of resolution, grid area code, and energy supplier id. + There is one calculation with two grid areas, and each grid area has two energy suppliers and each energy supplier + has one metering point in the grid area + """ + df = None + for metering_point in _get_all_metering_points(): + data_spec = ( + metering_point_time_series_factory.MeteringPointTimeSeriesTestDataSpec( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + metering_point_id=metering_point.metering_point_id, + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + resolution=metering_point.resolution, + grid_area_code=metering_point.grid_area_code, + energy_supplier_id=metering_point.energy_supplier_id, + from_date=FROM_DATE, + to_date=TO_DATE, + quantity=Decimal("1.005"), + ) + ) + next_df = metering_point_time_series_factory.create(spark, data_spec) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def create_metering_point_periods(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with metering point periods for testing purposes. + """ + rows = [] + for metering_point in _get_all_metering_points(): + rows.append( + metering_point_periods_factory.MeteringPointPeriodsRow( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + metering_point_id=metering_point.metering_point_id, + metering_point_type=metering_point.metering_point_type, + settlement_method=SettlementMethodDataProductValue.FLEX, + resolution=metering_point.resolution, + grid_area_code=metering_point.grid_area_code, + energy_supplier_id=metering_point.energy_supplier_id, + balance_responsible_party_id=BALANCE_RESPONSIBLE_PARTY_ID, + from_grid_area_code=None, + to_grid_area_code=None, + parent_metering_point_id=None, + from_date=FROM_DATE, + to_date=TO_DATE, + ) + ) + + return metering_point_periods_factory.create(spark, rows) + + +def create_latest_calculations(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with latest calculations data for testing purposes. + """ + + data_specs = [] + for grid_area_code in GRID_AREAS: + current_date = FROM_DATE + while current_date < TO_DATE: + data_specs.append( + latest_calculations_factory.LatestCalculationsPerDayRow( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + grid_area_code=grid_area_code, + start_of_day=current_date, + ) + ) + current_date += timedelta(days=1) + + return latest_calculations_factory.create(spark, data_specs) + + +def create_energy(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with energy data for testing purposes. + Mimics the wholesale_results.energy_v1 view. + """ + + df = None + for metering_point in _get_all_metering_points(): + data_spec = create_energy_results_data_spec( + grid_area_code=metering_point.grid_area_code, + metering_point_type=metering_point.metering_point_type, + resolution=metering_point.resolution, + energy_supplier_id=metering_point.energy_supplier_id, + ) + + next_df = energy_factory.create_energy_v1(spark, data_spec) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def create_energy_per_es(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with energy data for testing purposes. + Mimics the wholesale_results.energy_v1 view. + """ + + df = None + for metering_point in _get_all_metering_points(): + data_spec = create_energy_results_data_spec( + grid_area_code=metering_point.grid_area_code, + metering_point_type=metering_point.metering_point_type, + resolution=metering_point.resolution, + energy_supplier_id=metering_point.energy_supplier_id, + ) + next_df = energy_factory.create_energy_per_es_v1(spark, data_spec) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def _get_all_metering_points() -> list[MeteringPointSpec]: + metering_points = [] + count = 0 + for resolution in { + MeteringPointResolutionDataProductValue.HOUR, + MeteringPointResolutionDataProductValue.QUARTER, + }: + for grid_area_code in GRID_AREAS: + for energy_supplier_id in ENERGY_SUPPLIER_IDS: + for metering_point_type in METERING_POINT_TYPES: + metering_points.append( + MeteringPointSpec( + metering_point_id=str(1000000000000 + count), + metering_point_type=metering_point_type, + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + resolution=resolution, + ) + ) + + return metering_points diff --git a/source/settlement_report_python/tests/data_seeding/standard_wholesale_fixing_scenario_data_generator.py b/source/settlement_report_python/tests/data_seeding/standard_wholesale_fixing_scenario_data_generator.py new file mode 100644 index 0000000..1f6aae2 --- /dev/null +++ b/source/settlement_report_python/tests/data_seeding/standard_wholesale_fixing_scenario_data_generator.py @@ -0,0 +1,390 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + ChargeTypeDataProductValue, + ChargeResolutionDataProductValue, + MeteringPointResolutionDataProductValue, + MeteringPointTypeDataProductValue, + SettlementMethodDataProductValue, +) +from test_factories.default_test_data_spec import ( + create_energy_results_data_spec, + create_amounts_per_charge_row, + create_monthly_amounts_per_charge_row, + create_total_monthly_amounts_row, +) +from test_factories import ( + metering_point_periods_factory, + metering_point_time_series_factory, + charge_link_periods_factory, + charge_price_information_periods_factory, + energy_factory, + amounts_per_charge_factory, + monthly_amounts_per_charge_factory, + total_monthly_amounts_factory, + charge_price_points_factory, +) + +GRID_AREAS = ["804", "805"] +CALCULATION_ID = "12345678-6f20-40c5-9a95-f419a1245d7e" +CALCULATION_TYPE = CalculationTypeDataProductValue.WHOLESALE_FIXING +ENERGY_SUPPLIER_IDS = ["1000000000000", "2000000000000"] +FROM_DATE = datetime(2024, 1, 1, 23) +TO_DATE = FROM_DATE + timedelta(days=1) +"""TO_DATE is exclusive""" +METERING_POINT_TYPES = [ + MeteringPointTypeDataProductValue.CONSUMPTION, + MeteringPointTypeDataProductValue.EXCHANGE, +] +RESULT_ID = "12345678-4e15-434c-9d93-b03a6dd272a5" +CALCULATION_PERIOD_START = FROM_DATE +CALCULATION_PERIOD_END = TO_DATE +QUANTITY_UNIT = "kwh" +QUANTITY_QUALITIES = ["measured"] +BALANCE_RESPONSIBLE_PARTY_ID = "1234567890123" +CHARGE_OWNER_ID_WITHOUT_TAX = "5790001330552" +CHARGE_OWNER_ID_WITH_TAX = "5790001330553" + + +@dataclass +class MeteringPointSpec: + metering_point_id: str + metering_point_type: MeteringPointTypeDataProductValue + grid_area_code: str + energy_supplier_id: str + resolution: MeteringPointResolutionDataProductValue + + +@dataclass +class Charge: + charge_key: str + charge_code: str + charge_type: ChargeTypeDataProductValue + charge_owner_id: str + is_tax: bool + + +def create_metering_point_periods(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with metering point periods for testing purposes. + """ + + rows = [] + for metering_point in _get_all_metering_points(): + rows.append( + metering_point_periods_factory.MeteringPointPeriodsRow( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + metering_point_id=metering_point.metering_point_id, + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + settlement_method=SettlementMethodDataProductValue.FLEX, + grid_area_code=metering_point.grid_area_code, + resolution=metering_point.resolution, + from_grid_area_code=None, + to_grid_area_code=None, + parent_metering_point_id=None, + energy_supplier_id=metering_point.energy_supplier_id, + balance_responsible_party_id=BALANCE_RESPONSIBLE_PARTY_ID, + from_date=FROM_DATE, + to_date=TO_DATE, + ) + ) + return metering_point_periods_factory.create(spark, rows) + + +def create_metering_point_time_series(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with metering point time series data for testing purposes. + There is one row for each combination of resolution, grid area code, and energy supplier id. + There is one calculation with two grid areas, and each grid area has two energy suppliers and each energy supplier + has one metering point in the grid area + """ + df = None + for metering_point in _get_all_metering_points(): + data_spec = ( + metering_point_time_series_factory.MeteringPointTimeSeriesTestDataSpec( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + metering_point_id=metering_point.metering_point_id, + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + resolution=metering_point.resolution, + grid_area_code=metering_point.grid_area_code, + energy_supplier_id=metering_point.energy_supplier_id, + from_date=FROM_DATE, + to_date=TO_DATE, + quantity=Decimal("1.005"), + ) + ) + next_df = metering_point_time_series_factory.create(spark, data_spec) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def create_charge_link_periods(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with charge link periods data for testing purposes. + """ + + rows = [] + for metering_point in _get_all_metering_points(): + for charge in _get_all_charges(): + rows.append( + charge_link_periods_factory.ChargeLinkPeriodsRow( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + charge_key=charge.charge_key, + charge_code=charge.charge_code, + charge_type=charge.charge_type, + charge_owner_id=charge.charge_owner_id, + metering_point_id=metering_point.metering_point_id, + quantity=1, + from_date=FROM_DATE, + to_date=TO_DATE, + ) + ) + + return charge_link_periods_factory.create(spark, rows) + + +def create_charge_price_points(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with charge prices data for testing purposes. + """ + + rows = [] + for charge in _get_all_charges(): + rows.append( + charge_price_points_factory.ChargePricePointsRow( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + charge_key=charge.charge_key, + charge_code=charge.charge_code, + charge_type=charge.charge_type, + charge_owner_id=charge.charge_owner_id, + charge_price=Decimal("10"), + charge_time=FROM_DATE, + ) + ) + + return charge_price_points_factory.create(spark, rows) + + +def create_charge_price_information_periods(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with charge price information periods data for testing purposes. + """ + rows = [] + for charge in _get_all_charges(): + rows.append( + charge_price_information_periods_factory.ChargePriceInformationPeriodsRow( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_version=1, + charge_key=charge.charge_key, + charge_code=charge.charge_code, + charge_type=charge.charge_type, + charge_owner_id=charge.charge_owner_id, + is_tax=charge.is_tax, + resolution=ChargeResolutionDataProductValue.HOUR, + from_date=FROM_DATE, + to_date=TO_DATE, + ) + ) + return charge_price_information_periods_factory.create(spark, rows) + + +def create_energy(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with energy data for testing purposes. + Mimics the wholesale_results.energy_v1 view. + """ + + df = None + for metering_point in _get_all_metering_points(): + data_spec = create_energy_results_data_spec( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_period_start=FROM_DATE, + calculation_period_end=TO_DATE, + grid_area_code=metering_point.grid_area_code, + metering_point_type=metering_point.metering_point_type, + resolution=metering_point.resolution, + energy_supplier_id=metering_point.energy_supplier_id, + ) + next_df = energy_factory.create_energy_v1(spark, data_spec) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def create_amounts_per_charge(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with amounts per charge data for testing purposes. + Mimics the wholesale_results.amounts_per_charge_v1 view. + """ + + rows = [] + for charge in _get_all_charges(): + for grid_area_code in GRID_AREAS: + rows.append( + create_amounts_per_charge_row( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + time=FROM_DATE, + grid_area_code=grid_area_code, + metering_point_type=METERING_POINT_TYPES[0], + resolution=ChargeResolutionDataProductValue.HOUR, + energy_supplier_id=ENERGY_SUPPLIER_IDS[0], + is_tax=charge.is_tax, + charge_owner_id=charge.charge_owner_id, + ) + ) + + return amounts_per_charge_factory.create(spark, rows) + + +def create_monthly_amounts_per_charge(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with amounts per charge data for testing purposes. + Mimics the wholesale_results.monthly_amounts_per_charge_v1 view. + """ + + df = None + for grid_area_code in GRID_AREAS: + for energy_supplier_id in ENERGY_SUPPLIER_IDS: + for charge_owner_id in [ + CHARGE_OWNER_ID_WITH_TAX, + CHARGE_OWNER_ID_WITHOUT_TAX, + ]: + row = create_monthly_amounts_per_charge_row( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + time=FROM_DATE, + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + charge_owner_id=charge_owner_id, + ) + next_df = monthly_amounts_per_charge_factory.create(spark, row) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def create_total_monthly_amounts(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with amounts per charge data for testing purposes. + Mimics the wholesale_results.monthly_amounts_per_charge_v1 view. + """ + + df = None + for grid_area_code in GRID_AREAS: + for energy_supplier_id in ENERGY_SUPPLIER_IDS: + for charge_owner_id in [ + CHARGE_OWNER_ID_WITH_TAX, + CHARGE_OWNER_ID_WITHOUT_TAX, + None, + ]: + row = create_total_monthly_amounts_row( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + time=FROM_DATE, + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + charge_owner_id=charge_owner_id, + ) + next_df = total_monthly_amounts_factory.create(spark, row) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def create_energy_per_es(spark: SparkSession) -> DataFrame: + """ + Creates a DataFrame with energy data for testing purposes. + Mimics the wholesale_results.energy_v1 view. + """ + + df = None + for metering_point in _get_all_metering_points(): + data_spec = create_energy_results_data_spec( + calculation_id=CALCULATION_ID, + calculation_type=CALCULATION_TYPE, + calculation_period_start=FROM_DATE, + calculation_period_end=TO_DATE, + grid_area_code=metering_point.grid_area_code, + metering_point_type=metering_point.metering_point_type, + resolution=metering_point.resolution, + energy_supplier_id=metering_point.energy_supplier_id, + ) + next_df = energy_factory.create_energy_per_es_v1(spark, data_spec) + if df is None: + df = next_df + else: + df = df.union(next_df) + + return df + + +def _get_all_metering_points() -> list[MeteringPointSpec]: + metering_points = [] + count = 0 + for resolution in { + MeteringPointResolutionDataProductValue.HOUR, + MeteringPointResolutionDataProductValue.QUARTER, + }: + for grid_area_code in GRID_AREAS: + for energy_supplier_id in ENERGY_SUPPLIER_IDS: + for metering_point_type in METERING_POINT_TYPES: + metering_points.append( + MeteringPointSpec( + metering_point_id=str(1000000000000 + count), + metering_point_type=metering_point_type, + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + resolution=resolution, + ) + ) + + return metering_points + + +def _get_all_charges() -> list[Charge]: + + return [ + Charge( + charge_key=f"4000_{ChargeTypeDataProductValue.TARIFF.value}_5790001330552", + charge_code="4000", + charge_type=ChargeTypeDataProductValue.TARIFF, + charge_owner_id=CHARGE_OWNER_ID_WITHOUT_TAX, + is_tax=False, + ), + Charge( + charge_key=f"4001_{ChargeTypeDataProductValue.TARIFF.value}_5790001330553", + charge_code="4001", + charge_type=ChargeTypeDataProductValue.TARIFF, + charge_owner_id=CHARGE_OWNER_ID_WITH_TAX, + is_tax=True, + ), + ] diff --git a/source/settlement_report_python/tests/data_seeding/write_test_data.py b/source/settlement_report_python/tests/data_seeding/write_test_data.py new file mode 100644 index 0000000..549f10e --- /dev/null +++ b/source/settlement_report_python/tests/data_seeding/write_test_data.py @@ -0,0 +1,240 @@ +from pyspark.sql import SparkSession, DataFrame +from pyspark.sql.types import StructType + +from settlement_report_job.infrastructure.wholesale import ( + database_definitions, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + charge_link_periods_v1, + metering_point_periods_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + charge_price_information_periods_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + metering_point_time_series_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas.amounts_per_charge_v1 import ( + amounts_per_charge_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas.charge_price_points_v1 import ( + charge_price_points_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas.energy_per_es_v1 import ( + energy_per_es_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas.energy_v1 import ( + energy_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas.latest_calculations_by_day_v1 import ( + latest_calculations_by_day_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + monthly_amounts_per_charge_v1, +) +from settlement_report_job.infrastructure.wholesale.schemas.total_monthly_amounts_v1 import ( + total_monthly_amounts_v1, +) + + +def write_latest_calculations_by_day_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleResultsDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleResultsDatabase.LATEST_CALCULATIONS_BY_DAY_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleResultsDatabase.LATEST_CALCULATIONS_BY_DAY_VIEW_NAME}", + schema=latest_calculations_by_day_v1, + ) + + +def write_amounts_per_charge_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleResultsDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleResultsDatabase.AMOUNTS_PER_CHARGE_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleResultsDatabase.AMOUNTS_PER_CHARGE_VIEW_NAME}", + schema=amounts_per_charge_v1, + ) + + +def write_monthly_amounts_per_charge_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleResultsDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleResultsDatabase.MONTHLY_AMOUNTS_PER_CHARGE_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleResultsDatabase.MONTHLY_AMOUNTS_PER_CHARGE_VIEW_NAME}", + schema=monthly_amounts_per_charge_v1, + ) + + +def write_total_monthly_amounts_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleResultsDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleResultsDatabase.TOTAL_MONTHLY_AMOUNTS_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleResultsDatabase.TOTAL_MONTHLY_AMOUNTS_VIEW_NAME}", + schema=total_monthly_amounts_v1, + ) + + +def write_energy_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleResultsDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleResultsDatabase.ENERGY_V1_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleResultsDatabase.ENERGY_V1_VIEW_NAME}", + schema=energy_v1, + ) + + +def write_energy_per_es_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleResultsDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleResultsDatabase.ENERGY_PER_ES_V1_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleResultsDatabase.ENERGY_PER_ES_V1_VIEW_NAME}", + schema=energy_per_es_v1, + ) + + +def write_charge_price_information_periods_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleBasisDataDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleBasisDataDatabase.CHARGE_PRICE_INFORMATION_PERIODS_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleBasisDataDatabase.CHARGE_PRICE_INFORMATION_PERIODS_VIEW_NAME}", + schema=charge_price_information_periods_v1, + ) + + +def write_charge_link_periods_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleBasisDataDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleBasisDataDatabase.CHARGE_LINK_PERIODS_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleBasisDataDatabase.CHARGE_LINK_PERIODS_VIEW_NAME}", + schema=charge_link_periods_v1, + ) + + +def write_charge_price_points_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleBasisDataDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleBasisDataDatabase.CHARGE_PRICE_POINTS_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleBasisDataDatabase.CHARGE_PRICE_POINTS_VIEW_NAME}", + schema=charge_price_points_v1, + ) + + +def write_metering_point_time_series_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleBasisDataDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleBasisDataDatabase.TIME_SERIES_POINTS_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleBasisDataDatabase.TIME_SERIES_POINTS_VIEW_NAME}", + schema=metering_point_time_series_v1, + ) + + +def write_metering_point_periods_to_delta_table( + spark: SparkSession, + df: DataFrame, + table_location: str, +) -> None: + write_dataframe_to_table( + spark, + df=df, + database_name=database_definitions.WholesaleBasisDataDatabase.DATABASE_NAME, + table_name=database_definitions.WholesaleBasisDataDatabase.METERING_POINT_PERIODS_VIEW_NAME, + table_location=f"{table_location}/{database_definitions.WholesaleBasisDataDatabase.METERING_POINT_PERIODS_VIEW_NAME}", + schema=metering_point_periods_v1, + ) + + +def write_dataframe_to_table( + spark: SparkSession, + df: DataFrame, + database_name: str, + table_name: str, + table_location: str, + schema: StructType, + mode: str = "append", # Append because the tables are shared across tests +) -> None: + spark.sql(f"CREATE DATABASE IF NOT EXISTS {database_name}") + + sql_schema = _struct_type_to_sql_schema(schema) + + # Creating table if not exists - note that the table is shared across tests, and should therefore not be deleted first. + spark.sql( + f"CREATE TABLE IF NOT EXISTS {database_name}.{table_name} ({sql_schema}) USING DELTA LOCATION '{table_location}'" + ) + df.write.format("delta").option("overwriteSchema", "true").mode(mode).saveAsTable( + f"{database_name}.{table_name}" + ) + + +def _struct_type_to_sql_schema(schema: StructType) -> str: + schema_string = "" + for field in schema.fields: + field_name = field.name + field_type = field.dataType.simpleString() + + if not field.nullable: + field_type += " NOT NULL" + + schema_string += f"{field_name} {field_type}, " + + # Remove the trailing comma and space + schema_string = schema_string.rstrip(", ") + return schema_string diff --git a/source/settlement_report_python/tests/dbutils_fixture.py b/source/settlement_report_python/tests/dbutils_fixture.py new file mode 100644 index 0000000..fef9684 --- /dev/null +++ b/source/settlement_report_python/tests/dbutils_fixture.py @@ -0,0 +1,73 @@ +import os +import shutil + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class DBUtilsFileInfoFixture: + """ + This class mocks the DBUtils FileInfo object + """ + + path: str + name: str + size: int + modificationTime: int + + +class DBUtilsSecretsFixture: + def __init__(self): + self.secrets = self + + def get(self, scope, name): + return os.environ.get(name) + + +class DBUtilsFixture: + """ + This class is used for mocking the behaviour of DBUtils inside tests. + """ + + def __init__(self): + self.fs = self + self.secrets = DBUtilsSecretsFixture() + + def _clean_path(self, path: str): + return path.replace("file:", "") + + def cp(self, src: str, dest: str, recurse: bool = False): + copy_func = shutil.copytree if recurse else shutil.copy + copy_func(self._clean_path(src), self._clean_path(dest)) + + def ls(self, path: str): + _paths = Path(self._clean_path(path)).glob("*") + _objects = [ + DBUtilsFileInfoFixture( + str(p.absolute()), p.name, p.stat().st_size, int(p.stat().st_mtime) + ) + for p in _paths + ] + return _objects + + def mkdirs(self, path: str): + Path(self._clean_path(path)).mkdir(parents=True, exist_ok=True) + + def mv(self, src: str, dest: str, recurse: bool = False): + copy_func = shutil.copytree if recurse else shutil.copy + shutil.move( + self._clean_path(src), self._clean_path(dest), copy_function=copy_func + ) + + def put(self, path: str, content: str, overwrite: bool = False): + _f = Path(self._clean_path(path)) + + if _f.exists() and not overwrite: + raise FileExistsError("File already exists") + + _f.write_text(content, encoding="utf-8") + + def rm(self, path: str, recurse: bool = False): + deletion_func = shutil.rmtree if recurse else os.remove + deletion_func(self._clean_path(path)) diff --git a/source/settlement_report_python/tests/domain/__init__.py b/source/settlement_report_python/tests/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/domain/charge_link_periods/test_charge_link_periods_read_and_filter.py b/source/settlement_report_python/tests/domain/charge_link_periods/test_charge_link_periods_read_and_filter.py new file mode 100644 index 0000000..8c0dc0d --- /dev/null +++ b/source/settlement_report_python/tests/domain/charge_link_periods/test_charge_link_periods_read_and_filter.py @@ -0,0 +1,570 @@ +import uuid +from datetime import datetime +from functools import reduce +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession, DataFrame +import test_factories.default_test_data_spec as default_data +import test_factories.charge_link_periods_factory as charge_link_periods_factory +import test_factories.metering_point_periods_factory as metering_point_periods_factory +import test_factories.charge_price_information_periods_factory as charge_price_information_periods_factory +from settlement_report_job.domain.charge_link_periods.read_and_filter import ( + read_and_filter, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +GRID_ACCESS_PROVIDER_ID = "4444444444444" +OTHER_ID = "9999999999999" +DEFAULT_CALCULATION_ID_BY_GRID_AREA = { + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) +} + +JAN_1ST = datetime(2023, 12, 31, 23) +JAN_2ND = datetime(2024, 1, 1, 23) +JAN_3RD = datetime(2024, 1, 2, 23) +JAN_4TH = datetime(2024, 1, 3, 23) +JAN_5TH = datetime(2024, 1, 4, 23) +JAN_6TH = datetime(2024, 1, 5, 23) +JAN_7TH = datetime(2024, 1, 6, 23) +JAN_8TH = datetime(2024, 1, 7, 23) +JAN_9TH = datetime(2024, 1, 8, 23) + + +def _get_repository_mock( + metering_point_period: DataFrame, + charge_link_periods: DataFrame, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + mock_repository.read_charge_link_periods.return_value = charge_link_periods + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +@pytest.mark.parametrize( + "charge_link_from_date,charge_link_to_date,is_included", + [ + pytest.param( + JAN_1ST, + JAN_2ND, + False, + id="charge link period stops before selected period", + ), + pytest.param( + JAN_1ST, + JAN_3RD, + True, + id="charge link starts before and ends within selected period", + ), + pytest.param( + JAN_3RD, + JAN_4TH, + True, + id="charge link period is within selected period", + ), + pytest.param( + JAN_3RD, + JAN_5TH, + True, + id="charge link starts within but stops after selected period", + ), + pytest.param( + JAN_4TH, JAN_5TH, False, id="charge link starts after selected period" + ), + ], +) +def test_read_and_filter__returns_charge_link_periods_that_overlap_with_selected_period( + spark: SparkSession, + charge_link_from_date: datetime, + charge_link_to_date: datetime, + is_included: bool, +) -> None: + # Arrange + period_start = JAN_2ND + period_end = JAN_4TH + + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + from_date=charge_link_from_date, to_date=charge_link_to_date + ), + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + from_date=charge_link_from_date, to_date=charge_link_to_date + ), + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual_df = read_and_filter( + period_start=period_start, + period_end=period_end, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert (actual_df.count() > 0) == is_included + + +def test_read_and_filter__returns_only_selected_grid_area( + spark: SparkSession, +) -> None: + # Arrange + selected_grid_area_code = "805" + not_selected_grid_area_code = "806" + selected_metering_point = "555" + not_selected_metering_point = "666" + + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + grid_area_code=selected_grid_area_code, + metering_point_id=selected_metering_point, + ), + ).union( + metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + grid_area_code=not_selected_grid_area_code, + metering_point_id=not_selected_metering_point, + ), + ) + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + metering_point_id=selected_metering_point, + ), + ).union( + charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + metering_point_id=not_selected_metering_point, + ), + ) + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + selected_grid_area_code: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + actual_grid_area_codes = ( + actual_df.select(DataProductColumnNames.grid_area_code).distinct().collect() + ) + assert len(actual_grid_area_codes) == 1 + assert actual_grid_area_codes[0][0] == selected_grid_area_code + + +def test_read_and_filter__returns_only_rows_from_selected_calculation_id( + spark: SparkSession, +) -> None: + # Arrange + selected_calculation_id = "11111111-9fc8-409a-a169-fbd49479d718" + not_selected_calculation_id = "22222222-9fc8-409a-a169-fbd49479d718" + expected_metering_point_id = "123456789012345678901234567" + other_metering_point_id = "765432109876543210987654321" + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + calculation_id=selected_calculation_id, + metering_point_id=expected_metering_point_id, + ), + ).union( + charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + calculation_id=not_selected_calculation_id, + metering_point_id=other_metering_point_id, + ), + ) + ) + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + calculation_id=selected_calculation_id, + metering_point_id=expected_metering_point_id, + ), + ).union( + metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + calculation_id=not_selected_calculation_id, + metering_point_id=other_metering_point_id, + ), + ) + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(selected_calculation_id) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + actual_metering_point_ids = ( + actual_df.select(DataProductColumnNames.metering_point_id).distinct().collect() + ) + assert len(actual_metering_point_ids) == 1 + assert ( + actual_metering_point_ids[0][DataProductColumnNames.metering_point_id] + == expected_metering_point_id + ) + + +ENERGY_SUPPLIER_A = "1000000000000" +ENERGY_SUPPLIER_B = "2000000000000" +ENERGY_SUPPLIER_C = "3000000000000" +ENERGY_SUPPLIERS_ABC = [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B, ENERGY_SUPPLIER_C] +METERING_POINT_ID_ABC = ["123", "456", "789"] + + +@pytest.mark.parametrize( + "selected_energy_supplier_ids,expected_energy_supplier_ids", + [ + (None, ENERGY_SUPPLIERS_ABC), + ([ENERGY_SUPPLIER_B], [ENERGY_SUPPLIER_B]), + ( + [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B], + [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B], + ), + (ENERGY_SUPPLIERS_ABC, ENERGY_SUPPLIERS_ABC), + ], +) +def test_read_and_filter__returns_data_for_expected_energy_suppliers( + spark: SparkSession, + selected_energy_supplier_ids: list[str] | None, + expected_energy_supplier_ids: list[str], +) -> None: + # Arrange + metering_point_periods = reduce( + lambda df1, df2: df1.union(df2), + [ + metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + energy_supplier_id=energy_supplier_id, + metering_point_id=metering_point_id, + ), + ) + for energy_supplier_id, metering_point_id in zip( + ENERGY_SUPPLIERS_ABC, METERING_POINT_ID_ABC + ) + ], + ) + charge_link_periods = reduce( + lambda df1, df2: df1.union(df2), + [ + charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + metering_point_id=metering_point_id, + ), + ) + for metering_point_id in METERING_POINT_ID_ABC + ], + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=selected_energy_supplier_ids, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert set( + row[DataProductColumnNames.energy_supplier_id] for row in actual_df.collect() + ) == set(expected_energy_supplier_ids) + + +@pytest.mark.parametrize( + "charge_owner_id,is_tax,return_rows", + [ + pytest.param( + SYSTEM_OPERATOR_ID, False, True, id="system operator without tax: include" + ), + pytest.param( + SYSTEM_OPERATOR_ID, True, False, id="system operator with tax: exclude" + ), + pytest.param( + OTHER_ID, False, False, id="other charge owner without tax: exclude" + ), + pytest.param(OTHER_ID, True, False, id="other charge owner with tax: exclude"), + ], +) +def test_read_and_filter__when_system_operator__returns_expected_charge_link_periods( + spark: SparkSession, + charge_owner_id: str, + is_tax: bool, + return_rows: bool, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=charge_owner_id, + is_tax=is_tax, + ), + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_owner_id=charge_owner_id), + ) + mock_repository = _get_repository_mock( + metering_point_periods, charge_link_periods, charge_price_information_periods + ) + + # Act + actual = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.SYSTEM_OPERATOR, + requesting_actor_id=SYSTEM_OPERATOR_ID, + repository=mock_repository, + ) + + # Assert + assert (actual.count() > 0) == return_rows + + +@pytest.mark.parametrize( + "charge_owner_id,is_tax,return_rows", + [ + pytest.param( + GRID_ACCESS_PROVIDER_ID, + False, + True, + id="grid access provider without tax: include", + ), + pytest.param( + OTHER_ID, False, False, id="other charge owner without tax: exclude" + ), + pytest.param(OTHER_ID, True, True, id="other charge owner with tax: include"), + ], +) +def test_read_and_filter__when_grid_access_provider__returns_expected_charge_link_periods( + spark: SparkSession, + charge_owner_id: str, + is_tax: bool, + return_rows: bool, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=charge_owner_id, + is_tax=is_tax, + ), + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_owner_id=charge_owner_id), + ) + mock_repository = _get_repository_mock( + metering_point_periods, charge_link_periods, charge_price_information_periods + ) + + # Act + actual = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.GRID_ACCESS_PROVIDER, + requesting_actor_id=GRID_ACCESS_PROVIDER_ID, + repository=mock_repository, + ) + + # Assert + assert (actual.count() > 0) == return_rows + + +def test_read_and_filter__when_energy_supplier_changes_on_metering_point__returns_one_link_period( + spark: SparkSession, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + energy_supplier_id="1", from_date=JAN_1ST, to_date=JAN_2ND + ), + default_data.create_metering_point_periods_row( + energy_supplier_id="2", from_date=JAN_2ND, to_date=JAN_3RD + ), + ], + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + from_date=JAN_1ST, to_date=JAN_3RD, charge_owner_id=GRID_ACCESS_PROVIDER_ID + ), + ) + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + from_date=JAN_1ST, + to_date=JAN_3RD, + is_tax=True, + charge_owner_id=GRID_ACCESS_PROVIDER_ID, + ), + ) + mock_repository = _get_repository_mock( + metering_point_periods, charge_link_periods, charge_price_information_periods + ) + + # Act + actual = read_and_filter( + period_start=JAN_1ST, + period_end=JAN_3RD, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.GRID_ACCESS_PROVIDER, + requesting_actor_id=GRID_ACCESS_PROVIDER_ID, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + assert actual.select(DataProductColumnNames.from_date).collect()[0][0] == JAN_1ST + assert actual.select(DataProductColumnNames.to_date).collect()[0][0] == JAN_3RD + + +def test_read_and_filter__when_datahub_user_and_energy_supplier_changes_on_metering_point__returns_two_link_periods( + spark: SparkSession, +) -> None: + # Arrange + es_id_a = "111" + es_id_b = "222" + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + energy_supplier_id=es_id_a, + from_date=JAN_1ST, + to_date=JAN_2ND, + ), + default_data.create_metering_point_periods_row( + energy_supplier_id=es_id_b, + from_date=JAN_2ND, + to_date=JAN_3RD, + ), + ], + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(from_date=JAN_1ST, to_date=JAN_3RD), + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual = read_and_filter( + period_start=JAN_1ST, + period_end=JAN_3RD, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.count() == 2 + + actual_row_1 = actual.collect()[0] + assert actual_row_1[DataProductColumnNames.energy_supplier_id] == es_id_a + assert actual_row_1[DataProductColumnNames.from_date] == JAN_1ST + assert actual_row_1[DataProductColumnNames.to_date] == JAN_2ND + + actual_row_2 = actual.collect()[1] + assert actual_row_2[DataProductColumnNames.energy_supplier_id] == es_id_b + assert actual_row_2[DataProductColumnNames.from_date] == JAN_2ND + assert actual_row_2[DataProductColumnNames.to_date] == JAN_3RD + + +def test_read_and_filter__when_duplicate_metering_point_periods__returns_one_link_period_per_duplicate( + spark: SparkSession, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row(), + default_data.create_metering_point_periods_row(), + ], + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(), + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual = read_and_filter( + period_start=JAN_1ST, + period_end=JAN_3RD, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 diff --git a/source/settlement_report_python/tests/domain/charge_price_points/test_charge_price_points_prepare_for_csv.py b/source/settlement_report_python/tests/domain/charge_price_points/test_charge_price_points_prepare_for_csv.py new file mode 100644 index 0000000..f4d50ce --- /dev/null +++ b/source/settlement_report_python/tests/domain/charge_price_points/test_charge_price_points_prepare_for_csv.py @@ -0,0 +1,240 @@ +import uuid +from datetime import datetime, timedelta +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession, DataFrame, functions as F + +import test_factories.default_test_data_spec as default_data +from settlement_report_job.domain.charge_price_points.prepare_for_csv import ( + prepare_for_csv, +) + + +import test_factories.charge_price_points_factory as charge_price_points_factory +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + ChargeResolutionDataProductValue, +) +from utils import Dates + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +ENERGY_SUPPLIER_IDS = ["1234567890123", "2345678901234"] +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +GRID_ACCESS_PROVIDER_ID = "4444444444444" +OTHER_ID = "9999999999999" +DEFAULT_CALCULATION_ID_BY_GRID_AREA = { + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) +} +DEFAULT_TIME_ZONE = "Europe/Copenhagen" + + +def _get_repository_mock( + metering_point_period: DataFrame, + charge_link_periods: DataFrame, + charge_price_points: DataFrame, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + mock_repository.read_charge_link_periods.return_value = charge_link_periods + mock_repository.read_charge_price_points.return_value = charge_price_points + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +@pytest.mark.parametrize( + "resolution", + [ + ChargeResolutionDataProductValue.DAY, + ChargeResolutionDataProductValue.MONTH, + ], +) +def test_when_resolution_is_day_or_month_return_only_value_in_energy_price_1( + spark: SparkSession, + resolution: ChargeResolutionDataProductValue, +) -> None: + # Arrange + filtered_charge_price_points = ( + charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(), + ) + .withColumn( + DataProductColumnNames.grid_area_code, + F.lit(default_data.DEFAULT_GRID_AREA_CODE), + ) + .withColumn(DataProductColumnNames.is_tax, F.lit(False)) + .withColumn( + DataProductColumnNames.resolution, + F.lit(resolution.value), + ) + ) + + # Act + result_df = prepare_for_csv( + filtered_charge_price_points=filtered_charge_price_points, + time_zone=DEFAULT_TIME_ZONE, + ) + + # Assert + assert result_df.count() == 1 + result = result_df.collect()[0] + assert result["ENERGYPRICE1"] == default_data.DEFAULT_CHARGE_PRICE + for i in range(2, 26): + assert result[f"ENERGYPRICE{i}"] is None + + +def test_when_resolution_is_hour_return_one_row_with_value_in_every_energy_price_except_25( + spark: SparkSession, +) -> None: + # Arrange + hours_in_day = [Dates.JAN_1ST + timedelta(hours=i) for i in range(24)] + charge_price_rows = [] + for i in range(24): + charge_price_rows.append( + default_data.create_charge_price_points_row( + charge_time=hours_in_day[i], + charge_price=default_data.DEFAULT_CHARGE_PRICE + i, + ) + ) + + filtered_charge_price_points = ( + charge_price_points_factory.create( + spark, + charge_price_rows, + ) + .withColumn( + DataProductColumnNames.grid_area_code, + F.lit(default_data.DEFAULT_GRID_AREA_CODE), + ) + .withColumn(DataProductColumnNames.is_tax, F.lit(False)) + .withColumn( + DataProductColumnNames.resolution, + F.lit(ChargeResolutionDataProductValue.HOUR.value), + ) + ) + + # Act + result_df = prepare_for_csv( + filtered_charge_price_points=filtered_charge_price_points, + time_zone=DEFAULT_TIME_ZONE, + ) + + # Assert + assert result_df.count() == 1 + result = result_df.collect()[0] + for i in range(1, 25): + assert ( + result[f"{CsvColumnNames.energy_price}{i}"] + == default_data.DEFAULT_CHARGE_PRICE + i - 1 + ) + assert result[f"{CsvColumnNames.energy_price}25"] is None + + +@pytest.mark.parametrize( + "is_tax, expected_tax_indicator", + [ + (True, 1), + (False, 0), + ], +) +def test_tax_indicator_is_converted_correctly( + spark: SparkSession, + is_tax: bool, + expected_tax_indicator: int, +) -> None: + # Arrange + filtered_charge_price_points = ( + charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(), + ) + .withColumn( + DataProductColumnNames.grid_area_code, + F.lit(default_data.DEFAULT_GRID_AREA_CODE), + ) + .withColumn(DataProductColumnNames.is_tax, F.lit(is_tax)) + .withColumn( + DataProductColumnNames.resolution, + F.lit(ChargeResolutionDataProductValue.DAY.value), + ) + ) + + # Act + result_df = prepare_for_csv( + filtered_charge_price_points=filtered_charge_price_points, + time_zone=DEFAULT_TIME_ZONE, + ) + + # Assert + assert result_df.collect()[0][CsvColumnNames.is_tax] == expected_tax_indicator + + +@pytest.mark.parametrize( + "daylight_savings, expected_energy_price_columns_with_value", + [ + ( + datetime(2023, 3, 25, 23), + 23, + ), + ( + datetime(2023, 10, 28, 22), + 25, + ), + ], +) +def test_when_daylight_savings_time_return_number_of_expected_rows( + spark: SparkSession, + daylight_savings: datetime, + expected_energy_price_columns_with_value: int, +) -> None: + # Arrange + hours_in_day = [daylight_savings + timedelta(hours=i) for i in range(25)] + charge_price_rows = [] + for i in range(25): + charge_price_rows.append( + default_data.create_charge_price_points_row( + charge_time=hours_in_day[i], + charge_price=default_data.DEFAULT_CHARGE_PRICE + i, + ) + ) + + filtered_charge_price_points = ( + charge_price_points_factory.create( + spark, + charge_price_rows, + ) + .withColumn( + DataProductColumnNames.grid_area_code, + F.lit(default_data.DEFAULT_GRID_AREA_CODE), + ) + .withColumn(DataProductColumnNames.is_tax, F.lit(False)) + .withColumn( + DataProductColumnNames.resolution, + F.lit(ChargeResolutionDataProductValue.HOUR.value), + ) + ) + + # Act + result_df = prepare_for_csv( + filtered_charge_price_points=filtered_charge_price_points, + time_zone=DEFAULT_TIME_ZONE, + ) + + # Assert + assert_count = 0 + result = result_df.collect()[0] + for i in range(1, 26): + if result[f"{CsvColumnNames.energy_price}{i}"] is not None: + assert_count += 1 + assert assert_count == expected_energy_price_columns_with_value diff --git a/source/settlement_report_python/tests/domain/charge_price_points/test_charge_price_points_read_and_filter.py b/source/settlement_report_python/tests/domain/charge_price_points/test_charge_price_points_read_and_filter.py new file mode 100644 index 0000000..ecaef16 --- /dev/null +++ b/source/settlement_report_python/tests/domain/charge_price_points/test_charge_price_points_read_and_filter.py @@ -0,0 +1,553 @@ +import uuid +from datetime import datetime +from unittest.mock import Mock +from uuid import UUID + +import pytest +from pyspark.sql import SparkSession, DataFrame +import test_factories.default_test_data_spec as default_data +import test_factories.charge_link_periods_factory as charge_link_periods_factory +import test_factories.metering_point_periods_factory as metering_point_periods_factory +import test_factories.charge_price_points_factory as charge_price_points_factory +import test_factories.charge_price_information_periods_factory as charge_price_information_periods_factory +from settlement_report_job.domain.charge_price_points.read_and_filter import ( + read_and_filter, +) +from settlement_report_job.domain.utils.market_role import MarketRole + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +ENERGY_SUPPLIER_IDS = ["1234567890123", "2345678901234"] +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +GRID_ACCESS_PROVIDER_ID = "4444444444444" +OTHER_ID = "9999999999999" +DEFAULT_CALCULATION_ID_BY_GRID_AREA = { + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) +} + +JAN_1ST = datetime(2023, 12, 31, 23) +JAN_2ND = datetime(2024, 1, 1, 23) +JAN_3RD = datetime(2024, 1, 2, 23) +JAN_4TH = datetime(2024, 1, 3, 23) +JAN_5TH = datetime(2024, 1, 4, 23) +JAN_6TH = datetime(2024, 1, 5, 23) +JAN_7TH = datetime(2024, 1, 6, 23) +JAN_8TH = datetime(2024, 1, 7, 23) +JAN_9TH = datetime(2024, 1, 8, 23) + + +def _get_repository_mock( + metering_point_period: DataFrame, + charge_link_periods: DataFrame, + charge_price_points: DataFrame, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + mock_repository.read_charge_link_periods.return_value = charge_link_periods + mock_repository.read_charge_price_points.return_value = charge_price_points + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +@pytest.mark.parametrize( + "args_energy_supplier_ids, expected_rows", + [ + pytest.param( + ["1"], + 1, + id="when the energy supplier id matches with a metering point period which matches with a charge link period and charge prices charge time is with in period, return 1 row", + ), + pytest.param( + ["2"], + 0, + id="when the energy supplier id matches with a metering point period which matches with a charge link period and charge price, but the charge time is not with in the period, return 0 row", + ), + ], +) +def test_when_energy_supplier_ids_contain_only_one_energy_supplier_id( + spark: SparkSession, + args_energy_supplier_ids: list[str] | None, + expected_rows: int, +) -> None: + # Arrange + energy_supplier_id_1 = "1" + energy_supplier_id_2 = "2" + + metering_point_id_1 = "1" + metering_point_id_2 = "2" + + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + metering_point_id=metering_point_id_1, + energy_supplier_id=energy_supplier_id_1, + from_date=JAN_1ST, + to_date=JAN_4TH, + ), + default_data.create_metering_point_periods_row( + metering_point_id=metering_point_id_2, + energy_supplier_id=energy_supplier_id_2, + from_date=JAN_3RD, + to_date=JAN_4TH, + ), + ], + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, + [ + default_data.create_charge_link_periods_row( + metering_point_id=metering_point_id_1, + from_date=JAN_1ST, + to_date=JAN_4TH, + ), + default_data.create_charge_link_periods_row( + metering_point_id=metering_point_id_2, + from_date=JAN_3RD, + to_date=JAN_4TH, + ), + ], + ) + + charge_price_points = charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(charge_time=JAN_2ND), + ) + + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row(), + ) + + mock_repository = _get_repository_mock( + metering_point_periods, + charge_link_periods, + charge_price_points, + charge_price_information_periods, + ) + + # Act + actual_df = read_and_filter( + period_start=JAN_1ST, + period_end=JAN_4TH, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=args_energy_supplier_ids, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == expected_rows + + +@pytest.mark.parametrize( + "charge_time, expected_rows", + [ + pytest.param( + JAN_2ND, + 1, + id="Charge time is within the period of one of the energy suppliers, then 1 row is returned", + ), + pytest.param( + JAN_3RD, + 1, + id="Charge time is within the period of two of the energy suppliers (duplicated rows will be removed), then 1 row is returned", + ), + pytest.param( + JAN_5TH, + 0, + id="Charge time is outside the period of the energy suppliers, then 0 rows are returned", + ), + ], +) +def test_when_two_energy_suppliers_ids_with_different_periods( + spark: SparkSession, + charge_time: datetime, + expected_rows: int, +) -> None: + # Arrange + energy_supplier_id_1 = "1" + energy_supplier_id_2 = "2" + energy_supplier_ids = [energy_supplier_id_1, energy_supplier_id_2] + + metering_point_id_1 = "1" + metering_point_id_2 = "2" + + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + metering_point_id=metering_point_id_1, + energy_supplier_id=energy_supplier_id_1, + from_date=JAN_1ST, + to_date=JAN_4TH, + ), + default_data.create_metering_point_periods_row( + metering_point_id=metering_point_id_2, + energy_supplier_id=energy_supplier_id_1, + from_date=JAN_3RD, + to_date=JAN_4TH, + ), + ], + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, + [ + default_data.create_charge_link_periods_row( + metering_point_id=metering_point_id_1, + from_date=JAN_1ST, + to_date=JAN_4TH, + ), + default_data.create_charge_link_periods_row( + metering_point_id=metering_point_id_2, + from_date=JAN_3RD, + to_date=JAN_4TH, + ), + ], + ) + + charge_price_points = charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(charge_time=charge_time), + ) + + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row(), + ) + + mock_repository = _get_repository_mock( + metering_point_periods, + charge_link_periods, + charge_price_points, + charge_price_information_periods, + ) + + # Act + actual_df = read_and_filter( + period_start=JAN_1ST, + period_end=JAN_4TH, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=energy_supplier_ids, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == expected_rows + + +@pytest.mark.parametrize( + "args_start_date, args_end_date, expected_rows", + [ + pytest.param( + JAN_2ND, + JAN_9TH, + 1, + id="when time is within the range, return 1 row", + ), + pytest.param( + JAN_5TH, + JAN_9TH, + 0, + id="when time is outside the range, return 0 rows", + ), + ], +) +def test_time_within_and_outside_of_date_range_scenarios( + spark: SparkSession, + args_start_date: datetime, + args_end_date: datetime, + expected_rows: int, +) -> None: + # Arrange + charge_time = JAN_3RD + + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + from_date=JAN_1ST, + to_date=JAN_4TH, + ), + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(from_date=JAN_1ST, to_date=JAN_4TH), + ) + + charge_price_points = charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(charge_time=charge_time), + ) + + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row(), + ) + + mock_repository = _get_repository_mock( + metering_point_periods, + charge_link_periods, + charge_price_points, + charge_price_information_periods, + ) + + # Act + actual_df = read_and_filter( + period_start=args_start_date, + period_end=args_end_date, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=ENERGY_SUPPLIER_IDS, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == expected_rows + + +@pytest.mark.parametrize( + "args_energy_supplier_ids, expected_rows", + [ + pytest.param( + ["1234567890123"], + 1, + id="when energy_supplier_id is in energy_supplier_ids, return 1 row", + ), + pytest.param( + ["2345678901234"], + 0, + id="when energy_supplier_id is not in energy_supplier_ids, return 0 rows", + ), + pytest.param( + None, + 1, + id="when energy_supplier_ids is None, return 1 row", + ), + ], +) +def test_energy_supplier_ids_scenarios( + spark: SparkSession, + args_energy_supplier_ids: list[str] | None, + expected_rows: int, +) -> None: + # Arrange + energy_supplier_id = "1234567890123" + + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + energy_supplier_id=energy_supplier_id + ), + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, default_data.create_charge_link_periods_row() + ) + + charge_price_points = charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(), + ) + + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row(), + ) + + mock_repository = _get_repository_mock( + metering_point_periods, + charge_link_periods, + charge_price_points, + charge_price_information_periods, + ) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=args_energy_supplier_ids, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == expected_rows + + +@pytest.mark.parametrize( + "args_calculation_id_by_grid_area, expected_rows", + [ + pytest.param( + {"804": UUID(default_data.DEFAULT_CALCULATION_ID)}, + 1, + id="when calculation_id and grid_area_code is in calculation_id_by_grid_area, return 1 row", + ), + pytest.param( + {"500": UUID(default_data.DEFAULT_CALCULATION_ID)}, + 0, + id="when grid_area_code is not in calculation_id_by_grid_area, return 0 rows", + ), + pytest.param( + {"804": UUID("11111111-1111-2222-1111-111111111111")}, + 0, + id="when calculation_id is not in calculation_id_by_grid_area, return 0 row", + ), + pytest.param( + {"500": UUID("11111111-1111-2222-1111-111111111111")}, + 0, + id="when calculation_id and grid_area_code is not in calculation_id_by_grid_area, return 0 row", + ), + ], +) +def test_calculation_id_by_grid_area_scenarios( + spark: SparkSession, + args_calculation_id_by_grid_area: dict[str, UUID], + expected_rows: int, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + calculation_id=default_data.DEFAULT_CALCULATION_ID, grid_area_code="804" + ), + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + calculation_id=default_data.DEFAULT_CALCULATION_ID + ), + ) + + charge_price_points = charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row( + calculation_id=default_data.DEFAULT_CALCULATION_ID + ), + ) + + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + calculation_id=default_data.DEFAULT_CALCULATION_ID + ), + ) + + mock_repository = _get_repository_mock( + metering_point_periods, + charge_link_periods, + charge_price_points, + charge_price_information_periods, + ) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=args_calculation_id_by_grid_area, + energy_supplier_ids=ENERGY_SUPPLIER_IDS, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == expected_rows + + +@pytest.mark.parametrize( + "args_requesting_actor_market_role, args_requesting_actor_id, is_tax, expected_rows", + [ + pytest.param( + MarketRole.GRID_ACCESS_PROVIDER, + "1111111111111", + True, + 1, + id="When grid_access_provider and charge_owner_id equals requesting_actor_id and is_tax is True, return 1 row", + ), + pytest.param( + MarketRole.GRID_ACCESS_PROVIDER, + default_data.DEFAULT_CHARGE_OWNER_ID, + False, + 1, + id="When grid_access_provider and charge_owner_id equals requesting_actor_id and is_tax is False, return 0 rows", + ), + pytest.param( + MarketRole.SYSTEM_OPERATOR, + default_data.DEFAULT_CHARGE_OWNER_ID, + True, + 0, + id="When system_operator and charge_owner_id equals requesting_actor_id and is_tax is True, return 0 rows", + ), + pytest.param( + MarketRole.SYSTEM_OPERATOR, + default_data.DEFAULT_CHARGE_OWNER_ID, + False, + 1, + id="When system_operator and charge_owner_id equals requesting_actor_id and is_tax is False, return 1 rows", + ), + ], +) +def test_grid_access_provider_and_system_operator_scenarios( + spark: SparkSession, + args_requesting_actor_market_role: MarketRole, + args_requesting_actor_id: str, + is_tax: bool, + expected_rows: int, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(), + ) + + charge_price_points = charge_price_points_factory.create( + spark, + default_data.create_charge_price_points_row(), + ) + + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row(is_tax=is_tax), + ) + + mock_repository = _get_repository_mock( + metering_point_periods, + charge_link_periods, + charge_price_points, + charge_price_information_periods, + ) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=ENERGY_SUPPLIER_IDS, + requesting_actor_market_role=args_requesting_actor_market_role, + requesting_actor_id=args_requesting_actor_id, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == expected_rows diff --git a/source/settlement_report_python/tests/domain/energy_results/test_energy_read_and_filter.py b/source/settlement_report_python/tests/domain/energy_results/test_energy_read_and_filter.py new file mode 100644 index 0000000..67a83ee --- /dev/null +++ b/source/settlement_report_python/tests/domain/energy_results/test_energy_read_and_filter.py @@ -0,0 +1,506 @@ +from datetime import timedelta +from functools import reduce +import uuid +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession, functions as F +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, +) +from tests.test_factories import latest_calculations_factory +import tests.test_factories.default_test_data_spec as default_data +import tests.test_factories.energy_factory as energy_factory + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.energy_results.read_and_filter import ( + read_and_filter_from_view, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +NOT_SYSTEM_OPERATOR_ID = "4444444444444" +DEFAULT_TIME_ZONE = "Europe/Copenhagen" +DEFAULT_CALCULATION_ID = "12345678-6f20-40c5-9a95-f419a1245d7e" + + +@pytest.fixture(scope="session") +def energy_read_and_filter_mock_repository( + spark: SparkSession, +) -> Mock: + mock_repository = Mock() + + df_energy_v1 = None + df_energy_per_es_v1 = None + + # The inner-most loop generates 24 entries, so in total each view will contain: + # 2 * 3 * 24 = 144 rows. + for grid_area in ["804", "805"]: + for energy_supplier_id in ["1000000000000", "2000000000000", "3000000000000"]: + testing_spec = default_data.create_energy_results_data_spec( + energy_supplier_id=energy_supplier_id, + calculation_id=DEFAULT_CALCULATION_ID, + grid_area_code=grid_area, + ) + + if df_energy_v1 is None: + df_energy_v1 = energy_factory.create_energy_v1(spark, testing_spec) + else: + df_energy_v1 = df_energy_v1.union( + energy_factory.create_energy_v1(spark, testing_spec) + ) + + if df_energy_per_es_v1 is None: + df_energy_per_es_v1 = energy_factory.create_energy_per_es_v1( + spark, testing_spec + ) + else: + df_energy_per_es_v1 = df_energy_per_es_v1.union( + energy_factory.create_energy_per_es_v1(spark, testing_spec) + ) + + mock_repository.read_energy.return_value = df_energy_v1 + mock_repository.read_energy_per_es.return_value = df_energy_per_es_v1 + + return mock_repository + + +@pytest.mark.parametrize( + "requesting_energy_supplier_id, contains_data", + [("1000000000000", True), ("2000000000000", True), ("ID_WITH_NO_DATA", False)], +) +def test_read_and_filter_from_view__when_requesting_actor_is_energy_supplier__returns_results_only_for_that_energy_supplier( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + energy_read_and_filter_mock_repository: Mock, + requesting_energy_supplier_id: str, + contains_data: bool, +) -> None: + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.ENERGY_SUPPLIER + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = ( + requesting_energy_supplier_id + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = [ + requesting_energy_supplier_id + ] + + expected_columns = ( + energy_read_and_filter_mock_repository.read_energy_per_es.return_value.columns + ) + + # Act + actual_df = read_and_filter_from_view( + args=standard_wholesale_fixing_scenario_args, + repository=energy_read_and_filter_mock_repository, + ) + + # Assert + assert expected_columns == actual_df.columns + + assert ( + actual_df.filter( + f"{DataProductColumnNames.energy_supplier_id} != '{requesting_energy_supplier_id}'" + ).count() + == 0 + ) + + if contains_data: + assert actual_df.count() > 0 + else: + assert actual_df.count() == 0 + + +@pytest.mark.parametrize( + "energy_supplier_ids, contains_data", + [ + (None, True), + (["1000000000000"], True), + (["2000000000000", "3000000000000"], True), + (["'ID_WITH_NO_DATA'"], False), + ], +) +def test_read_and_filter_from_view__when_datahub_admin__returns_results_for_expected_energy_suppliers( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + energy_read_and_filter_mock_repository: Mock, + energy_supplier_ids: list[str] | None, + contains_data: bool, +) -> None: + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = energy_supplier_ids + + expected_columns = ( + energy_read_and_filter_mock_repository.read_energy_per_es.return_value.columns + ) + + # Act + actual_df = read_and_filter_from_view( + args=standard_wholesale_fixing_scenario_args, + repository=energy_read_and_filter_mock_repository, + ) + + # Assert + assert expected_columns == actual_df.columns + + if energy_supplier_ids is not None: + energy_supplier_ids_as_string = ", ".join(energy_supplier_ids) + assert ( + actual_df.filter( + f"{DataProductColumnNames.energy_supplier_id} not in ({energy_supplier_ids_as_string})" + ).count() + == 0 + ) + + if contains_data: + assert actual_df.count() > 0 + else: + assert actual_df.count() == 0 + + +@pytest.mark.parametrize( + "grid_area_code", + [ + ("804"), + ("805"), + ], +) +def test_read_and_filter_from_view__when_grid_access_provider__returns_expected_results( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + energy_read_and_filter_mock_repository: Mock, + grid_area_code: str, +) -> None: + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.GRID_ACCESS_PROVIDER + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area = { + grid_area_code: uuid.UUID(DEFAULT_CALCULATION_ID), + } + + expected_columns = ( + energy_read_and_filter_mock_repository.read_energy.return_value.columns + ) + + # Act + actual_df = read_and_filter_from_view( + args=standard_wholesale_fixing_scenario_args, + repository=energy_read_and_filter_mock_repository, + ) + + # Assert + assert expected_columns == actual_df.columns + + number_of_rows_from_non_chosen_grid_areas = actual_df.filter( + f"{DataProductColumnNames.grid_area_code} != '{grid_area_code}'" + ).count() + number_of_rows_returned = actual_df.count() + + assert number_of_rows_from_non_chosen_grid_areas == 0 + assert number_of_rows_returned > 0 + + +def test_read_and_filter_from_view__when_balance_fixing__returns_only_rows_from_latest_calculations( + spark: SparkSession, + standard_balance_fixing_scenario_args: SettlementReportArgs, +) -> None: + # Arrange + standard_balance_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_balance_fixing_scenario_args.grid_area_codes = ["804"] + + not_latest_calculation_id = "11111111-9fc8-409a-a169-fbd49479d718" + latest_calculation_id = "22222222-9fc8-409a-a169-fbd49479d718" + energy_df = reduce( + lambda df1, df2: df1.union(df2), + [ + energy_factory.create_energy_per_es_v1( + spark, + default_data.create_energy_results_data_spec( + calculation_id=calculation_id + ), + ) + for calculation_id in [latest_calculation_id, not_latest_calculation_id] + ], + ) + latest_calculations = latest_calculations_factory.create( + spark, + default_data.create_latest_calculations_per_day_row( + calculation_id=latest_calculation_id, + calculation_type=CalculationTypeDataProductValue.BALANCE_FIXING, + ), + ) + + mock_repository = Mock() + mock_repository.read_energy_per_es.return_value = energy_df + mock_repository.read_latest_calculations.return_value = latest_calculations + + # Act + actual_df = read_and_filter_from_view( + args=standard_balance_fixing_scenario_args, + repository=mock_repository, + ) + + # Assert + actual_calculation_ids = ( + actual_df.select(DataProductColumnNames.calculation_id).distinct().collect() + ) + assert len(actual_calculation_ids) == 1 + assert ( + actual_calculation_ids[0][DataProductColumnNames.calculation_id] + == latest_calculation_id + ) + + +def test_read_and_filter_from_view__when_balance_fixing_with_two_calculations_with_time_overlap__returns_only_latest_calculation_data( + spark: SparkSession, + standard_balance_fixing_scenario_args: SettlementReportArgs, +) -> None: + # Arrange + standard_balance_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + day_1 = DEFAULT_FROM_DATE + day_2 = day_1 + timedelta(days=1) + day_3 = day_1 + timedelta(days=2) + day_4 = day_1 + timedelta(days=3) # exclusive + calculation_id_1 = "11111111-9fc8-409a-a169-fbd49479d718" + calculation_id_2 = "22222222-9fc8-409a-a169-fbd49479d718" + calc_type = CalculationTypeDataProductValue.BALANCE_FIXING + + time_series_points = reduce( + lambda df1, df2: df1.union(df2), + [ + energy_factory.create_energy_per_es_v1( + spark, + default_data.create_energy_results_data_spec( + calculation_id=calc_id, + calculation_type=calc_type, + from_date=from_date, + to_date=to_date, + ), + ) + for calc_id, from_date, to_date in [ + (calculation_id_1, day_1, day_3), + (calculation_id_2, day_2, day_4), + ] + ], + ) + + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + calculation_id=calc_id, + calculation_type=calc_type, + start_of_day=start_of_day, + ) + for calc_id, start_of_day in [ + (calculation_id_1, day_1), + (calculation_id_1, day_2), + (calculation_id_2, day_3), + ] + ], + ) + + mock_repository = Mock() + mock_repository.read_energy_per_es.return_value = time_series_points + mock_repository.read_latest_calculations.return_value = latest_calculations + + standard_balance_fixing_scenario_args.period_start = day_1 + standard_balance_fixing_scenario_args.period_end = day_4 + standard_balance_fixing_scenario_args.grid_area_codes = [ + default_data.DEFAULT_GRID_AREA_CODE + ] + + # Act + actual_df = read_and_filter_from_view( + args=standard_balance_fixing_scenario_args, + repository=mock_repository, + ) + + # Assert + + for day, expected_calculation_id in zip( + [day_1, day_2, day_3], [calculation_id_1, calculation_id_1, calculation_id_2] + ): + actual_calculation_ids = ( + actual_df.where( + (F.col(DataProductColumnNames.time) >= day) + & (F.col(DataProductColumnNames.time) < day + timedelta(days=1)) + ) + .select(DataProductColumnNames.calculation_id) + .distinct() + .collect() + ) + assert len(actual_calculation_ids) == 1 + assert actual_calculation_ids[0][0] == expected_calculation_id + + +def test_read_and_filter_from_view__when_balance_fixing__latest_calculation_for_grid_area( + spark: SparkSession, standard_balance_fixing_scenario_args: SettlementReportArgs +) -> None: + # Arrange + standard_balance_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + day_1 = DEFAULT_FROM_DATE + day_2 = day_1 + timedelta(days=1) # exclusive + grid_area_1 = "805" + grid_area_2 = "806" + calculation_id_1 = "11111111-9fc8-409a-a169-fbd49479d718" + calculation_id_2 = "22222222-9fc8-409a-a169-fbd49479d718" + calc_type = CalculationTypeDataProductValue.BALANCE_FIXING + + energy_per_es = reduce( + lambda df1, df2: df1.union(df2), + [ + energy_factory.create_energy_per_es_v1( + spark, + default_data.create_energy_results_data_spec( + calculation_id=calc_id, + calculation_type=calc_type, + grid_area_code=grid_area, + from_date=day_1, + to_date=day_2, + ), + ) + for calc_id, grid_area in [ + (calculation_id_1, grid_area_1), + (calculation_id_1, grid_area_2), + (calculation_id_2, grid_area_2), + ] + ], + ) + + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_1, + calculation_type=calc_type, + grid_area_code=grid_area_1, + start_of_day=day_1, + ), + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_2, + calculation_type=calc_type, + grid_area_code=grid_area_2, + start_of_day=day_1, + ), + ], + ) + + mock_repository = Mock() + mock_repository.read_energy_per_es.return_value = energy_per_es + mock_repository.read_latest_calculations.return_value = latest_calculations + + standard_balance_fixing_scenario_args.period_start = day_1 + standard_balance_fixing_scenario_args.period_end = day_2 + standard_balance_fixing_scenario_args.grid_area_codes = [grid_area_1, grid_area_2] + + # Act + actual_df = read_and_filter_from_view( + args=standard_balance_fixing_scenario_args, + repository=mock_repository, + ) + + # Assert + assert all( + row[DataProductColumnNames.calculation_id] == calculation_id_1 + for row in actual_df.where( + F.col(DataProductColumnNames.grid_area_code) == grid_area_1 + ) + .select(DataProductColumnNames.calculation_id) + .distinct() + .collect() + ) + + assert all( + row[DataProductColumnNames.calculation_id] == calculation_id_2 + for row in actual_df.where( + F.col(DataProductColumnNames.grid_area_code) == grid_area_2 + ) + .select(DataProductColumnNames.calculation_id) + .distinct() + .collect() + ) + + +def test_read_and_filter_from_view__when_balance_fixing__returns_only_balance_fixing_results( + spark: SparkSession, standard_balance_fixing_scenario_args: SettlementReportArgs +) -> None: + # Arrange + calculation_id_and_type = { + "11111111-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.AGGREGATION, + "22222222-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.BALANCE_FIXING, + "33333333-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.WHOLESALE_FIXING, + "44444444-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.FIRST_CORRECTION_SETTLEMENT, + "55555555-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.SECOND_CORRECTION_SETTLEMENT, + "66666666-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.THIRD_CORRECTION_SETTLEMENT, + } + energy_per_es = reduce( + lambda df1, df2: df1.union(df2), + [ + energy_factory.create_energy_per_es_v1( + spark, + default_data.create_energy_results_data_spec( + calculation_id=calc_id, calculation_type=calc_type + ), + ) + for calc_id, calc_type in calculation_id_and_type.items() + ], + ) + + latest_calculations = reduce( + lambda df1, df2: df1.union(df2), + [ + latest_calculations_factory.create( + spark, + default_data.create_latest_calculations_per_day_row( + calculation_id=calc_id, calculation_type=calc_type + ), + ) + for calc_id, calc_type in calculation_id_and_type.items() + ], + ) + + mock_repository = Mock() + mock_repository.read_energy_per_es.return_value = energy_per_es + mock_repository.read_latest_calculations.return_value = latest_calculations + + standard_balance_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_balance_fixing_scenario_args.grid_area_codes = [ + default_data.DEFAULT_GRID_AREA_CODE + ] + + # Act + actual_df = read_and_filter_from_view( + args=standard_balance_fixing_scenario_args, + repository=mock_repository, + ) + + # Assert + actual_calculation_ids = ( + actual_df.select(DataProductColumnNames.calculation_id).distinct().collect() + ) + assert len(actual_calculation_ids) == 1 + assert ( + actual_calculation_ids[0][DataProductColumnNames.calculation_id] + == "22222222-9fc8-409a-a169-fbd49479d718" + ) diff --git a/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_factory_balance_fixing.py b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_factory_balance_fixing.py new file mode 100644 index 0000000..ec4f11d --- /dev/null +++ b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_factory_balance_fixing.py @@ -0,0 +1,108 @@ +from unittest.mock import Mock + +from pyspark.sql import SparkSession, DataFrame +import test_factories.default_test_data_spec as default_data +import test_factories.metering_point_periods_factory as input_metering_point_periods_factory +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames +from settlement_report_job.domain.metering_point_periods.metering_point_periods_factory import ( + create_metering_point_periods, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from test_factories import latest_calculations_factory +from utils import Dates as d + + +def _get_repository_mock( + metering_point_period: DataFrame, + latest_calculations: DataFrame, + charge_link_periods: DataFrame | None = None, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + mock_repository.read_latest_calculations.return_value = latest_calculations + + if charge_link_periods: + mock_repository.read_charge_link_periods.return_value = charge_link_periods + + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +def test_create_metering_point_periods__when_grid_access_provider__returns_expected_columns( + spark: SparkSession, + standard_balance_fixing_scenario_grid_access_provider_args: SettlementReportArgs, +) -> None: + + # Arrange + expected_columns = [ + "grid_area_code_partitioning", + "METERINGPOINTID", + "VALIDFROM", + "VALIDTO", + "GRIDAREAID", + "TOGRIDAREAID", + "FROMGRIDAREAID", + "TYPEOFMP", + "SETTLEMENTMETHOD", + ] + + metering_point_periods = input_metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + latest_calculations = latest_calculations_factory.create( + spark, default_data.create_latest_calculations_per_day_row() + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = create_metering_point_periods( + args=standard_balance_fixing_scenario_grid_access_provider_args, + repository=mock_repository, + ) + + # Assert + assert actual.columns == expected_columns + + +def test_create_metering_point_periods__when_and_metering_point_period_exceeds_selected_period__returns_period_that_ends_on_the_selected_end_date( + spark: SparkSession, + standard_balance_fixing_scenario_args: SettlementReportArgs, +) -> None: + # Arrange + standard_balance_fixing_scenario_args.period_start = d.JAN_2ND + standard_balance_fixing_scenario_args.period_end = d.JAN_3RD + + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row(start_of_day=d.JAN_1ST), + default_data.create_latest_calculations_per_day_row(start_of_day=d.JAN_2ND), + default_data.create_latest_calculations_per_day_row(start_of_day=d.JAN_3RD), + ], + ) + metering_point_periods = input_metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + from_date=d.JAN_1ST, to_date=d.JAN_4TH + ), + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = create_metering_point_periods( + args=standard_balance_fixing_scenario_args, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + assert actual.collect()[0][CsvColumnNames.metering_point_from_date] == d.JAN_2ND + assert actual.collect()[0][CsvColumnNames.metering_point_to_date] == d.JAN_3RD diff --git a/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_factory_wholesale.py b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_factory_wholesale.py new file mode 100644 index 0000000..cbe07a4 --- /dev/null +++ b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_factory_wholesale.py @@ -0,0 +1,194 @@ +import uuid +from unittest.mock import Mock + +from pyspark.sql import SparkSession, DataFrame +import test_factories.default_test_data_spec as default_data +import test_factories.charge_link_periods_factory as input_charge_link_periods_factory +import test_factories.metering_point_periods_factory as input_metering_point_periods_factory +import test_factories.charge_price_information_periods_factory as input_charge_price_information_periods_factory +from settlement_report_job.domain.metering_point_periods.metering_point_periods_factory import ( + create_metering_point_periods, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointTypeDataProductValue, + SettlementMethodDataProductValue, +) + + +def _get_repository_mock( + metering_point_period: DataFrame, + charge_link_periods: DataFrame | None = None, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + + if charge_link_periods: + mock_repository.read_charge_link_periods.return_value = charge_link_periods + + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +def test_create_metering_point_periods__when_datahub_admin__returns_expected_values( + spark: SparkSession, + standard_wholesale_fixing_scenario_datahub_admin_args: SettlementReportArgs, +) -> None: + # Arrange + args = standard_wholesale_fixing_scenario_datahub_admin_args + args.period_start = default_data.DEFAULT_PERIOD_START + args.period_end = default_data.DEFAULT_PERIOD_END + args.calculation_id_by_grid_area = { + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID( + default_data.DEFAULT_CALCULATION_ID + ) + } + expected = { + "grid_area_code_partitioning": default_data.DEFAULT_GRID_AREA_CODE, + "METERINGPOINTID": default_data.DEFAULT_METERING_POINT_ID, + "VALIDFROM": default_data.DEFAULT_PERIOD_START, + "VALIDTO": default_data.DEFAULT_PERIOD_END, + "GRIDAREAID": default_data.DEFAULT_GRID_AREA_CODE, + "TYPEOFMP": "E17", + "SETTLEMENTMETHOD": "D01", + "ENERGYSUPPLIERID": default_data.DEFAULT_ENERGY_SUPPLIER_ID, + } + metering_point_periods = input_metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + settlement_method=SettlementMethodDataProductValue.FLEX, + ), + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual = create_metering_point_periods( + args=args, + repository=mock_repository, + ) + + # Assert + print(actual.collect()) + print(expected) + assert actual.count() == 1 + assert actual.collect()[0].asDict() == expected + + +def test_create_metering_point_periods_for_wholesale__when_system_operator__returns_expected_columns( + spark: SparkSession, + standard_wholesale_fixing_scenario_system_operator_args: SettlementReportArgs, +) -> None: + # Arrange + expected_columns = [ + "grid_area_code_partitioning", + "METERINGPOINTID", + "VALIDFROM", + "VALIDTO", + "GRIDAREAID", + "TYPEOFMP", + "SETTLEMENTMETHOD", + "ENERGYSUPPLIERID", + ] + metering_point_periods = input_metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + charge_price_information_periods = input_charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=standard_wholesale_fixing_scenario_system_operator_args.requesting_actor_id, + is_tax=False, + ), + ) + charge_link_periods = input_charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + charge_owner_id=standard_wholesale_fixing_scenario_system_operator_args.requesting_actor_id + ), + ) + mock_repository = _get_repository_mock( + metering_point_periods, charge_link_periods, charge_price_information_periods + ) + + # Act + actual = create_metering_point_periods( + args=standard_wholesale_fixing_scenario_system_operator_args, + repository=mock_repository, + ) + + # Assert + assert actual.columns == expected_columns + + +def test_create_metering_point_periods__when_energy_supplier__returns_expected_columns( + spark: SparkSession, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, +) -> None: + + # Arrange + expected_columns = [ + "grid_area_code_partitioning", + "METERINGPOINTID", + "VALIDFROM", + "VALIDTO", + "GRIDAREAID", + "TYPEOFMP", + "SETTLEMENTMETHOD", + ] + + metering_point_periods = input_metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual = create_metering_point_periods( + args=standard_wholesale_fixing_scenario_energy_supplier_args, + repository=mock_repository, + ) + + # Assert + assert actual.columns == expected_columns + + +def test_create_metering_point_periods__when_grid_access_provider__returns_expected_columns( + spark: SparkSession, + standard_wholesale_fixing_scenario_grid_access_provider_args: SettlementReportArgs, +) -> None: + + # Arrange + expected_columns = [ + "grid_area_code_partitioning", + "METERINGPOINTID", + "VALIDFROM", + "VALIDTO", + "GRIDAREAID", + "TOGRIDAREAID", + "FROMGRIDAREAID", + "TYPEOFMP", + "SETTLEMENTMETHOD", + ] + + metering_point_periods = input_metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual = create_metering_point_periods( + args=standard_wholesale_fixing_scenario_grid_access_provider_args, + repository=mock_repository, + ) + + # Assert + assert actual.columns == expected_columns diff --git a/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_read_and_filter_balance_fixing.py b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_read_and_filter_balance_fixing.py new file mode 100644 index 0000000..ca40729 --- /dev/null +++ b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_read_and_filter_balance_fixing.py @@ -0,0 +1,359 @@ +from datetime import datetime +from unittest.mock import Mock + +from pyspark.sql import SparkSession, DataFrame + +import test_factories.default_test_data_spec as default_data +from settlement_report_job.domain.metering_point_periods.read_and_filter_balance_fixing import ( + read_and_filter, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from test_factories import latest_calculations_factory, metering_point_periods_factory +from utils import Dates as d, DEFAULT_TIME_ZONE + + +DEFAULT_SELECT_COLUMNS = [ + DataProductColumnNames.metering_point_id, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + DataProductColumnNames.grid_area_code, + DataProductColumnNames.from_grid_area_code, + DataProductColumnNames.to_grid_area_code, + DataProductColumnNames.metering_point_type, + DataProductColumnNames.settlement_method, + DataProductColumnNames.energy_supplier_id, +] + + +def _get_repository_mock( + metering_point_period: DataFrame, + latest_calculations: DataFrame, + charge_link_periods: DataFrame | None = None, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + mock_repository.read_latest_calculations.return_value = latest_calculations + + if charge_link_periods: + mock_repository.read_charge_link_periods.return_value = charge_link_periods + + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +def test_read_and_filter__when_duplicate_metering_point_periods__returns_one_period_per_duplicate( + spark: SparkSession, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + from_date=d.JAN_1ST, to_date=d.JAN_2ND + ), + default_data.create_metering_point_periods_row( + from_date=d.JAN_1ST, to_date=d.JAN_2ND + ), + ], + ) + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_1ST, + ), + ], + ) + mock_repository = _get_repository_mock( + metering_point_periods, latest_calculations=latest_calculations + ) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_2ND, + grid_area_codes=default_data.DEFAULT_GRID_AREA_CODE, + energy_supplier_ids=None, + select_columns=DEFAULT_SELECT_COLUMNS, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + + +def test_read_and_filter__when_metering_periods_with_gap__returns_separate_periods( + spark: SparkSession, +) -> None: + # Arrange + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_1ST, + ), + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_2ND, + ), + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_3RD, + ), + ], + ) + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + from_date=d.JAN_1ST, + to_date=d.JAN_2ND, + ), + default_data.create_metering_point_periods_row( + from_date=d.JAN_3RD, + to_date=d.JAN_4TH, + ), + ], + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_4TH, + grid_area_codes=default_data.DEFAULT_GRID_AREA_CODE, + energy_supplier_ids=None, + select_columns=DEFAULT_SELECT_COLUMNS, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 2 + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.collect()[0][DataProductColumnNames.from_date] == d.JAN_1ST + assert actual.collect()[0][DataProductColumnNames.to_date] == d.JAN_2ND + assert actual.collect()[1][DataProductColumnNames.from_date] == d.JAN_3RD + assert actual.collect()[1][DataProductColumnNames.to_date] == d.JAN_4TH + + +def test_read_and_filter__when_period_exceeds_selection_period__returns_period_that_ends_on_the_selection_end_date( + spark: SparkSession, +) -> None: + # Arrange + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_1ST, + ), + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_2ND, + ), + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_3RD, + ), + default_data.create_latest_calculations_per_day_row( + start_of_day=d.JAN_4TH, + ), + ], + ) + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + from_date=d.JAN_1ST, + to_date=d.JAN_4TH, + ), + ], + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = read_and_filter( + period_start=d.JAN_2ND, + period_end=d.JAN_3RD, + grid_area_codes=default_data.DEFAULT_GRID_AREA_CODE, + energy_supplier_ids=None, + select_columns=DEFAULT_SELECT_COLUMNS, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.collect()[0][DataProductColumnNames.from_date] == d.JAN_2ND + assert actual.collect()[0][DataProductColumnNames.to_date] == d.JAN_3RD + + +def test_read_and_filter__when_calculation_overlap_in_time__returns_latest( + spark: SparkSession, +) -> None: + # Arrange + calculation_id_1 = "11111111-1111-1111-1111-11111111" + calculation_id_2 = "22222222-2222-2222-2222-22222222" + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_1, + start_of_day=d.JAN_1ST, + ), + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_2, + start_of_day=d.JAN_2ND, + ), + ], + ) + + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + calculation_id=calculation_id_1, + metering_point_id="1", + from_date=d.JAN_1ST, + to_date=d.JAN_3RD, + ), + default_data.create_metering_point_periods_row( + calculation_id=calculation_id_2, + metering_point_id="2", + from_date=d.JAN_1ST, + to_date=d.JAN_3RD, + ), + ], + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_3RD, + grid_area_codes=default_data.DEFAULT_GRID_AREA_CODE, + energy_supplier_ids=None, + select_columns=DEFAULT_SELECT_COLUMNS, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 2 + assert ( + actual.orderBy(DataProductColumnNames.from_date).collect()[0][ + DataProductColumnNames.metering_point_id + ] + == "1" + ) + assert ( + actual.orderBy(DataProductColumnNames.from_date).collect()[1][ + DataProductColumnNames.metering_point_id + ] + == "2" + ) + + +def test_read_and_filter__when_metering_point_period_is_shorter_in_newer_calculation__returns_the_shorter_period( + spark: SparkSession, +) -> None: + # Arrange + calculation_id_1 = "11111111-1111-1111-1111-11111111" + calculation_id_2 = "22222222-2222-2222-2222-22222222" + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_2, + start_of_day=d.JAN_1ST, + ), + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_2, + start_of_day=d.JAN_2ND, + ), + ], + ) + + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + calculation_id=calculation_id_1, + from_date=d.JAN_1ST, + to_date=d.JAN_3RD, + ), + default_data.create_metering_point_periods_row( + calculation_id=calculation_id_2, + from_date=d.JAN_2ND, + to_date=d.JAN_3RD, + ), + ], + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_3RD, + grid_area_codes=default_data.DEFAULT_GRID_AREA_CODE, + energy_supplier_ids=None, + select_columns=DEFAULT_SELECT_COLUMNS, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + assert actual.collect()[0][DataProductColumnNames.from_date] == d.JAN_2ND + assert actual.collect()[0][DataProductColumnNames.to_date] == d.JAN_3RD + + +def test_read_and_filter__when_daylight_saving_time_returns_expected( + spark: SparkSession, +) -> None: + # Arrange + from_date = datetime(2024, 3, 30, 23) + to_date = datetime(2024, 4, 1, 22) + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + start_of_day=datetime(2024, 3, 30, 23), + ), + default_data.create_latest_calculations_per_day_row( + start_of_day=datetime(2024, 3, 31, 22), + ), + ], + ) + + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + from_date=from_date, + to_date=to_date, + ), + ], + ) + mock_repository = _get_repository_mock(metering_point_periods, latest_calculations) + + # Act + actual = read_and_filter( + period_start=from_date, + period_end=to_date, + grid_area_codes=default_data.DEFAULT_GRID_AREA_CODE, + energy_supplier_ids=None, + select_columns=DEFAULT_SELECT_COLUMNS, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + assert actual.collect()[0][DataProductColumnNames.from_date] == from_date + assert actual.collect()[0][DataProductColumnNames.to_date] == to_date diff --git a/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_read_and_filter_wholesale.py b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_read_and_filter_wholesale.py new file mode 100644 index 0000000..2a3a10b --- /dev/null +++ b/source/settlement_report_python/tests/domain/metering_point_periods/test_metering_point_periods_read_and_filter_wholesale.py @@ -0,0 +1,488 @@ +import uuid +from datetime import datetime +from functools import reduce +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession, DataFrame +import test_factories.default_test_data_spec as default_data +import test_factories.charge_link_periods_factory as charge_link_periods_factory +import test_factories.metering_point_periods_factory as metering_point_periods_factory +import test_factories.charge_price_information_periods_factory as charge_price_information_periods_factory +from settlement_report_job.domain.metering_point_periods.read_and_filter_wholesale import ( + read_and_filter, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from utils import Dates as d + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +GRID_ACCESS_PROVIDER_ID = "4444444444444" +OTHER_ID = "9999999999999" +DEFAULT_CALCULATION_ID_BY_GRID_AREA = { + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) +} + + +DEFAULT_SELECT_COLUMNS = [ + DataProductColumnNames.metering_point_id, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + DataProductColumnNames.grid_area_code, + DataProductColumnNames.from_grid_area_code, + DataProductColumnNames.to_grid_area_code, + DataProductColumnNames.metering_point_type, + DataProductColumnNames.settlement_method, + DataProductColumnNames.energy_supplier_id, +] + + +def _get_repository_mock( + metering_point_period: DataFrame, + charge_link_periods: DataFrame | None = None, + charge_price_information_periods: DataFrame | None = None, +) -> Mock: + mock_repository = Mock() + mock_repository.read_metering_point_periods.return_value = metering_point_period + + if charge_link_periods: + mock_repository.read_charge_link_periods.return_value = charge_link_periods + + if charge_price_information_periods: + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_periods + ) + + return mock_repository + + +@pytest.mark.parametrize( + "from_date,to_date,is_included", + [ + pytest.param( + d.JAN_1ST, + d.JAN_2ND, + False, + id="metering point period stops before selected period", + ), + pytest.param( + d.JAN_1ST, + d.JAN_3RD, + True, + id="metering point starts before and ends within selected period", + ), + pytest.param( + d.JAN_3RD, + d.JAN_4TH, + True, + id="metering point period is within selected period", + ), + pytest.param( + d.JAN_3RD, + d.JAN_5TH, + True, + id="metering point starts within but stops after selected period", + ), + pytest.param( + d.JAN_4TH, + d.JAN_5TH, + False, + id="metering point starts after selected period", + ), + ], +) +def test_read_and_filter__returns_charge_link_periods_that_overlap_with_selected_period( + spark: SparkSession, + from_date: datetime, + to_date: datetime, + is_included: bool, +) -> None: + # Arrange + calculation_period_start = d.JAN_2ND + calculation_period_end = d.JAN_4TH + + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + from_date=from_date, to_date=to_date + ), + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual_df = read_and_filter( + period_start=calculation_period_start, + period_end=calculation_period_end, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + assert (actual_df.count() > 0) == is_included + + +def test_read_and_filter__returns_only_selected_grid_area( + spark: SparkSession, +) -> None: + # Arrange + selected_grid_area_code = "805" + not_selected_grid_area_code = "806" + selected_metering_point = "555" + not_selected_metering_point = "666" + + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + grid_area_code=selected_grid_area_code, + metering_point_id=selected_metering_point, + ), + ).union( + metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + grid_area_code=not_selected_grid_area_code, + metering_point_id=not_selected_metering_point, + ), + ) + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + selected_grid_area_code: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + actual_grid_area_codes = ( + actual_df.select(DataProductColumnNames.grid_area_code).distinct().collect() + ) + assert len(actual_grid_area_codes) == 1 + assert actual_grid_area_codes[0][0] == selected_grid_area_code + + +def test_read_and_filter__returns_only_rows_from_selected_calculation_id( + spark: SparkSession, +) -> None: + # Arrange + selected_calculation_id = "11111111-9fc8-409a-a169-fbd49479d718" + not_selected_calculation_id = "22222222-9fc8-409a-a169-fbd49479d718" + expected_metering_point_id = "123456789012345678901234567" + other_metering_point_id = "765432109876543210987654321" + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + calculation_id=selected_calculation_id, + metering_point_id=expected_metering_point_id, + ), + ).union( + metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + calculation_id=not_selected_calculation_id, + metering_point_id=other_metering_point_id, + ), + ) + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(selected_calculation_id) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + actual_metering_point_ids = ( + actual_df.select(DataProductColumnNames.metering_point_id).distinct().collect() + ) + assert len(actual_metering_point_ids) == 1 + assert ( + actual_metering_point_ids[0][DataProductColumnNames.metering_point_id] + == expected_metering_point_id + ) + + +ENERGY_SUPPLIER_A = "1000000000000" +ENERGY_SUPPLIER_B = "2000000000000" +ENERGY_SUPPLIER_C = "3000000000000" +ENERGY_SUPPLIERS_ABC = [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B, ENERGY_SUPPLIER_C] +METERING_POINT_ID_ABC = ["123", "456", "789"] + + +@pytest.mark.parametrize( + "selected_energy_supplier_ids,expected_energy_supplier_ids", + [ + (None, ENERGY_SUPPLIERS_ABC), + ([ENERGY_SUPPLIER_B], [ENERGY_SUPPLIER_B]), + ( + [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B], + [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B], + ), + (ENERGY_SUPPLIERS_ABC, ENERGY_SUPPLIERS_ABC), + ], +) +def test_read_and_filter__returns_data_for_expected_energy_suppliers( + spark: SparkSession, + selected_energy_supplier_ids: list[str] | None, + expected_energy_supplier_ids: list[str], +) -> None: + # Arrange + metering_point_periods = reduce( + lambda df1, df2: df1.union(df2), + [ + metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row( + energy_supplier_id=energy_supplier_id, + metering_point_id=metering_point_id, + ), + ) + for energy_supplier_id, metering_point_id in zip( + ENERGY_SUPPLIERS_ABC, METERING_POINT_ID_ABC + ) + ], + ) + mock_repository = _get_repository_mock(metering_point_periods) + + # Act + actual_df = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=selected_energy_supplier_ids, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + assert set( + row[DataProductColumnNames.energy_supplier_id] for row in actual_df.collect() + ) == set(expected_energy_supplier_ids) + + +@pytest.mark.parametrize( + "charge_owner_id,is_tax,return_rows", + [ + pytest.param( + SYSTEM_OPERATOR_ID, False, True, id="system operator without tax: include" + ), + pytest.param( + SYSTEM_OPERATOR_ID, True, False, id="system operator with tax: exclude" + ), + pytest.param( + OTHER_ID, False, False, id="other charge owner without tax: exclude" + ), + pytest.param(OTHER_ID, True, False, id="other charge owner with tax: exclude"), + ], +) +def test_read_and_filter__when_system_operator__returns_expected_metering_points( + spark: SparkSession, + charge_owner_id: str, + is_tax: bool, + return_rows: bool, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + default_data.create_metering_point_periods_row(), + ) + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=charge_owner_id, + is_tax=is_tax, + ), + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_owner_id=charge_owner_id), + ) + mock_repository = _get_repository_mock( + metering_point_periods, charge_link_periods, charge_price_information_periods + ) + + # Act + actual = read_and_filter( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.SYSTEM_OPERATOR, + requesting_actor_id=SYSTEM_OPERATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + assert (actual.count() > 0) == return_rows + + +def test_read_and_filter__when_balance_responsible_party_changes_on_metering_point__returns_single_period( + spark: SparkSession, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + balance_responsible_party_id="1", from_date=d.JAN_1ST, to_date=d.JAN_2ND + ), + default_data.create_metering_point_periods_row( + balance_responsible_party_id="2", from_date=d.JAN_2ND, to_date=d.JAN_3RD + ), + ], + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + from_date=d.JAN_1ST, + to_date=d.JAN_3RD, + charge_owner_id=GRID_ACCESS_PROVIDER_ID, + ), + ) + charge_price_information_periods = charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + from_date=d.JAN_1ST, + to_date=d.JAN_3RD, + is_tax=True, + charge_owner_id=GRID_ACCESS_PROVIDER_ID, + ), + ) + mock_repository = _get_repository_mock( + metering_point_periods, charge_link_periods, charge_price_information_periods + ) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_3RD, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.GRID_ACCESS_PROVIDER, + requesting_actor_id=GRID_ACCESS_PROVIDER_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 + assert actual.select(DataProductColumnNames.from_date).collect()[0][0] == d.JAN_1ST + assert actual.select(DataProductColumnNames.to_date).collect()[0][0] == d.JAN_3RD + + +def test_read_and_filter__when_datahub_user_and_energy_supplier_changes_on_metering_point__returns_two_link_periods( + spark: SparkSession, +) -> None: + # Arrange + es_id_a = "111" + es_id_b = "222" + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row( + energy_supplier_id=es_id_a, + from_date=d.JAN_1ST, + to_date=d.JAN_2ND, + ), + default_data.create_metering_point_periods_row( + energy_supplier_id=es_id_b, + from_date=d.JAN_2ND, + to_date=d.JAN_3RD, + ), + ], + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + from_date=d.JAN_1ST, to_date=d.JAN_3RD + ), + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_3RD, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.count() == 2 + + actual_row_1 = actual.collect()[0] + assert actual_row_1[DataProductColumnNames.energy_supplier_id] == es_id_a + assert actual_row_1[DataProductColumnNames.from_date] == d.JAN_1ST + assert actual_row_1[DataProductColumnNames.to_date] == d.JAN_2ND + + actual_row_2 = actual.collect()[1] + assert actual_row_2[DataProductColumnNames.energy_supplier_id] == es_id_b + assert actual_row_2[DataProductColumnNames.from_date] == d.JAN_2ND + assert actual_row_2[DataProductColumnNames.to_date] == d.JAN_3RD + + +def test_read_and_filter__when_duplicate_metering_point_periods__returns_one_period_per_duplicate( + spark: SparkSession, +) -> None: + # Arrange + metering_point_periods = metering_point_periods_factory.create( + spark, + [ + default_data.create_metering_point_periods_row(), + default_data.create_metering_point_periods_row(), + ], + ) + charge_link_periods = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(), + ) + mock_repository = _get_repository_mock(metering_point_periods, charge_link_periods) + + # Act + actual = read_and_filter( + period_start=d.JAN_1ST, + period_end=d.JAN_3RD, + calculation_id_by_grid_area=DEFAULT_CALCULATION_ID_BY_GRID_AREA, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + select_columns=DEFAULT_SELECT_COLUMNS, + repository=mock_repository, + ) + + # Assert + assert actual.count() == 1 diff --git a/source/settlement_report_python/tests/domain/monthly_amounts/test_monthly_amounts_prepare_for_csv.py b/source/settlement_report_python/tests/domain/monthly_amounts/test_monthly_amounts_prepare_for_csv.py new file mode 100644 index 0000000..cb67bd4 --- /dev/null +++ b/source/settlement_report_python/tests/domain/monthly_amounts/test_monthly_amounts_prepare_for_csv.py @@ -0,0 +1,63 @@ +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.functions import lit + +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +import test_factories.default_test_data_spec as default_data +import test_factories.monthly_amounts_per_charge_factory as monthly_amounts_per_charge_factory + +from settlement_report_job.domain.monthly_amounts.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +NOT_SYSTEM_OPERATOR_ID = "4444444444444" +DEFAULT_TIME_ZONE = "Europe/Copenhagen" +DEFAULT_CALCULATION_ID = "12345678-6f20-40c5-9a95-f419a1245d7e" + + +@pytest.mark.parametrize("should_have_one_file_per_grid_area", [True, False]) +def test_prepare_for_csv__returns_expected_columns( + spark: SparkSession, + should_have_one_file_per_grid_area: bool, +) -> None: + # Arrange + monthly_amounts = monthly_amounts_per_charge_factory.create( + spark, default_data.create_monthly_amounts_per_charge_row() + ) + monthly_amounts = monthly_amounts.withColumn( + DataProductColumnNames.resolution, lit("P1M") + ) + + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + + if should_have_one_file_per_grid_area: + expected_columns.append(EphemeralColumns.grid_area_code_partitioning) + + # Act + actual_df = prepare_for_csv(monthly_amounts, should_have_one_file_per_grid_area) + + # Assert + assert expected_columns == actual_df.columns diff --git a/source/settlement_report_python/tests/domain/monthly_amounts/test_monthly_amounts_read_and_filter.py b/source/settlement_report_python/tests/domain/monthly_amounts/test_monthly_amounts_read_and_filter.py new file mode 100644 index 0000000..f70675f --- /dev/null +++ b/source/settlement_report_python/tests/domain/monthly_amounts/test_monthly_amounts_read_and_filter.py @@ -0,0 +1,424 @@ +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession, functions as F +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from data_seeding import ( + standard_wholesale_fixing_scenario_data_generator, +) +import test_factories.default_test_data_spec as default_data +import test_factories.monthly_amounts_per_charge_factory as monthly_amounts_per_charge_factory +import test_factories.total_monthly_amounts_factory as total_monthly_amounts_factory + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.monthly_amounts.read_and_filter import ( + _filter_monthly_amounts_per_charge, + read_and_filter_from_view, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +GRID_ACCESS_PROVIDER_ID = "5555555555555" +SYSTEM_OPERATOR_ID = "3333333333333" +DEFAULT_ENERGY_SUPPLIER_ID = "2222222222222" +NOT_SYSTEM_OPERATOR_ID = "4444444444444" +DEFAULT_TIME_ZONE = "Europe/Copenhagen" +DEFAULT_CALCULATION_ID = "12345678-6f20-40c5-9a95-f419a1245d7e" + + +@pytest.fixture(scope="session") +def monthly_amounts_read_and_filter_mock_repository( + spark: SparkSession, +) -> Mock: + mock_repository = Mock() + + monthly_amounts_per_charge = None + total_monthly_amounts = None + + for grid_area in ["804", "805"]: + for energy_supplier_id in [ + "1000000000000", + DEFAULT_ENERGY_SUPPLIER_ID, + "3000000000000", + ]: + for charge_owner_id in [ + DATAHUB_ADMINISTRATOR_ID, + GRID_ACCESS_PROVIDER_ID, + SYSTEM_OPERATOR_ID, + energy_supplier_id, + None, + ]: + for is_tax in [True, False]: + charge_owner_id_for_per_charge = ( + energy_supplier_id + if charge_owner_id is None + else charge_owner_id + ) + + testing_spec_monthly_per_charge = ( + default_data.create_monthly_amounts_per_charge_row( + energy_supplier_id=energy_supplier_id, + calculation_id=DEFAULT_CALCULATION_ID, + grid_area_code=grid_area, + charge_owner_id=charge_owner_id_for_per_charge, + is_tax=is_tax, + ) + ) + testing_spec_total_monthly = ( + default_data.create_total_monthly_amounts_row( + energy_supplier_id=energy_supplier_id, + calculation_id=DEFAULT_CALCULATION_ID, + grid_area_code=grid_area, + charge_owner_id=charge_owner_id, + ) + ) + + if monthly_amounts_per_charge is None: + monthly_amounts_per_charge = ( + monthly_amounts_per_charge_factory.create( + spark, testing_spec_monthly_per_charge + ) + ) + else: + monthly_amounts_per_charge = monthly_amounts_per_charge.union( + monthly_amounts_per_charge_factory.create( + spark, testing_spec_monthly_per_charge + ) + ) + + if total_monthly_amounts is None: + total_monthly_amounts = total_monthly_amounts_factory.create( + spark, testing_spec_total_monthly + ) + else: + total_monthly_amounts = total_monthly_amounts.union( + total_monthly_amounts_factory.create( + spark, testing_spec_total_monthly + ) + ) + + mock_repository.read_monthly_amounts_per_charge_v1.return_value = ( + monthly_amounts_per_charge + ) + mock_repository.read_total_monthly_amounts_v1.return_value = total_monthly_amounts + + return mock_repository + + +def get_expected_unordered_columns() -> list[str]: + return [ + DataProductColumnNames.calculation_id, + DataProductColumnNames.calculation_type, + DataProductColumnNames.calculation_version, + DataProductColumnNames.grid_area_code, + DataProductColumnNames.energy_supplier_id, + DataProductColumnNames.time, + DataProductColumnNames.resolution, + DataProductColumnNames.quantity_unit, + DataProductColumnNames.currency, + DataProductColumnNames.amount, + DataProductColumnNames.charge_type, + DataProductColumnNames.charge_code, + DataProductColumnNames.is_tax, + DataProductColumnNames.result_id, + DataProductColumnNames.charge_owner_id, + ] + + +def test_read_and_filter_from_view__returns_expected_columns( + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + monthly_amounts_read_and_filter_mock_repository: Mock, +) -> None: + # Arrange + expected_unordered_columns = get_expected_unordered_columns() + + # Act + actual_df = read_and_filter_from_view( + args=standard_wholesale_fixing_scenario_energy_supplier_args, + repository=monthly_amounts_read_and_filter_mock_repository, + ) + + # Assert + assert set(expected_unordered_columns) == set(actual_df.columns) + + +def test_read_and_filter_from_view__when_energy_supplier__returns_only_data_from_itself_but_all_charge_owners( + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + monthly_amounts_read_and_filter_mock_repository: Mock, +) -> None: + # Arrange + args = standard_wholesale_fixing_scenario_energy_supplier_args + expected_unordered_columns = get_expected_unordered_columns() + + # Act + actual_df = read_and_filter_from_view( + args=args, + repository=monthly_amounts_read_and_filter_mock_repository, + ) + + # Assert + assert set(expected_unordered_columns) == set(actual_df.columns) + assert ( + actual_df.where( + F.col(DataProductColumnNames.energy_supplier_id).isin( + [args.requesting_actor_id] + ) + ).count() + > 0 + ) + assert ( + actual_df.where( + ~F.col(DataProductColumnNames.energy_supplier_id).isin( + [args.requesting_actor_id] + ) + ).count() + == 0 + ) + assert ( + actual_df.select(F.col(DataProductColumnNames.charge_owner_id)) + .distinct() + .count() + > 1 + ) + + +def test_read_and_filter_from_view__when_datahub_administrator__returns_all_suppliers_and_charge_owners( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + monthly_amounts_read_and_filter_mock_repository: Mock, +) -> None: + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = ( + DATAHUB_ADMINISTRATOR_ID + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + expected_unordered_columns = get_expected_unordered_columns() + + # Act + actual_df = read_and_filter_from_view( + args=standard_wholesale_fixing_scenario_args, + repository=monthly_amounts_read_and_filter_mock_repository, + ) + + # Assert + assert set(expected_unordered_columns) == set(actual_df.columns) + assert ( + actual_df.select(F.col(DataProductColumnNames.energy_supplier_id)).count() > 1 + ) + assert ( + actual_df.select(F.col(DataProductColumnNames.charge_owner_id)) + .distinct() + .count() + > 1 + ) + + +@pytest.mark.parametrize( + "requesting_actor_market_role,actor_id", + [ + (MarketRole.GRID_ACCESS_PROVIDER, GRID_ACCESS_PROVIDER_ID), + (MarketRole.SYSTEM_OPERATOR, SYSTEM_OPERATOR_ID), + ], +) +def test_read_and_filter_from_view__when_grid_or_system_operator__returns_multiple_energy_suppliers( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + monthly_amounts_read_and_filter_mock_repository: Mock, + requesting_actor_market_role: MarketRole, + actor_id: str, +) -> None: + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + requesting_actor_market_role + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = actor_id + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area = dict( + list( + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area.items() + )[:-1] + ) + targeted_grid_area = list( + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area + )[0] + + expected_unordered_columns = get_expected_unordered_columns() + + # Act + actual_df = read_and_filter_from_view( + args=standard_wholesale_fixing_scenario_args, + repository=monthly_amounts_read_and_filter_mock_repository, + ) + + # Assert + assert set(expected_unordered_columns) == set(actual_df.columns) + assert actual_df.count() > 0 + assert ( + actual_df.where( + F.col(DataProductColumnNames.grid_area_code).isin([targeted_grid_area]) + ).count() + > 0 + ) + assert ( + actual_df.where( + ~F.col(DataProductColumnNames.grid_area_code).isin([targeted_grid_area]) + ).count() + == 0 + ) + assert ( + actual_df.select(F.col(DataProductColumnNames.energy_supplier_id)) + .distinct() + .count() + > 1 + ) + + +def test_filter_monthly_amounts_per_charge__when_grid_access_provider__returns_their_charges_and_correct_tax( + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +) -> None: + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = MarketRole.GRID_ACCESS_PROVIDER + args.requesting_actor_id = GRID_ACCESS_PROVIDER_ID + args.energy_supplier_ids = None + targeted_grid_area = standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[0] + calc_id = standard_wholesale_fixing_scenario_data_generator.CALCULATION_ID + + passing_row_because_of_charge_owner_id = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + is_tax=False, + charge_owner_id=GRID_ACCESS_PROVIDER_ID, + ), + ) + passing_row_due_to_tax = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + is_tax=True, + charge_owner_id="Not our requesting actor", + ), + ) + failing_row_due_to_charge_owner = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + is_tax=False, + charge_owner_id="Not our requesting actor", + ), + ) + failing_row_due_to_grid_area = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + grid_area_code="Not our grid area", + is_tax=False, + charge_owner_id=GRID_ACCESS_PROVIDER_ID, + ), + ) + testing_data = ( + passing_row_because_of_charge_owner_id.union(passing_row_due_to_tax) + .union(failing_row_due_to_charge_owner) + .union(failing_row_due_to_grid_area) + ) + + expected_count = 2 + + # Act + actual_df = _filter_monthly_amounts_per_charge( + testing_data, + args, + ) + + # Assert + assert actual_df.count() == expected_count + assert ( + actual_df.where( + F.col(DataProductColumnNames.grid_area_code).isin([targeted_grid_area]) + ).count() + == expected_count + ) + assert actual_df.select(DataProductColumnNames.is_tax).distinct().count() == 2 + + +def test_filter_monthly_amounts_per_charge__when_system_operator__returns_their_charges_and_correct_tax( + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +) -> None: + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = MarketRole.SYSTEM_OPERATOR + args.requesting_actor_id = SYSTEM_OPERATOR_ID + args.energy_supplier_ids = None + targeted_grid_area = standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[0] + calc_id = standard_wholesale_fixing_scenario_data_generator.CALCULATION_ID + + passing_row = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + is_tax=False, + charge_owner_id=SYSTEM_OPERATOR_ID, + ), + ) + failing_row_due_to_tax = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + is_tax=True, + charge_owner_id=SYSTEM_OPERATOR_ID, + ), + ) + failing_row_due_to_charge_owner = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + is_tax=False, + charge_owner_id="Not our requesting actor", + ), + ) + failing_row_due_to_grid_area = monthly_amounts_per_charge_factory.create( + spark, + default_data.create_monthly_amounts_per_charge_row( + calculation_id=calc_id, + grid_area_code="Not our grid area", + is_tax=False, + charge_owner_id=SYSTEM_OPERATOR_ID, + ), + ) + testing_data = ( + passing_row.union(failing_row_due_to_tax) + .union(failing_row_due_to_charge_owner) + .union(failing_row_due_to_grid_area) + ) + + expected_count = 1 + + # Act + actual_df = _filter_monthly_amounts_per_charge( + testing_data, + args, + ) + + # Assert + assert actual_df.count() == expected_count + assert ( + actual_df.where( + F.col(DataProductColumnNames.grid_area_code).isin([targeted_grid_area]) + ).count() + == expected_count + ) + assert ( + actual_df.filter(~F.col(DataProductColumnNames.is_tax)).count() + == expected_count + ) diff --git a/source/settlement_report_python/tests/domain/time_series_points/test_prepare_for_csv.py b/source/settlement_report_python/tests/domain/time_series_points/test_prepare_for_csv.py new file mode 100644 index 0000000..0b67510 --- /dev/null +++ b/source/settlement_report_python/tests/domain/time_series_points/test_prepare_for_csv.py @@ -0,0 +1,179 @@ +from datetime import datetime +from decimal import Decimal + +import pytest +from pyspark.sql import SparkSession, DataFrame +from pyspark.sql.functions import monotonically_increasing_id +import pyspark.sql.functions as F +from pyspark.sql.types import DecimalType + +import tests.test_factories.default_test_data_spec as default_data +import tests.test_factories.metering_point_time_series_factory as time_series_points_factory +from settlement_report_job.domain.utils.market_role import MarketRole + +from settlement_report_job.domain.time_series_points.prepare_for_csv import ( + prepare_for_csv, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, +) + +DEFAULT_TIME_ZONE = "Europe/Copenhagen" +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +NOT_SYSTEM_OPERATOR_ID = "4444444444444" +DEFAULT_MARKET_ROLE = MarketRole.GRID_ACCESS_PROVIDER + + +def _create_time_series_points_with_increasing_quantity( + spark: SparkSession, + from_date: datetime, + to_date: datetime, + resolution: MeteringPointResolutionDataProductValue, +) -> DataFrame: + spec = default_data.create_time_series_points_data_spec( + from_date=from_date, to_date=to_date, resolution=resolution + ) + df = time_series_points_factory.create(spark, spec) + return df.withColumn( # just set quantity equal to its row number + DataProductColumnNames.quantity, + monotonically_increasing_id().cast(DecimalType(18, 3)), + ) + + +@pytest.mark.parametrize( + "resolution", + [ + MeteringPointResolutionDataProductValue.HOUR, + MeteringPointResolutionDataProductValue.QUARTER, + ], +) +def test_prepare_for_csv__when_two_days_of_data__returns_two_rows( + spark: SparkSession, resolution: MeteringPointResolutionDataProductValue +) -> None: + # Arrange + expected_rows = DEFAULT_TO_DATE.day - DEFAULT_FROM_DATE.day + spec = default_data.create_time_series_points_data_spec( + from_date=DEFAULT_FROM_DATE, to_date=DEFAULT_TO_DATE, resolution=resolution + ) + df = time_series_points_factory.create(spark, spec) + + # Act + result_df = prepare_for_csv( + filtered_time_series_points=df, + metering_point_resolution=resolution, + time_zone=DEFAULT_TIME_ZONE, + requesting_actor_market_role=DEFAULT_MARKET_ROLE, + ) + + # Assert + assert result_df.count() == expected_rows + + +@pytest.mark.parametrize( + "resolution, energy_quantity_column_count", + [ + (MeteringPointResolutionDataProductValue.HOUR, 25), + (MeteringPointResolutionDataProductValue.QUARTER, 100), + ], +) +def test_prepare_for_csv__returns_expected_energy_quantity_columns( + spark: SparkSession, + resolution: MeteringPointResolutionDataProductValue, + energy_quantity_column_count: int, +) -> None: + # Arrange + expected_columns = [ + f"ENERGYQUANTITY{i}" for i in range(1, energy_quantity_column_count + 1) + ] + spec = default_data.create_time_series_points_data_spec(resolution=resolution) + df = time_series_points_factory.create(spark, spec) + + # Act + actual_df = prepare_for_csv( + filtered_time_series_points=df, + metering_point_resolution=resolution, + time_zone=DEFAULT_TIME_ZONE, + requesting_actor_market_role=DEFAULT_MARKET_ROLE, + ) + + # Assert + actual_columns = [ + col for col in actual_df.columns if col.startswith("ENERGYQUANTITY") + ] + assert set(actual_columns) == set(expected_columns) + + +@pytest.mark.parametrize( + "from_date,to_date,resolution,expected_columns_with_data", + [ + ( + # Entering daylight saving time for hourly resolution + datetime(2023, 3, 25, 23), + datetime(2023, 3, 27, 22), + MeteringPointResolutionDataProductValue.HOUR, + 23, + ), + ( + # Entering daylight saving time for quarterly resolution + datetime(2023, 3, 25, 23), + datetime(2023, 3, 27, 22), + MeteringPointResolutionDataProductValue.QUARTER, + 92, + ), + ( + # Exiting daylight saving time for hourly resolution + datetime(2023, 10, 28, 22), + datetime(2023, 10, 30, 23), + MeteringPointResolutionDataProductValue.HOUR, + 25, + ), + ( + # Exiting daylight saving time for quarterly resolution + datetime(2023, 10, 28, 22), + datetime(2023, 10, 30, 23), + MeteringPointResolutionDataProductValue.QUARTER, + 100, + ), + ], +) +def test_prepare_for_csv__when_daylight_saving_tim_transition__returns_expected_energy_quantities( + spark: SparkSession, + from_date: datetime, + to_date: datetime, + resolution: MeteringPointResolutionDataProductValue, + expected_columns_with_data: int, +) -> None: + # Arrange + df = _create_time_series_points_with_increasing_quantity( + spark=spark, + from_date=from_date, + to_date=to_date, + resolution=resolution, + ) + total_columns = ( + 25 if resolution == MeteringPointResolutionDataProductValue.HOUR else 100 + ) + + # Act + actual_df = prepare_for_csv( + filtered_time_series_points=df, + metering_point_resolution=resolution, + time_zone=DEFAULT_TIME_ZONE, + requesting_actor_market_role=DEFAULT_MARKET_ROLE, + ) + + # Assert + assert actual_df.count() == 2 + dst_day = actual_df.where(F.col(CsvColumnNames.time) == from_date).collect()[0] + for i in range(1, total_columns): + expected_value = None if i > expected_columns_with_data else Decimal(i - 1) + assert dst_day[f"ENERGYQUANTITY{i}"] == expected_value diff --git a/source/settlement_report_python/tests/domain/time_series_points/test_read_and_filter.py b/source/settlement_report_python/tests/domain/time_series_points/test_read_and_filter.py new file mode 100644 index 0000000..8f1e736 --- /dev/null +++ b/source/settlement_report_python/tests/domain/time_series_points/test_read_and_filter.py @@ -0,0 +1,640 @@ +import uuid +from datetime import datetime, timedelta +from functools import reduce +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession, functions as F +import tests.test_factories.default_test_data_spec as default_data +import tests.test_factories.metering_point_time_series_factory as time_series_points_factory +import tests.test_factories.charge_link_periods_factory as charge_link_periods_factory +import tests.test_factories.charge_price_information_periods_factory as charge_price_information_periods +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, +) + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.time_series_points.read_and_filter import ( + read_and_filter_for_wholesale, + read_and_filter_for_balance_fixing, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from tests.test_factories import latest_calculations_factory +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, +) + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +NOT_SYSTEM_OPERATOR_ID = "4444444444444" +DEFAULT_TIME_ZONE = "Europe/Copenhagen" + + +@pytest.mark.parametrize( + "resolution", + [ + MeteringPointResolutionDataProductValue.HOUR, + MeteringPointResolutionDataProductValue.QUARTER, + ], +) +def test_read_and_filter_for_wholesale__when_input_has_both_resolution_types__returns_only_data_with_expected_resolution( + spark: SparkSession, + resolution: MeteringPointResolutionDataProductValue, +) -> None: + # Arrange + hourly_metering_point_id = "1111111111111" + quarterly_metering_point_id = "1515151515115" + expected_metering_point_id = ( + hourly_metering_point_id + if resolution == MeteringPointResolutionDataProductValue.HOUR + else quarterly_metering_point_id + ) + spec_hour = default_data.create_time_series_points_data_spec( + metering_point_id=hourly_metering_point_id, + resolution=MeteringPointResolutionDataProductValue.HOUR, + ) + spec_quarter = default_data.create_time_series_points_data_spec( + metering_point_id=quarterly_metering_point_id, + resolution=MeteringPointResolutionDataProductValue.QUARTER, + ) + df = time_series_points_factory.create(spark, spec_hour).union( + time_series_points_factory.create(spark, spec_quarter) + ) + + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = df + + # Act + actual_df = read_and_filter_for_wholesale( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + metering_point_resolution=resolution, + repository=mock_repository, + ) + + # Assert + actual_metering_point_ids = ( + actual_df.select(DataProductColumnNames.metering_point_id).distinct().collect() + ) + assert len(actual_metering_point_ids) == 1 + assert ( + actual_metering_point_ids[0][DataProductColumnNames.metering_point_id] + == expected_metering_point_id + ) + + +def test_read_and_filter_for_wholesale__returns_only_days_within_selected_period( + spark: SparkSession, +) -> None: + # Arrange + data_from_date = datetime(2024, 1, 1, 23) + data_to_date = datetime(2024, 1, 31, 23) + number_of_days_in_period = 2 + number_of_hours_in_period = number_of_days_in_period * 24 + period_start = datetime(2024, 1, 10, 23) + period_end = period_start + timedelta(days=number_of_days_in_period) + + df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + from_date=data_from_date, to_date=data_to_date + ), + ) + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = df + + # Act + actual_df = read_and_filter_for_wholesale( + period_start=period_start, + period_end=period_end, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + energy_supplier_ids=None, + metering_point_resolution=MeteringPointResolutionDataProductValue.HOUR, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + repository=mock_repository, + ) + + # Assert + assert actual_df.count() == number_of_hours_in_period + actual_max_time = actual_df.orderBy( + DataProductColumnNames.observation_time, ascending=False + ).first()[DataProductColumnNames.observation_time] + actual_min_time = actual_df.orderBy( + DataProductColumnNames.observation_time, ascending=True + ).first()[DataProductColumnNames.observation_time] + assert actual_min_time == period_start + assert actual_max_time == period_end - timedelta(hours=1) + + +def test_read_and_filter_for_wholesale__returns_only_selected_grid_area( + spark: SparkSession, +) -> None: + # Arrange + selected_grid_area_code = "805" + not_selected_grid_area_code = "806" + df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + grid_area_code=selected_grid_area_code, + ), + ).union( + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + grid_area_code=not_selected_grid_area_code, + ), + ) + ) + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = df + + # Act + actual_df = read_and_filter_for_wholesale( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + selected_grid_area_code: uuid.UUID(default_data.DEFAULT_CALCULATION_ID) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + metering_point_resolution=MeteringPointResolutionDataProductValue.HOUR, + repository=mock_repository, + ) + + # Assert + actual_grid_area_codes = ( + actual_df.select(DataProductColumnNames.grid_area_code).distinct().collect() + ) + assert len(actual_grid_area_codes) == 1 + assert actual_grid_area_codes[0][0] == selected_grid_area_code + + +def test_read_and_filter_for_wholesale__returns_only_metering_points_from_selected_calculation_id( + spark: SparkSession, +) -> None: + # Arrange + selected_calculation_id = "11111111-9fc8-409a-a169-fbd49479d718" + not_selected_calculation_id = "22222222-9fc8-409a-a169-fbd49479d718" + expected_metering_point_id = "123456789012345678901234567" + other_metering_point_id = "765432109876543210987654321" + df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=selected_calculation_id, + metering_point_id=expected_metering_point_id, + ), + ).union( + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=not_selected_calculation_id, + metering_point_id=other_metering_point_id, + ), + ) + ) + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = df + + # Act + actual_df = read_and_filter_for_wholesale( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID(selected_calculation_id) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + metering_point_resolution=MeteringPointResolutionDataProductValue.HOUR, + repository=mock_repository, + ) + + # Assert + actual_metering_point_ids = ( + actual_df.select(DataProductColumnNames.metering_point_id).distinct().collect() + ) + assert len(actual_metering_point_ids) == 1 + assert ( + actual_metering_point_ids[0][DataProductColumnNames.metering_point_id] + == expected_metering_point_id + ) + + +ENERGY_SUPPLIER_A = "1000000000000" +ENERGY_SUPPLIER_B = "2000000000000" +ENERGY_SUPPLIER_C = "3000000000000" +ENERGY_SUPPLIERS_ABC = [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B, ENERGY_SUPPLIER_C] + + +@pytest.mark.parametrize( + "selected_energy_supplier_ids,expected_energy_supplier_ids", + [ + (None, ENERGY_SUPPLIERS_ABC), + ([ENERGY_SUPPLIER_B], [ENERGY_SUPPLIER_B]), + ( + [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B], + [ENERGY_SUPPLIER_A, ENERGY_SUPPLIER_B], + ), + (ENERGY_SUPPLIERS_ABC, ENERGY_SUPPLIERS_ABC), + ], +) +def test_read_and_filter_for_wholesale__returns_data_for_expected_energy_suppliers( + spark: SparkSession, + selected_energy_supplier_ids: list[str] | None, + expected_energy_supplier_ids: list[str], +) -> None: + # Arrange + df = reduce( + lambda df1, df2: df1.union(df2), + [ + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + energy_supplier_id=energy_supplier_id, + ), + ) + for energy_supplier_id in ENERGY_SUPPLIERS_ABC + ], + ) + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = df + + # Act + actual_df = read_and_filter_for_wholesale( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + energy_supplier_ids=selected_energy_supplier_ids, + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + requesting_actor_id=DATAHUB_ADMINISTRATOR_ID, + metering_point_resolution=MeteringPointResolutionDataProductValue.HOUR, + repository=mock_repository, + ) + + # Assert + assert set( + row[DataProductColumnNames.energy_supplier_id] for row in actual_df.collect() + ) == set(expected_energy_supplier_ids) + + +@pytest.mark.parametrize( + "charge_owner_id,return_rows", + [ + (SYSTEM_OPERATOR_ID, True), + (NOT_SYSTEM_OPERATOR_ID, False), + ], +) +def test_read_and_filter_for_wholesale__when_system_operator__returns_only_time_series_points_with_system_operator_as_charge_owner( + spark: SparkSession, + charge_owner_id: str, + return_rows: bool, +) -> None: + # Arrange + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec(), + ) + charge_price_information_period_df = charge_price_information_periods.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=SYSTEM_OPERATOR_ID + ), + ) + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_owner_id=SYSTEM_OPERATOR_ID), + ) + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = time_series_points_df + mock_repository.read_charge_price_information_periods.return_value = ( + charge_price_information_period_df + ) + mock_repository.read_charge_link_periods.return_value = charge_link_periods_df + + # Act + actual = read_and_filter_for_wholesale( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: uuid.UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + energy_supplier_ids=None, + requesting_actor_market_role=MarketRole.SYSTEM_OPERATOR, + requesting_actor_id=charge_owner_id, + metering_point_resolution=MeteringPointResolutionDataProductValue.HOUR, + repository=mock_repository, + ) + + # Assert + assert (actual.count() > 0) == return_rows + + +def test_read_and_filter_for_balance_fixing__returns_only_time_series_points_from_latest_calculations( + spark: SparkSession, +) -> None: + # Arrange + not_latest_calculation_id = "11111111-9fc8-409a-a169-fbd49479d718" + latest_calculation_id = "22222222-9fc8-409a-a169-fbd49479d718" + time_series_points_df = reduce( + lambda df1, df2: df1.union(df2), + [ + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=calculation_id + ), + ) + for calculation_id in [latest_calculation_id, not_latest_calculation_id] + ], + ) + latest_calculations = latest_calculations_factory.create( + spark, + default_data.create_latest_calculations_per_day_row( + calculation_id=latest_calculation_id, + calculation_type=CalculationTypeDataProductValue.BALANCE_FIXING, + ), + ) + + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = time_series_points_df + mock_repository.read_latest_calculations.return_value = latest_calculations + + # Act + actual_df = read_and_filter_for_balance_fixing( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + grid_area_codes=[default_data.DEFAULT_GRID_AREA_CODE], + energy_supplier_ids=None, + metering_point_resolution=default_data.DEFAULT_RESOLUTION, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + actual_calculation_ids = ( + actual_df.select(DataProductColumnNames.calculation_id).distinct().collect() + ) + assert len(actual_calculation_ids) == 1 + assert ( + actual_calculation_ids[0][DataProductColumnNames.calculation_id] + == latest_calculation_id + ) + + +def test_read_and_filter_for_balance_fixing__returns_only_balance_fixing_results( + spark: SparkSession, +) -> None: + # Arrange + calculation_id_and_type = { + "11111111-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.AGGREGATION, + "22222222-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.BALANCE_FIXING, + "33333333-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.WHOLESALE_FIXING, + "44444444-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.FIRST_CORRECTION_SETTLEMENT, + "55555555-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.SECOND_CORRECTION_SETTLEMENT, + "66666666-9fc8-409a-a169-fbd49479d718": CalculationTypeDataProductValue.THIRD_CORRECTION_SETTLEMENT, + } + time_series_points = reduce( + lambda df1, df2: df1.union(df2), + [ + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=calc_id, calculation_type=calc_type + ), + ) + for calc_id, calc_type in calculation_id_and_type.items() + ], + ) + + latest_calculations = reduce( + lambda df1, df2: df1.union(df2), + [ + latest_calculations_factory.create( + spark, + default_data.create_latest_calculations_per_day_row( + calculation_id=calc_id, calculation_type=calc_type + ), + ) + for calc_id, calc_type in calculation_id_and_type.items() + ], + ) + + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = time_series_points + mock_repository.read_latest_calculations.return_value = latest_calculations + + # Act + actual_df = read_and_filter_for_balance_fixing( + period_start=DEFAULT_FROM_DATE, + period_end=DEFAULT_TO_DATE, + grid_area_codes=[default_data.DEFAULT_GRID_AREA_CODE], + energy_supplier_ids=None, + metering_point_resolution=default_data.DEFAULT_RESOLUTION, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + actual_calculation_ids = ( + actual_df.select(DataProductColumnNames.calculation_id).distinct().collect() + ) + assert len(actual_calculation_ids) == 1 + assert ( + actual_calculation_ids[0][DataProductColumnNames.calculation_id] + == "22222222-9fc8-409a-a169-fbd49479d718" + ) + + +def test_read_and_filter_for_balance_fixing__when_two_calculations_with_time_overlap__returns_only_latest_calculation_data( + spark: SparkSession, +) -> None: + # Arrange + day_1 = DEFAULT_FROM_DATE + day_2 = day_1 + timedelta(days=1) + day_3 = day_1 + timedelta(days=2) + day_4 = day_1 + timedelta(days=3) # exclusive + calculation_id_1 = "11111111-9fc8-409a-a169-fbd49479d718" + calculation_id_2 = "22222222-9fc8-409a-a169-fbd49479d718" + calc_type = CalculationTypeDataProductValue.BALANCE_FIXING + + time_series_points = reduce( + lambda df1, df2: df1.union(df2), + [ + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=calc_id, + calculation_type=calc_type, + from_date=from_date, + to_date=to_date, + ), + ) + for calc_id, from_date, to_date in [ + (calculation_id_1, day_1, day_3), + (calculation_id_2, day_2, day_4), + ] + ], + ) + + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + calculation_id=calc_id, + calculation_type=calc_type, + start_of_day=start_of_day, + ) + for calc_id, start_of_day in [ + (calculation_id_1, day_1), + (calculation_id_1, day_2), + (calculation_id_2, day_3), + ] + ], + ) + + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = time_series_points + mock_repository.read_latest_calculations.return_value = latest_calculations + + # Act + actual_df = read_and_filter_for_balance_fixing( + period_start=day_1, + period_end=day_4, + grid_area_codes=[default_data.DEFAULT_GRID_AREA_CODE], + energy_supplier_ids=None, + metering_point_resolution=default_data.DEFAULT_RESOLUTION, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + + for day, expected_calculation_id in zip( + [day_1, day_2, day_3], [calculation_id_1, calculation_id_1, calculation_id_2] + ): + actual_calculation_ids = ( + actual_df.where( + (F.col(DataProductColumnNames.observation_time) >= day) + & ( + F.col(DataProductColumnNames.observation_time) + < day + timedelta(days=1) + ) + ) + .select(DataProductColumnNames.calculation_id) + .distinct() + .collect() + ) + assert len(actual_calculation_ids) == 1 + assert actual_calculation_ids[0][0] == expected_calculation_id + + +def test_read_and_filter_for_balance_fixing__latest_calculation_for_grid_area( + spark: SparkSession, +) -> None: + # Arrange + day_1 = DEFAULT_FROM_DATE + day_2 = day_1 + timedelta(days=1) # exclusive + grid_area_1 = "805" + grid_area_2 = "806" + calculation_id_1 = "11111111-9fc8-409a-a169-fbd49479d718" + calculation_id_2 = "22222222-9fc8-409a-a169-fbd49479d718" + calc_type = CalculationTypeDataProductValue.BALANCE_FIXING + + time_series_points = reduce( + lambda df1, df2: df1.union(df2), + [ + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=calc_id, + calculation_type=calc_type, + grid_area_code=grid_area, + from_date=day_1, + to_date=day_2, + ), + ) + for calc_id, grid_area in [ + (calculation_id_1, grid_area_1), + (calculation_id_1, grid_area_2), + (calculation_id_2, grid_area_2), + ] + ], + ) + + latest_calculations = latest_calculations_factory.create( + spark, + [ + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_1, + calculation_type=calc_type, + grid_area_code=grid_area_1, + start_of_day=day_1, + ), + default_data.create_latest_calculations_per_day_row( + calculation_id=calculation_id_2, + calculation_type=calc_type, + grid_area_code=grid_area_2, + start_of_day=day_1, + ), + ], + ) + + mock_repository = Mock() + mock_repository.read_metering_point_time_series.return_value = time_series_points + mock_repository.read_latest_calculations.return_value = latest_calculations + + # Act + actual_df = read_and_filter_for_balance_fixing( + period_start=day_1, + period_end=day_2, + grid_area_codes=[grid_area_1, grid_area_2], + energy_supplier_ids=None, + metering_point_resolution=default_data.DEFAULT_RESOLUTION, + time_zone=DEFAULT_TIME_ZONE, + repository=mock_repository, + ) + + # Assert + assert all( + row[DataProductColumnNames.calculation_id] == calculation_id_1 + for row in actual_df.where( + F.col(DataProductColumnNames.grid_area_code) == grid_area_1 + ) + .select(DataProductColumnNames.calculation_id) + .distinct() + .collect() + ) + + assert all( + row[DataProductColumnNames.calculation_id] == calculation_id_2 + for row in actual_df.where( + F.col(DataProductColumnNames.grid_area_code) == grid_area_2 + ) + .select(DataProductColumnNames.calculation_id) + .distinct() + .collect() + ) diff --git a/source/settlement_report_python/tests/domain/utils/__init__.py b/source/settlement_report_python/tests/domain/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/domain/utils/test_map_from_dict.py b/source/settlement_report_python/tests/domain/utils/test_map_from_dict.py new file mode 100644 index 0000000..6a71926 --- /dev/null +++ b/source/settlement_report_python/tests/domain/utils/test_map_from_dict.py @@ -0,0 +1,43 @@ +from pyspark.sql import SparkSession, functions as F + +from settlement_report_job.domain.utils.map_from_dict import map_from_dict + + +def test_map_from_dict__when_applied_to_new_col__returns_df_with_new_col( + spark: SparkSession, +): + # Arrange + df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["key", "value"]) + + # Act + mapper = map_from_dict({"a": "another_a"}) + actual = df.select("*", mapper[F.col("key")].alias("new_key")) + + # Assert + expected = spark.createDataFrame( + [("a", 1, "another_a"), ("b", 2, None), ("c", 3, None)], + ["key", "value", "new_key"], + ) + assert actual.collect() == expected.collect() + + +def test_map_from_dict__when_applied_as_overwrite__returns_df_with_overwritten_column( + spark: SparkSession, +): + # Arrange + df = spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["key", "value"]) + + # Act + mapper = map_from_dict({"a": "another_a"}) + actual = df.select(mapper[F.col("key")].alias("key"), "value") + + # Assert + expected = spark.createDataFrame( + [ + ("another_a", 1), + (None, 2), + (None, 3), + ], + ["key", "value"], + ) + assert actual.collect() == expected.collect() diff --git a/source/settlement_report_python/tests/domain/utils/test_map_to_csv_naming.py b/source/settlement_report_python/tests/domain/utils/test_map_to_csv_naming.py new file mode 100644 index 0000000..2ef07a0 --- /dev/null +++ b/source/settlement_report_python/tests/domain/utils/test_map_to_csv_naming.py @@ -0,0 +1,278 @@ +import pytest +from pyspark.sql import SparkSession, functions as F +from pyspark.sql.types import StructType, StructField, StringType + +from settlement_report_job.domain.utils.map_from_dict import map_from_dict +from settlement_report_job.infrastructure.wholesale.data_values import ( + SettlementMethodDataProductValue, + MeteringPointTypeDataProductValue, + ChargeTypeDataProductValue, + CalculationTypeDataProductValue, +) + +import settlement_report_job.domain.utils.map_to_csv_naming as mapping_dicts + + +@pytest.mark.parametrize( + "charge_type, expected_charge_type", + [ + pytest.param( + ChargeTypeDataProductValue.SUBSCRIPTION, + "D01", + id="when charge type is subscription, then charge type is D01", + ), + pytest.param( + ChargeTypeDataProductValue.FEE, + "D02", + id="when charge type is fee, then charge type is D02", + ), + pytest.param( + ChargeTypeDataProductValue.TARIFF, + "D03", + id="when charge type is tariff, then charge type is D03", + ), + ], +) +def test_mapping_of_charge_type( + spark: SparkSession, + charge_type: ChargeTypeDataProductValue, + expected_charge_type: str, +) -> None: + # Arrange + df = spark.createDataFrame( + data=[[charge_type.value]], + schema=StructType([StructField("charge_type", StringType(), True)]), + ) + + # Act + actual = df.select(map_from_dict(mapping_dicts.CHARGE_TYPES)[F.col("charge_type")]) + + # Assert + assert actual.collect()[0][0] == expected_charge_type + + +@pytest.mark.parametrize( + "calculation_type, expected_process_variant", + [ + pytest.param( + CalculationTypeDataProductValue.FIRST_CORRECTION_SETTLEMENT, + "1ST", + id="when calculation type is first_correction_settlement, then process variant is 1ST", + ), + pytest.param( + CalculationTypeDataProductValue.SECOND_CORRECTION_SETTLEMENT, + "2ND", + id="when calculation type is second_correction_settlement, then process variant is 2ND", + ), + pytest.param( + CalculationTypeDataProductValue.THIRD_CORRECTION_SETTLEMENT, + "3RD", + id="when calculation type is third_correction_settlement, then process variant is 3RD", + ), + pytest.param( + CalculationTypeDataProductValue.WHOLESALE_FIXING, + None, + id="when calculation type is wholesale_fixing, then process variant is None", + ), + pytest.param( + CalculationTypeDataProductValue.BALANCE_FIXING, + None, + id="when calculation type is balance_fixing, then process variant is None", + ), + ], +) +def test_mapping_of_process_variant( + spark: SparkSession, + calculation_type: CalculationTypeDataProductValue, + expected_process_variant: str, +) -> None: + # Arrange + df = spark.createDataFrame([[calculation_type.value]], ["calculation_type"]) + + # Act + actual = df.select( + map_from_dict(mapping_dicts.CALCULATION_TYPES_TO_PROCESS_VARIANT)[ + F.col("calculation_type") + ] + ) + + # Assert + assert actual.collect()[0][0] == expected_process_variant + + +@pytest.mark.parametrize( + "calculation_type, expected_energy_business_process", + [ + pytest.param( + CalculationTypeDataProductValue.BALANCE_FIXING, + "D04", + id="when calculation type is balance_fixing, then energy business process is D04", + ), + pytest.param( + CalculationTypeDataProductValue.WHOLESALE_FIXING, + "D05", + id="when calculation type is wholesale_fixing, then energy business process is D05", + ), + pytest.param( + CalculationTypeDataProductValue.FIRST_CORRECTION_SETTLEMENT, + "D32", + id="when calculation type is first_correction_settlement, then energy business process is D32", + ), + pytest.param( + CalculationTypeDataProductValue.SECOND_CORRECTION_SETTLEMENT, + "D32", + id="when calculation type is second_correction_settlement, then energy business process is D32", + ), + pytest.param( + CalculationTypeDataProductValue.THIRD_CORRECTION_SETTLEMENT, + "D32", + id="when calculation type is third_correction_settlement, then energy business process is D32", + ), + ], +) +def test_mapping_of_energy_business_process( + spark: SparkSession, + calculation_type: CalculationTypeDataProductValue, + expected_energy_business_process: str, +) -> None: + # Arrange + df = spark.createDataFrame([[calculation_type.value]], ["calculation_type"]) + + # Act + actual = df.select( + map_from_dict(mapping_dicts.CALCULATION_TYPES_TO_ENERGY_BUSINESS_PROCESS)[ + F.col("calculation_type") + ] + ) + + # Assert + assert actual.collect()[0][0] == expected_energy_business_process + + +@pytest.mark.parametrize( + "metering_point_type, expected_metering_point_type", + [ + pytest.param( + MeteringPointTypeDataProductValue.CONSUMPTION, + "E17", + id="when metering point type is consumption, then type of mp is E17", + ), + pytest.param( + MeteringPointTypeDataProductValue.PRODUCTION, + "E18", + id="when metering point type is production, then type of mp is E18", + ), + pytest.param( + MeteringPointTypeDataProductValue.EXCHANGE, + "E20", + id="when metering point type is exchange, then type of mp is E20", + ), + pytest.param( + MeteringPointTypeDataProductValue.VE_PRODUCTION, + "D01", + id="when metering point type is ve_production, then type of mp is D01", + ), + pytest.param( + MeteringPointTypeDataProductValue.NET_PRODUCTION, + "D05", + id="when metering point type is net_production, then type of mp is D05", + ), + pytest.param( + MeteringPointTypeDataProductValue.SUPPLY_TO_GRID, + "D06", + id="when metering point type is supply_to_grid, then type of mp is D06", + ), + pytest.param( + MeteringPointTypeDataProductValue.CONSUMPTION_FROM_GRID, + "D07", + id="when metering point type is consumption_from_grid, then type of mp is D07", + ), + pytest.param( + MeteringPointTypeDataProductValue.WHOLESALE_SERVICES_INFORMATION, + "D08", + id="when metering point type is wholesale_services_information, then type of mp is D08", + ), + pytest.param( + MeteringPointTypeDataProductValue.OWN_PRODUCTION, + "D09", + id="when metering point type is own_production, then type of mp is D09", + ), + pytest.param( + MeteringPointTypeDataProductValue.NET_FROM_GRID, + "D10", + id="when metering point type is net_from_grid, then type of mp is D10", + ), + pytest.param( + MeteringPointTypeDataProductValue.NET_TO_GRID, + "D11", + id="when metering point type is net_to_grid, then type of mp is D11", + ), + pytest.param( + MeteringPointTypeDataProductValue.TOTAL_CONSUMPTION, + "D12", + id="when metering point type is total_consumption, then type of mp is D12", + ), + pytest.param( + MeteringPointTypeDataProductValue.ELECTRICAL_HEATING, + "D14", + id="when metering point type is electrical_heating, then type of mp is D14", + ), + pytest.param( + MeteringPointTypeDataProductValue.NET_CONSUMPTION, + "D15", + id="when metering point type is net_consumption, then type of mp is D15", + ), + pytest.param( + MeteringPointTypeDataProductValue.EFFECT_SETTLEMENT, + "D19", + id="when metering point type is effect_settlement, then type of mp is D19", + ), + ], +) +def test_mapping_of_metering_point_type( + spark: SparkSession, + metering_point_type: MeteringPointTypeDataProductValue, + expected_metering_point_type: str, +) -> None: + # Arrange + df = spark.createDataFrame([[metering_point_type.value]], ["metering_point_type"]) + + # Act + actual = df.select( + map_from_dict(mapping_dicts.METERING_POINT_TYPES)[F.col("metering_point_type")] + ) + + # Assert + assert actual.collect()[0][0] == expected_metering_point_type + + +@pytest.mark.parametrize( + "settlement_method, expected_settlement_method", + [ + pytest.param( + SettlementMethodDataProductValue.NON_PROFILED, + "E02", + id="when settlement method is non_profiled, then settlement method is E02", + ), + pytest.param( + SettlementMethodDataProductValue.FLEX, + "D01", + id="when settlement method is flex, then settlement method is D01", + ), + ], +) +def test_mapping_of_settlement_method( + spark: SparkSession, + settlement_method: SettlementMethodDataProductValue, + expected_settlement_method: str, +) -> None: + # Arrange + df = spark.createDataFrame([[settlement_method.value]], ["settlement_method"]) + + # Act + actual = df.select( + map_from_dict(mapping_dicts.SETTLEMENT_METHODS)[F.col("settlement_method")] + ) + + # Assert + assert actual.collect()[0][0] == expected_settlement_method diff --git a/source/settlement_report_python/tests/domain/utils/test_merge_periods.py b/source/settlement_report_python/tests/domain/utils/test_merge_periods.py new file mode 100644 index 0000000..a4f3983 --- /dev/null +++ b/source/settlement_report_python/tests/domain/utils/test_merge_periods.py @@ -0,0 +1,306 @@ +from datetime import datetime + +import pytest +from pyspark.sql import SparkSession, functions as F + +from settlement_report_job.domain.utils.merge_periods import ( + merge_connected_periods, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + +JAN_1ST = datetime(2023, 12, 31, 23) +JAN_2ND = datetime(2024, 1, 1, 23) +JAN_3RD = datetime(2024, 1, 2, 23) +JAN_4TH = datetime(2024, 1, 3, 23) +JAN_5TH = datetime(2024, 1, 4, 23) +JAN_6TH = datetime(2024, 1, 5, 23) +JAN_7TH = datetime(2024, 1, 6, 23) +JAN_8TH = datetime(2024, 1, 7, 23) +JAN_9TH = datetime(2024, 1, 8, 23) + + +@pytest.mark.parametrize( + "periods,expected_periods", + [ + pytest.param( + [ + (JAN_1ST, JAN_3RD), + (JAN_2ND, JAN_4TH), + ], + [ + (JAN_1ST, JAN_4TH), + ], + id="two overlapping periods", + ), + pytest.param( + [ + (JAN_1ST, JAN_3RD), + (JAN_2ND, JAN_4TH), + (JAN_3RD, JAN_5TH), + ], + [ + (JAN_1ST, JAN_5TH), + ], + id="three overlapping periods", + ), + pytest.param( + [ + (JAN_1ST, JAN_3RD), + (JAN_2ND, JAN_4TH), + (JAN_5TH, JAN_6TH), + ], + [ + (JAN_1ST, JAN_4TH), + (JAN_5TH, JAN_6TH), + ], + id="two overlaps and one isolated", + ), + pytest.param( + [ + (JAN_1ST, JAN_3RD), + (JAN_2ND, JAN_4TH), + (JAN_5TH, JAN_7TH), + (JAN_6TH, JAN_8TH), + ], + [ + (JAN_1ST, JAN_4TH), + (JAN_5TH, JAN_8TH), + ], + id="two times two overlaps", + ), + pytest.param( + [ + (JAN_1ST, JAN_3RD), + (JAN_1ST, JAN_3RD), + ], + [ + (JAN_1ST, JAN_3RD), + ], + id="two perfect overlaps", + ), + ], +) +def test_merge_connecting_periods__when_overlapping_periods__returns_merged_periods( + spark: SparkSession, + periods: list[tuple[datetime, datetime]], + expected_periods: list[tuple[datetime, datetime]], +) -> None: + # Arrange + df = spark.createDataFrame( + [("1", from_date, to_date) for from_date, to_date in periods], + [ + DataProductColumnNames.charge_key, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + ], + ).orderBy(F.rand()) + + # Act + actual = merge_connected_periods(df) + + # Assert + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.count() == len(expected_periods) + for i, (expected_from, expected_to) in enumerate(expected_periods): + assert actual.collect()[i][DataProductColumnNames.from_date] == expected_from + assert actual.collect()[i][DataProductColumnNames.to_date] == expected_to + + +@pytest.mark.parametrize( + "periods,expected_periods", + [ + pytest.param( + [ + (JAN_1ST, JAN_2ND), + (JAN_3RD, JAN_4TH), + ], + [ + (JAN_1ST, JAN_2ND), + (JAN_3RD, JAN_4TH), + ], + id="no connected periods", + ), + pytest.param( + [ + (JAN_1ST, JAN_2ND), + (JAN_2ND, JAN_3RD), + (JAN_4TH, JAN_5TH), + (JAN_5TH, JAN_6TH), + ], + [ + (JAN_1ST, JAN_3RD), + (JAN_4TH, JAN_6TH), + ], + id="two connect and two others connected", + ), + pytest.param( + [ + (JAN_1ST, JAN_2ND), + (JAN_2ND, JAN_3RD), + (JAN_3RD, JAN_4TH), + (JAN_5TH, JAN_6TH), + ], + [ + (JAN_1ST, JAN_4TH), + (JAN_5TH, JAN_6TH), + ], + id="three connected and one not connected", + ), + ], +) +def test_merge_connecting_periods__when_connections_and_gaps_between_periods__returns_merged_rows( + spark: SparkSession, + periods: list[tuple[datetime, datetime]], + expected_periods: list[tuple[datetime, datetime]], +) -> None: + # Arrange + df = spark.createDataFrame( + [("1", from_date, to_date) for from_date, to_date in periods], + [ + DataProductColumnNames.charge_key, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + ], + ).orderBy(F.rand()) + + # Act + actual = merge_connected_periods(df) + + # Assert + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.count() == len(expected_periods) + for i, (expected_from, expected_to) in enumerate(expected_periods): + assert actual.collect()[i][DataProductColumnNames.from_date] == expected_from + assert actual.collect()[i][DataProductColumnNames.to_date] == expected_to + + +@pytest.mark.parametrize( + "periods,expected_periods", + [ + pytest.param( + [ + ("1", JAN_1ST, JAN_2ND), + ("2", JAN_2ND, JAN_3RD), + ], + [ + ("1", JAN_1ST, JAN_2ND), + ("2", JAN_2ND, JAN_3RD), + ], + id="connected but different group", + ), + pytest.param( + [ + ("1", JAN_1ST, JAN_2ND), + ("2", JAN_2ND, JAN_4TH), + ("1", JAN_2ND, JAN_3RD), + ], + [ + ("1", JAN_1ST, JAN_3RD), + ("2", JAN_2ND, JAN_4TH), + ], + id="one group has overlap and another group has no overlap", + ), + ], +) +def test_merge_connecting_periods__when_overlap_but_difference_groups__returns_without_merge( + spark: SparkSession, + periods: list[tuple[str, datetime, datetime]], + expected_periods: list[tuple[str, datetime, datetime]], +) -> None: + # Arrange + some_column_name = "some_column" + df = spark.createDataFrame( + [ + (some_column_value, from_date, to_date) + for some_column_value, from_date, to_date in periods + ], + [ + some_column_name, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + ], + ).orderBy(F.rand()) + + # Act + actual = merge_connected_periods(df) + + # Assert + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.count() == len(expected_periods) + for i, (expected_some_column_value, expected_from, expected_to) in enumerate( + expected_periods + ): + assert actual.collect()[i][some_column_name] == expected_some_column_value + assert actual.collect()[i][DataProductColumnNames.from_date] == expected_from + assert actual.collect()[i][DataProductColumnNames.to_date] == expected_to + + +@pytest.mark.parametrize( + "periods,expected_periods", + [ + pytest.param( + [ + ("A", "B", JAN_1ST, JAN_2ND), + ("A", "C", JAN_2ND, JAN_3RD), + ("B", "C", JAN_3RD, JAN_4TH), + ], + [ + ("A", "B", JAN_1ST, JAN_2ND), + ("A", "C", JAN_2ND, JAN_3RD), + ("B", "C", JAN_3RD, JAN_4TH), + ], + id="overlaps but not same group", + ), + pytest.param( + [ + ("A", "B", JAN_1ST, JAN_2ND), + ("A", "C", JAN_2ND, JAN_3RD), + ("A", "C", JAN_3RD, JAN_4TH), + ], + [ + ("A", "B", JAN_1ST, JAN_2ND), + ("A", "C", JAN_2ND, JAN_4TH), + ], + id="overlaps and including same group", + ), + ], +) +def test_merge_connecting_periods__when_multiple_other_columns_and_no___returns_expected( + spark: SparkSession, + periods: list[tuple[str, str, datetime, datetime]], + expected_periods: list[tuple[str, str, datetime, datetime]], +) -> None: + # Arrange + column_a = "column_a" + column_b = "column_b" + df = spark.createDataFrame( + [ + (col_a_value, col_b_value, from_date, to_date) + for col_a_value, col_b_value, from_date, to_date in periods + ], + [ + column_a, + column_b, + DataProductColumnNames.from_date, + DataProductColumnNames.to_date, + ], + ).orderBy(F.rand()) + + # Act + actual = merge_connected_periods(df) + + # Assert + actual = actual.orderBy(DataProductColumnNames.from_date) + assert actual.count() == len(expected_periods) + for i, ( + expected_col_a, + expected_col_b, + expected_from, + expected_to, + ) in enumerate(expected_periods): + assert actual.collect()[i][column_a] == expected_col_a + assert actual.collect()[i][column_b] == expected_col_b + assert actual.collect()[i][DataProductColumnNames.from_date] == expected_from + assert actual.collect()[i][DataProductColumnNames.to_date] == expected_to diff --git a/source/settlement_report_python/tests/domain/utils/test_system_operator_filter.py b/source/settlement_report_python/tests/domain/utils/test_system_operator_filter.py new file mode 100644 index 0000000..d1b7615 --- /dev/null +++ b/source/settlement_report_python/tests/domain/utils/test_system_operator_filter.py @@ -0,0 +1,358 @@ +from datetime import datetime +import pytest +from pyspark.sql import SparkSession +import tests.test_factories.default_test_data_spec as default_data +import tests.test_factories.metering_point_time_series_factory as time_series_points_factory +import tests.test_factories.charge_link_periods_factory as charge_link_periods_factory +import tests.test_factories.charge_price_information_periods_factory as charge_price_information_periods_factory + +from settlement_report_job.domain.time_series_points.system_operator_filter import ( + filter_time_series_points_on_charge_owner, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) + + +@pytest.mark.parametrize( + "mp_from_date, mp_to_date, charge_from_date, charge_to_date, expected_row_count", + [ + ( + # one day overlap charge starts later + datetime(2022, 1, 1, 23), + datetime(2022, 1, 3, 23), + datetime(2022, 1, 2, 23), + datetime(2022, 1, 4, 23), + 24, + ), + ( + # one day overlap metering point period starts later + datetime(2022, 1, 2, 23), + datetime(2022, 1, 4, 23), + datetime(2022, 1, 1, 23), + datetime(2022, 1, 3, 23), + 24, + ), + ( + # no overlap + datetime(2022, 1, 2, 23), + datetime(2022, 1, 4, 23), + datetime(2022, 1, 4, 23), + datetime(2022, 1, 6, 23), + 0, + ), + ], +) +def test_filter_time_series_points_on_charge_owner__returns_only_time_series_points_within_charge_period( + spark: SparkSession, + mp_from_date: datetime, + mp_to_date: datetime, + charge_from_date: datetime, + charge_to_date: datetime, + expected_row_count: int, +) -> None: + # Arrange + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + from_date=charge_from_date, to_date=charge_to_date + ), + ) + charge_price_information_periods_df = ( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + from_date=charge_from_date, to_date=charge_to_date + ), + ) + ) + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + from_date=mp_from_date, to_date=mp_to_date + ), + ) + + # Act + actual = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points_df, + system_operator_id=default_data.DEFAULT_CHARGE_OWNER_ID, + charge_link_periods=charge_link_periods_df, + charge_price_information_periods=charge_price_information_periods_df, + ) + + # Assert + assert actual.count() == expected_row_count + + +@pytest.mark.parametrize( + "calculation_id_charge_price_information, calculation_id_charge_link, calculation_id_metering_point, returns_data", + [ + ( + "11111111-1111-1111-1111-111111111111", + "11111111-1111-1111-1111-111111111111", + "11111111-1111-1111-1111-111111111111", + True, + ), + ( + "22222222-1111-1111-1111-111111111111", + "22222222-1111-1111-1111-111111111111", + "11111111-1111-1111-1111-111111111111", + False, + ), + ( + "22222222-1111-1111-1111-111111111111", + "11111111-1111-1111-1111-111111111111", + "11111111-1111-1111-1111-111111111111", + False, + ), + ], +) +def test_filter_time_series_points_on_charge_owner__returns_only_time_series_points_if_calculation_id_is_the_same_as_for_the_charge( + spark: SparkSession, + calculation_id_charge_price_information: str, + calculation_id_charge_link: str, + calculation_id_metering_point: str, + returns_data: bool, +) -> None: + # Arrange + charge_price_information_periods_df = ( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + calculation_id=calculation_id_charge_price_information, + ), + ) + ) + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + calculation_id=calculation_id_charge_link, + ), + ) + + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + calculation_id=calculation_id_metering_point, + ), + ) + + # Act + actual = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points_df, + system_operator_id=default_data.DEFAULT_CHARGE_OWNER_ID, + charge_link_periods=charge_link_periods_df, + charge_price_information_periods=charge_price_information_periods_df, + ) + + # Assert + assert (actual.count() > 0) == returns_data + + +def test_filter_time_series_points_on_charge_owner__returns_only_time_series_points_where_the_charge_link_has_the_same_metering_point_id( + spark: SparkSession, +) -> None: + # Arrange + charge_price_information_periods_df = ( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row(), + ) + ) + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + metering_point_id="matching_metering_point_id" + ), + ) + + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + metering_point_id="matching_metering_point_id" + ), + ).union( + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + metering_point_id="non_matching_metering_point_id" + ), + ) + ) + + # Act + actual = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points_df, + system_operator_id=default_data.DEFAULT_CHARGE_OWNER_ID, + charge_link_periods=charge_link_periods_df, + charge_price_information_periods=charge_price_information_periods_df, + ) + + # Assert + assert ( + actual.select(DataProductColumnNames.metering_point_id).distinct().count() == 1 + ) + assert ( + actual.select(DataProductColumnNames.metering_point_id).distinct().first()[0] + == "matching_metering_point_id" + ) + + +def test_filter_time_series_points_on_charge_owner__when_multiple_links_matches_on_metering_point_id__returns_expected_number_of_time_series_points( + spark: SparkSession, +) -> None: + # Arrange + charge_price_information_periods_df = ( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_code="code1" + ), + ) + ).union( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_code="code2" + ), + ) + ) + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_code="code1"), + ).union( + charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_code="code2"), + ) + ) + + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec(), + ) + + # Act + actual = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points_df, + system_operator_id=default_data.DEFAULT_CHARGE_OWNER_ID, + charge_link_periods=charge_link_periods_df, + charge_price_information_periods=charge_price_information_periods_df, + ) + + # Assert + assert actual.count() == 24 + + +def test_filter_time_series_points_on_charge_owner__when_charge_owner_is_not_system_operator__returns_time_series_points_without_that_metering_point( + spark: SparkSession, +) -> None: + # Arrange + system_operator_id = "1234567890123" + not_system_operator_id = "9876543210123" + system_operator_metering_point_id = "1111111111111" + not_system_operator_metering_point_id = "2222222222222" + charge_price_information_periods_df = ( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=system_operator_id + ), + ) + ).union( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=not_system_operator_id + ), + ) + ) + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + metering_point_id=system_operator_metering_point_id, + charge_owner_id=system_operator_id, + ), + ).union( + charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row( + metering_point_id=not_system_operator_metering_point_id, + charge_owner_id=not_system_operator_id, + ), + ) + ) + + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + metering_point_id=system_operator_metering_point_id + ), + ).union( + time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec( + metering_point_id=not_system_operator_metering_point_id + ), + ) + ) + + # Act + actual = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points_df, + system_operator_id=system_operator_id, + charge_link_periods=charge_link_periods_df, + charge_price_information_periods=charge_price_information_periods_df, + ) + + # Assert + assert ( + actual.select(DataProductColumnNames.metering_point_id).distinct().count() == 1 + ) + assert ( + actual.select(DataProductColumnNames.metering_point_id).distinct().first()[0] + == system_operator_metering_point_id + ) + + +@pytest.mark.parametrize( + "is_tax, returns_rows", + [ + (True, False), + (False, True), + ], +) +def test_filter_time_series_points_on_charge_owner__returns_only_time_series_points_from_metering_points_without_tax_associated( + spark: SparkSession, is_tax: bool, returns_rows: bool +) -> None: + # Arrange + system_operator_id = "1234567890123" + charge_price_information_periods_df = ( + charge_price_information_periods_factory.create( + spark, + default_data.create_charge_price_information_periods_row( + charge_owner_id=system_operator_id, is_tax=is_tax + ), + ) + ) + charge_link_periods_df = charge_link_periods_factory.create( + spark, + default_data.create_charge_link_periods_row(charge_owner_id=system_operator_id), + ) + + time_series_points_df = time_series_points_factory.create( + spark, + default_data.create_time_series_points_data_spec(), + ) + + # Act + actual = filter_time_series_points_on_charge_owner( + time_series_points=time_series_points_df, + system_operator_id=system_operator_id, + charge_link_periods=charge_link_periods_df, + charge_price_information_periods=charge_price_information_periods_df, + ) + + # Assert + assert (actual.count() > 0) == returns_rows diff --git a/source/settlement_report_python/tests/domain/wholesale_results/test_wholesale_read_and_filter.py b/source/settlement_report_python/tests/domain/wholesale_results/test_wholesale_read_and_filter.py new file mode 100644 index 0000000..907d079 --- /dev/null +++ b/source/settlement_report_python/tests/domain/wholesale_results/test_wholesale_read_and_filter.py @@ -0,0 +1,247 @@ +from uuid import UUID +from datetime import datetime +from unittest.mock import Mock + +import pytest +from pyspark.sql import SparkSession + +import test_factories.default_test_data_spec as default_data +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.wholesale_results.read_and_filter import ( + read_and_filter_from_view, +) +from test_factories.default_test_data_spec import create_amounts_per_charge_row +from test_factories.amounts_per_charge_factory import create + + +DEFAULT_FROM_DATE = default_data.DEFAULT_FROM_DATE +DEFAULT_TO_DATE = default_data.DEFAULT_TO_DATE +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +SYSTEM_OPERATOR_ID = "3333333333333" +NOT_SYSTEM_OPERATOR_ID = "4444444444444" +DEFAULT_TIME_ZONE = "Europe/Copenhagen" +ENERGY_SUPPLIER_IDS = ["1234567890123", "2345678901234"] + + +@pytest.mark.parametrize( + "args_start_date, args_end_date, expected_rows", + [ + pytest.param( + datetime(2024, 1, 2, 23), + datetime(2024, 1, 10, 23), + 1, + id="when time is within the range, return 1 row", + ), + pytest.param( + datetime(2024, 1, 5, 23), + datetime(2024, 1, 10, 23), + 0, + id="when time is outside the range, return 0 rows", + ), + ], +) +def test_time_within_and_outside_of_date_range_scenarios( + spark: SparkSession, + args_start_date: datetime, + args_end_date: datetime, + expected_rows: int, +) -> None: + # Arrange + time = datetime(2024, 1, 3, 23) + + df = create(spark, create_amounts_per_charge_row(time=time)) + mock_repository = Mock() + mock_repository.read_amounts_per_charge.return_value = df + + # Act + actual = read_and_filter_from_view( + energy_supplier_ids=ENERGY_SUPPLIER_IDS, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + period_start=args_start_date, + period_end=args_end_date, + requesting_actor_market_role=MarketRole.ENERGY_SUPPLIER, + requesting_actor_id=default_data.DEFAULT_CHARGE_OWNER_ID, + repository=mock_repository, + ) + + # Assert + assert actual.count() == expected_rows + + +@pytest.mark.parametrize( + "args_energy_supplier_ids, expected_rows", + [ + pytest.param( + ["1234567890123"], + 1, + id="when energy_supplier_id is in energy_supplier_ids, return 1 row", + ), + pytest.param( + ["2345678901234"], + 0, + id="when energy_supplier_id is not in energy_supplier_ids, return 0 rows", + ), + pytest.param( + None, + 1, + id="when energy_supplier_ids is None, return 1 row", + ), + ], +) +def test_energy_supplier_ids_scenarios( + spark: SparkSession, + args_energy_supplier_ids: list[str] | None, + expected_rows: int, +) -> None: + # Arrange + energy_supplier_id = "1234567890123" + df = create( + spark, + create_amounts_per_charge_row(energy_supplier_id=energy_supplier_id), + ) + mock_repository = Mock() + mock_repository.read_amounts_per_charge.return_value = df + + # Act + actual = read_and_filter_from_view( + energy_supplier_ids=args_energy_supplier_ids, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + period_start=default_data.DEFAULT_FROM_DATE, + period_end=default_data.DEFAULT_TO_DATE, + requesting_actor_market_role=MarketRole.ENERGY_SUPPLIER, + requesting_actor_id=default_data.DEFAULT_CHARGE_OWNER_ID, + repository=mock_repository, + ) + + # Assert + assert actual.count() == expected_rows + + +@pytest.mark.parametrize( + "args_calculation_id_by_grid_area, expected_rows", + [ + pytest.param( + {"804": UUID(default_data.DEFAULT_CALCULATION_ID)}, + 1, + id="when calculation_id and grid_area_code is in calculation_id_by_grid_area, return 1 row", + ), + pytest.param( + {"500": UUID(default_data.DEFAULT_CALCULATION_ID)}, + 0, + id="when grid_area_code is not in calculation_id_by_grid_area, return 0 rows", + ), + pytest.param( + {"804": UUID("11111111-1111-2222-1111-111111111111")}, + 0, + id="when calculation_id is not in calculation_id_by_grid_area, return 0 row", + ), + pytest.param( + {"500": UUID("11111111-1111-2222-1111-111111111111")}, + 0, + id="when calculation_id and grid_area_code is not in calculation_id_by_grid_area, return 0 row", + ), + ], +) +def test_calculation_id_by_grid_area_scenarios( + spark: SparkSession, + args_calculation_id_by_grid_area: dict[str, UUID], + expected_rows: int, +) -> None: + # Arrange + df = create( + spark, + create_amounts_per_charge_row( + calculation_id=default_data.DEFAULT_CALCULATION_ID, grid_area_code="804" + ), + ) + mock_repository = Mock() + mock_repository.read_amounts_per_charge.return_value = df + + # Act + actual = read_and_filter_from_view( + energy_supplier_ids=ENERGY_SUPPLIER_IDS, + calculation_id_by_grid_area=args_calculation_id_by_grid_area, + period_start=default_data.DEFAULT_FROM_DATE, + period_end=default_data.DEFAULT_TO_DATE, + requesting_actor_market_role=MarketRole.ENERGY_SUPPLIER, + requesting_actor_id=default_data.DEFAULT_CHARGE_OWNER_ID, + repository=mock_repository, + ) + + # Assert + assert actual.count() == expected_rows + + +@pytest.mark.parametrize( + "args_requesting_actor_market_role, args_requesting_actor_id, is_tax, expected_rows", + [ + pytest.param( + MarketRole.GRID_ACCESS_PROVIDER, + "1111111111111", + True, + 1, + id="When grid_access_provider and charge_owner_id equals requesting_actor_id and is_tax is True, return 1 row", + ), + pytest.param( + MarketRole.GRID_ACCESS_PROVIDER, + default_data.DEFAULT_CHARGE_OWNER_ID, + False, + 1, + id="When grid_access_provider and charge_owner_id equals requesting_actor_id and is_tax is False, return 0 rows", + ), + pytest.param( + MarketRole.SYSTEM_OPERATOR, + default_data.DEFAULT_CHARGE_OWNER_ID, + True, + 0, + id="When system_operator and charge_owner_id equals requesting_actor_id and is_tax is True, return 0 rows", + ), + pytest.param( + MarketRole.SYSTEM_OPERATOR, + default_data.DEFAULT_CHARGE_OWNER_ID, + False, + 1, + id="When system_operator and charge_owner_id equals requesting_actor_id and is_tax is False, return 1 rows", + ), + ], +) +def test_grid_access_provider_and_system_operator_scenarios( + spark: SparkSession, + args_requesting_actor_market_role: MarketRole, + args_requesting_actor_id: str, + is_tax: bool, + expected_rows: int, +) -> None: + # Arrange + df = create( + spark, + create_amounts_per_charge_row(is_tax=is_tax), + ) + mock_repository = Mock() + mock_repository.read_amounts_per_charge.return_value = df + + # Act + actual = read_and_filter_from_view( + energy_supplier_ids=ENERGY_SUPPLIER_IDS, + calculation_id_by_grid_area={ + default_data.DEFAULT_GRID_AREA_CODE: UUID( + default_data.DEFAULT_CALCULATION_ID + ) + }, + period_start=default_data.DEFAULT_FROM_DATE, + period_end=default_data.DEFAULT_TO_DATE, + requesting_actor_market_role=args_requesting_actor_market_role, + requesting_actor_id=args_requesting_actor_id, + repository=mock_repository, + ) + + # Assert + assert actual.count() == expected_rows diff --git a/source/settlement_report_python/tests/entry_points/__init__.py b/source/settlement_report_python/tests/entry_points/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/entry_points/conftest.py b/source/settlement_report_python/tests/entry_points/conftest.py new file mode 100644 index 0000000..dcd04fd --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/conftest.py @@ -0,0 +1,49 @@ +import os +import subprocess +from typing import Generator + +import pytest + + +@pytest.fixture(scope="session") +def virtual_environment() -> Generator: + """Fixture ensuring execution in a virtual environment. + Uses `virtualenv` instead of conda environments due to problems + activating the virtual environment from pytest.""" + + # Create and activate the virtual environment + subprocess.call(["virtualenv", ".wholesale-pytest"]) + subprocess.call( + "source .wholesale-pytest/bin/activate", shell=True, executable="/bin/bash" + ) + + yield None + + # Deactivate virtual environment upon test suite tear down + subprocess.call("deactivate", shell=True, executable="/bin/bash") + + +@pytest.fixture(scope="session") +def installed_package( + virtual_environment: Generator, settlement_report_job_container_path: str +) -> None: + """Ensures that the wholesale package is installed (after building it).""" + + # Build the package wheel + os.chdir(settlement_report_job_container_path) + subprocess.call("python -m build --wheel", shell=True, executable="/bin/bash") + + # Uninstall the package in case it was left by a cancelled test suite + subprocess.call( + "pip uninstall -y package", + shell=True, + executable="/bin/bash", + ) + + # Install wheel, which will also create console scripts for invoking + # the entry points of the package + subprocess.call( + f"pip install {settlement_report_job_container_path}/dist/opengeh_settlement_report-1.0-py3-none-any.whl", + shell=True, + executable="/bin/bash", + ) diff --git a/source/settlement_report_python/tests/entry_points/job_args/test_settlement_report_args.py b/source/settlement_report_python/tests/entry_points/job_args/test_settlement_report_args.py new file mode 100644 index 0000000..836ab0f --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/job_args/test_settlement_report_args.py @@ -0,0 +1,503 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import uuid +from datetime import datetime +from unittest.mock import patch + +import pytest + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.entry_points.entry_point import ( + parse_job_arguments, + parse_command_line_arguments, +) + +from settlement_report_job.entry_points.job_args.environment_variables import ( + EnvironmentVariable, +) +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType + +DEFAULT_REPORT_ID = "12345678-9fc8-409a-a169-fbd49479d718" + + +def _get_contract_parameters(filename: str) -> list[str]: + """Get the parameters as they are expected to be received from the settlement report invoker.""" # noqa + with open(filename) as file: + text = file.read() + text = text.replace("{report-id}", DEFAULT_REPORT_ID) + lines = text.splitlines() + return list( + filter(lambda line: not line.startswith("#") and len(line) > 0, lines) + ) + + +def _substitute_requesting_actor_market_role( + sys_argv: list[str], market_role: str +) -> list[str]: + pattern = r"--requesting-actor-market-role=(\w+)" + + for i, item in enumerate(sys_argv): + if re.search(pattern, item): + sys_argv[i] = re.sub( + pattern, f"--requesting-actor-market-role={market_role}", item + ) + break + + return sys_argv + + +def _substitute_energy_supplier_ids( + sys_argv: list[str], energy_supplier_ids: str +) -> list[str]: + for i, item in enumerate(sys_argv): + if item.startswith("--energy-supplier-ids="): + sys_argv[i] = f"--energy-supplier-ids={energy_supplier_ids}" + break + return sys_argv + + +@pytest.fixture(scope="session") +def contract_parameters_for_balance_fixing(contracts_path: str) -> list[str]: + job_parameters = _get_contract_parameters( + f"{contracts_path}/settlement-report-balance-fixing-parameters-reference.txt" + ) + + return job_parameters + + +@pytest.fixture(scope="session") +def contract_parameters_for_wholesale(contracts_path: str) -> list[str]: + job_parameters = _get_contract_parameters( + f"{contracts_path}/settlement-report-wholesale-calculations-parameters-reference.txt" + ) + + return job_parameters + + +@pytest.fixture(scope="session") +def sys_argv_from_contract_for_wholesale( + contract_parameters_for_wholesale: list[str], +) -> list[str]: + return ["dummy_script_name"] + contract_parameters_for_wholesale + + +@pytest.fixture(scope="session") +def sys_argv_from_contract_for_balance_fixing( + contract_parameters_for_balance_fixing: list[str], +) -> list[str]: + return ["dummy_script_name"] + contract_parameters_for_balance_fixing + + +@pytest.fixture(scope="session") +def job_environment_variables() -> dict: + return { + EnvironmentVariable.CATALOG_NAME.name: "some_catalog", + } + + +def test_when_invoked_with_incorrect_parameters__fails( + job_environment_variables: dict, +) -> None: + # Arrange + with pytest.raises(SystemExit) as excinfo: + with patch("sys.argv", ["dummy_script", "--unexpected-arg"]): + with patch.dict("os.environ", job_environment_variables): + # Act + parse_command_line_arguments() + + # Assert + assert excinfo.value.code == 2 + + +def test_when_parameters_for_balance_fixing__parses_parameters_from_contract( + job_environment_variables: dict, + sys_argv_from_contract_for_balance_fixing: list[str], +) -> None: + """ + This test ensures that the settlement report job for balance fixing accepts + the arguments that are provided by the client. + """ + # Arrange + with patch("sys.argv", sys_argv_from_contract_for_balance_fixing): + with patch.dict("os.environ", job_environment_variables): + command_line_args = parse_command_line_arguments() + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert - settlement report arguments + assert actual_args.report_id == DEFAULT_REPORT_ID + assert actual_args.period_start == datetime(2022, 5, 31, 22) + assert actual_args.period_end == datetime(2022, 6, 1, 22) + assert actual_args.calculation_type == CalculationType.BALANCE_FIXING + assert actual_args.grid_area_codes == ["804", "805"] + assert actual_args.energy_supplier_ids == ["1234567890123"] + assert actual_args.prevent_large_text_files is True + assert actual_args.split_report_by_grid_area is True + assert actual_args.time_zone == "Europe/Copenhagen" + assert actual_args.include_basis_data is True + + +def test_when_parameters_for_wholesale__parses_parameters_from_contract( + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], +) -> None: + """ + This test ensures that the settlement report job for wholesale calculations accepts + the arguments that are provided by the client. + """ + # Arrange + with patch("sys.argv", sys_argv_from_contract_for_wholesale): + with patch.dict("os.environ", job_environment_variables): + command_line_args = parse_command_line_arguments() + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert - settlement report arguments + assert actual_args.report_id == DEFAULT_REPORT_ID + assert actual_args.period_start == datetime(2022, 5, 31, 22) + assert actual_args.period_end == datetime(2022, 6, 1, 22) + assert actual_args.calculation_type == CalculationType.WHOLESALE_FIXING + assert actual_args.calculation_id_by_grid_area == { + "804": uuid.UUID("95bd2365-c09b-4ee7-8c25-8dd56b564811"), + "805": uuid.UUID("d3e2b83a-2fd9-4bcd-a6dc-41e4ce74cd6d"), + } + assert actual_args.energy_supplier_ids == ["1234567890123"] + assert actual_args.prevent_large_text_files is True + assert actual_args.split_report_by_grid_area is True + assert actual_args.time_zone == "Europe/Copenhagen" + assert actual_args.include_basis_data is True + + +@pytest.mark.parametrize( + "not_valid_calculation_id", + [ + "not_valid", + "", + None, + "c09b-4ee7-8c25-8dd56b564811", # too short + ], +) +def test_when_no_valid_calculation_id_for_grid_area__raises_uuid_value_error( + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + not_valid_calculation_id: str, +) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + pattern = r"--calculation-id-by-grid-area=(\{.*\})" + + for i, item in enumerate(test_sys_args): + if re.search(pattern, item): + test_sys_args[i] = re.sub( + pattern, + f'--calculation-id-by-grid-area={{"804": "{not_valid_calculation_id}"}}', # noqa + item, + ) + break + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + with pytest.raises(ValueError) as exc_info: + command_line_args = parse_command_line_arguments() + # Act + parse_job_arguments(command_line_args) + + # Assert + assert "Calculation ID for grid area 804 is not a uuid" in str(exc_info.value) + + +@pytest.mark.parametrize( + "prevent_large_text_files", + [ + True, + False, + ], +) +def test_returns_expected_value_for_prevent_large_text_files( + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + prevent_large_text_files: bool, +) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + if not prevent_large_text_files: + test_sys_args = [ + item + for item in sys_argv_from_contract_for_wholesale + if not item.startswith("--prevent-large-text-files") + ] + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + command_line_args = parse_command_line_arguments() + + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert + assert actual_args.prevent_large_text_files is prevent_large_text_files + + +@pytest.mark.parametrize( + "split_report_by_grid_area", + [ + True, + False, + ], +) +def test_returns_expected_value_for_split_report_by_grid_area( + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + split_report_by_grid_area: bool, +) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + if not split_report_by_grid_area: + test_sys_args = [ + item + for item in sys_argv_from_contract_for_wholesale + if not item.startswith("--split-report-by-grid-area") + ] + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + command_line_args = parse_command_line_arguments() + + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert + assert actual_args.split_report_by_grid_area is split_report_by_grid_area + + +@pytest.mark.parametrize( + "include_basis_data", + [ + True, + False, + ], +) +def test_returns_expected_value_for_include_basis_data( + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + include_basis_data: bool, +) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + if not include_basis_data: + test_sys_args = [ + item + for item in sys_argv_from_contract_for_wholesale + if not item.startswith("--include-basis-data") + ] + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + command_line_args = parse_command_line_arguments() + + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert + assert actual_args.include_basis_data is include_basis_data + + +@pytest.mark.parametrize( + "energy_supplier_ids_argument, expected_energy_suppliers_ids", + [ + ("[1234567890123]", ["1234567890123"]), + ("[1234567890123]", ["1234567890123"]), + ("[1234567890123, 2345678901234]", ["1234567890123", "2345678901234"]), + ("[1234567890123,2345678901234]", ["1234567890123", "2345678901234"]), + ("[ 1234567890123,2345678901234 ]", ["1234567890123", "2345678901234"]), + ], +) +def test_when_energy_supplier_ids_are_specified__returns_expected_energy_supplier_ids( + sys_argv_from_contract_for_wholesale: list[str], + job_environment_variables: dict, + energy_supplier_ids_argument: str, + expected_energy_suppliers_ids: list[str], +) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + test_sys_args = _substitute_energy_supplier_ids( + test_sys_args, energy_supplier_ids_argument + ) + + with patch.dict("os.environ", job_environment_variables): + with patch("sys.argv", test_sys_args): + command_line_args = parse_command_line_arguments() + + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert + assert actual_args.energy_supplier_ids == expected_energy_suppliers_ids + + +@pytest.mark.parametrize( + "energy_supplier_ids_argument", + [ + "1234567890123", # not a list + "1234567890123 2345678901234", # not a list + "[123]", # neither 13 nor 16 characters + "[12345678901234]", # neither 13 nor 16 characters + ], +) +def test_when_invalid_energy_supplier_ids__raise_exception( + sys_argv_from_contract_for_wholesale: list[str], + job_environment_variables: dict, + energy_supplier_ids_argument: str, +) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + test_sys_args = _substitute_energy_supplier_ids( + test_sys_args, energy_supplier_ids_argument + ) + + with patch.dict("os.environ", job_environment_variables): + with patch("sys.argv", test_sys_args): + with pytest.raises(SystemExit) as error: + command_line_args = parse_command_line_arguments() + # Act + parse_job_arguments(command_line_args) + + # Assert + assert error.value.code != 0 + + +def test_when_no_energy_supplier_specified__returns_none_energy_supplier_ids( + sys_argv_from_contract_for_wholesale: list[str], + job_environment_variables: dict, +) -> None: + # Arrange + test_sys_args = [ + item + for item in sys_argv_from_contract_for_wholesale + if not item.startswith("--energy-supplier-ids") + ] + + with patch.dict("os.environ", job_environment_variables): + with patch("sys.argv", test_sys_args): + command_line_args = parse_command_line_arguments() + + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert + assert actual_args.energy_supplier_ids is None + + +class TestWhenInvokedWithValidMarketRole: + @pytest.mark.parametrize( + "market_role", + [market_role for market_role in MarketRole], + ) + def test_returns_expected_requesting_actor_market_role( + self, + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + market_role: MarketRole, + ) -> None: + # Arrange + test_sys_args = _substitute_requesting_actor_market_role( + sys_argv_from_contract_for_wholesale.copy(), market_role.value + ) + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + command_line_args = parse_command_line_arguments() + + # Act + actual_args = parse_job_arguments(command_line_args) + + # Assert + assert actual_args.requesting_actor_market_role == market_role + + +class TestWhenInvokedWithInvalidMarketRole: + + def test_raise_system_exit_with_non_zero_code( + self, + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + ) -> None: + # Arrange + test_sys_args = _substitute_requesting_actor_market_role( + sys_argv_from_contract_for_wholesale.copy(), "invalid_market_role" + ) + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + with pytest.raises(SystemExit) as error: + command_line_args = parse_command_line_arguments() + # Act + parse_job_arguments(command_line_args) + + # Assert + assert error.value.code != 0 + + +class TestWhenUnknownCalculationType: + def test_raise_system_exit_with_non_zero_code( + self, + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + ) -> None: + # Arrange + test_sys_args = sys_argv_from_contract_for_wholesale.copy() + unknown_calculation_type = "unknown_calculation_type" + pattern = r"--calculation-type=(\w+)" + + for i, item in enumerate(test_sys_args): + if re.search(pattern, item): + test_sys_args[i] = re.sub( + pattern, f"--calculation-type={unknown_calculation_type}", item + ) + break + + with patch("sys.argv", test_sys_args): + with patch.dict("os.environ", job_environment_variables): + with pytest.raises(SystemExit) as error: + command_line_args = parse_command_line_arguments() + # Act + parse_job_arguments(command_line_args) + + # Assert + assert error.value.code != 0 + + +class TestWhenMissingEnvVariables: + def test_raise_system_exit_with_non_zero_code( + self, + job_environment_variables: dict, + sys_argv_from_contract_for_wholesale: list[str], + ) -> None: + # Arrange + with patch("sys.argv", sys_argv_from_contract_for_wholesale): + for excluded_env_var in job_environment_variables.keys(): + env_variables_with_one_missing = { + key: value + for key, value in job_environment_variables.items() + if key != excluded_env_var + } + + with patch.dict("os.environ", env_variables_with_one_missing): + with pytest.raises(ValueError) as error: + command_line_args = parse_command_line_arguments() + # Act + parse_job_arguments(command_line_args) + + assert str(error.value).startswith("Environment variable not found") diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_charge_link_periods.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_charge_link_periods.py new file mode 100644 index 0000000..79cb3f0 --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_charge_link_periods.py @@ -0,0 +1,231 @@ +from pyspark.sql import SparkSession +import pytest + +from dbutils_fixture import DBUtilsFixture + +from data_seeding import standard_wholesale_fixing_scenario_data_generator +from assertion import assert_file_names_and_columns +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.entry_points.tasks.charge_link_periods_task import ( + ChargeLinkPeriodsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import get_actual_files + + +def test_execute_charge_link_periods__when_energy_supplier__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + expected_file_names = [ + f"CHARGELINK_804_{standard_wholesale_fixing_scenario_energy_supplier_args.requesting_actor_id}_DDQ_02-01-2024_02-01-2024.csv", + f"CHARGELINK_805_{standard_wholesale_fixing_scenario_energy_supplier_args.requesting_actor_id}_DDQ_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.charge_type, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.charge_quantity, + CsvColumnNames.charge_link_from_date, + CsvColumnNames.charge_link_to_date, + ] + task = ChargeLinkPeriodsTask( + spark, dbutils, standard_wholesale_fixing_scenario_energy_supplier_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargeLinks, + args=standard_wholesale_fixing_scenario_energy_supplier_args, + ) + assert_file_names_and_columns( + path=get_report_output_path( + standard_wholesale_fixing_scenario_energy_supplier_args + ), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_charge_link_periods__when_grid_access_provider__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_grid_access_provider_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + expected_file_names = [ + f"CHARGELINK_804_{standard_wholesale_fixing_scenario_grid_access_provider_args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + f"CHARGELINK_805_{standard_wholesale_fixing_scenario_grid_access_provider_args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.charge_type, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.charge_quantity, + CsvColumnNames.charge_link_from_date, + CsvColumnNames.charge_link_to_date, + ] + task = ChargeLinkPeriodsTask( + spark, dbutils, standard_wholesale_fixing_scenario_grid_access_provider_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargeLinks, + args=standard_wholesale_fixing_scenario_grid_access_provider_args, + ) + assert_file_names_and_columns( + path=get_report_output_path( + standard_wholesale_fixing_scenario_grid_access_provider_args + ), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_charge_link_periods__when_system_operator_or_datahub_admin_with_one_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + energy_supplier_id = ( + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + args.energy_supplier_ids = [energy_supplier_id] + expected_file_names = [ + f"CHARGELINK_804_{energy_supplier_id}_02-01-2024_02-01-2024.csv", + f"CHARGELINK_805_{energy_supplier_id}_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.charge_type, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.charge_quantity, + CsvColumnNames.charge_link_from_date, + CsvColumnNames.charge_link_to_date, + CsvColumnNames.energy_supplier_id, + ] + task = ChargeLinkPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargeLinks, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_charge_link_periods__when_system_operator_or_datahub_admin_with_none_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + args.energy_supplier_ids = None + expected_file_names = [ + "CHARGELINK_804_02-01-2024_02-01-2024.csv", + "CHARGELINK_805_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.charge_type, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.charge_quantity, + CsvColumnNames.charge_link_from_date, + CsvColumnNames.charge_link_to_date, + CsvColumnNames.energy_supplier_id, + ] + task = ChargeLinkPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargeLinks, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_charge_link_periods__when_include_basis_data_false__returns_no_file_paths( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.include_basis_data = False + task = ChargeLinkPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargeLinks, + args=args, + ) + assert actual_files is None or len(actual_files) == 0 diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_charge_price_points.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_charge_price_points.py new file mode 100644 index 0000000..5607ff4 --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_charge_price_points.py @@ -0,0 +1,236 @@ +from pyspark.sql import SparkSession +import pytest + +from assertion import assert_file_names_and_columns +from data_seeding import standard_wholesale_fixing_scenario_data_generator +from dbutils_fixture import DBUtilsFixture + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.entry_points.tasks.charge_price_points_task import ( + ChargePricePointsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import get_actual_files + +expected_columns = [ + CsvColumnNames.charge_type, + CsvColumnNames.charge_owner_id, + CsvColumnNames.charge_code, + CsvColumnNames.resolution, + CsvColumnNames.is_tax, + CsvColumnNames.time, + f"{CsvColumnNames.energy_price}1", + f"{CsvColumnNames.energy_price}2", + f"{CsvColumnNames.energy_price}3", + f"{CsvColumnNames.energy_price}4", + f"{CsvColumnNames.energy_price}5", + f"{CsvColumnNames.energy_price}6", + f"{CsvColumnNames.energy_price}7", + f"{CsvColumnNames.energy_price}8", + f"{CsvColumnNames.energy_price}9", + f"{CsvColumnNames.energy_price}10", + f"{CsvColumnNames.energy_price}11", + f"{CsvColumnNames.energy_price}12", + f"{CsvColumnNames.energy_price}13", + f"{CsvColumnNames.energy_price}14", + f"{CsvColumnNames.energy_price}15", + f"{CsvColumnNames.energy_price}16", + f"{CsvColumnNames.energy_price}17", + f"{CsvColumnNames.energy_price}18", + f"{CsvColumnNames.energy_price}19", + f"{CsvColumnNames.energy_price}20", + f"{CsvColumnNames.energy_price}21", + f"{CsvColumnNames.energy_price}22", + f"{CsvColumnNames.energy_price}23", + f"{CsvColumnNames.energy_price}24", + f"{CsvColumnNames.energy_price}25", +] + + +def test_execute_charge_price_points__when_energy_supplier__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + expected_file_names = [ + f"CHARGEPRICE_804_{standard_wholesale_fixing_scenario_energy_supplier_args.requesting_actor_id}_DDQ_02-01-2024_02-01-2024.csv", + f"CHARGEPRICE_805_{standard_wholesale_fixing_scenario_energy_supplier_args.requesting_actor_id}_DDQ_02-01-2024_02-01-2024.csv", + ] + + task = ChargePricePointsTask( + spark, dbutils, standard_wholesale_fixing_scenario_energy_supplier_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargePricePoints, + args=standard_wholesale_fixing_scenario_energy_supplier_args, + ) + + assert_file_names_and_columns( + path=get_report_output_path( + standard_wholesale_fixing_scenario_energy_supplier_args + ), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_charge_price_points__when_grid_access_provider__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_grid_access_provider_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + expected_file_names = [ + f"CHARGEPRICE_804_{standard_wholesale_fixing_scenario_grid_access_provider_args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + f"CHARGEPRICE_805_{standard_wholesale_fixing_scenario_grid_access_provider_args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + ] + + task = ChargePricePointsTask( + spark, dbutils, standard_wholesale_fixing_scenario_grid_access_provider_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargePricePoints, + args=standard_wholesale_fixing_scenario_grid_access_provider_args, + ) + assert_file_names_and_columns( + path=get_report_output_path( + standard_wholesale_fixing_scenario_grid_access_provider_args + ), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_charge_price_points__when_system_operator_or_datahub_admin_with_one_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + energy_supplier_id = ( + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + args.energy_supplier_ids = [energy_supplier_id] + + expected_file_names = [ + f"CHARGEPRICE_804_{energy_supplier_id}_02-01-2024_02-01-2024.csv", + f"CHARGEPRICE_805_{energy_supplier_id}_02-01-2024_02-01-2024.csv", + ] + + task = ChargePricePointsTask( + spark, dbutils, standard_wholesale_fixing_scenario_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargePricePoints, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_charge_price_points__when_system_operator_or_datahub_admin_with_none_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + args.energy_supplier_ids = None + expected_file_names = [ + "CHARGEPRICE_804_02-01-2024_02-01-2024.csv", + "CHARGEPRICE_805_02-01-2024_02-01-2024.csv", + ] + + task = ChargePricePointsTask( + spark, dbutils, standard_wholesale_fixing_scenario_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargePricePoints, + args=args, + ) + + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_charge_price_points__when_include_basis_data_false__returns_no_file_paths( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.include_basis_data = False + task = ChargePricePointsTask( + spark, dbutils, standard_wholesale_fixing_scenario_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.ChargePricePoints, + args=args, + ) + assert actual_files is None or len(actual_files) == 0 diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_energy_results.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_energy_results.py new file mode 100644 index 0000000..9c14f75 --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_energy_results.py @@ -0,0 +1,215 @@ +import pytest +from pyspark.sql import SparkSession + +from data_seeding import standard_wholesale_fixing_scenario_data_generator +from dbutils_fixture import DBUtilsFixture + +from assertion import assert_file_names_and_columns +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.entry_points.tasks.energy_resuls_task import ( + EnergyResultsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import get_actual_files, cleanup_output_path + + +@pytest.fixture(scope="function", autouse=True) +def reset_task_values(settlement_reports_output_path: str): + yield + cleanup_output_path( + settlement_reports_output_path=settlement_reports_output_path, + ) + + +def test_execute_energy_results__when_standard_wholesale_fixing_scenario__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.energy_supplier_ids = ["1000000000000"] + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_804_1000000000000_02-01-2024_02-01-2024.csv", + "RESULTENERGY_805_1000000000000_02-01-2024_02-01-2024.csv", + ] + task = EnergyResultsTask(spark, dbutils, standard_wholesale_fixing_scenario_args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.EnergyResults, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_energy_results__when_split_report_by_grid_area_is_false__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area = { + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[ + 0 + ]: standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area[ + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[0] + ] + } + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + standard_wholesale_fixing_scenario_args.split_report_by_grid_area = True + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_804_02-01-2024_02-01-2024.csv", + ] + task = EnergyResultsTask(spark, dbutils, standard_wholesale_fixing_scenario_args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.EnergyResults, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_energy_results__when_standard_wholesale_fixing_scenario_grid_access__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.GRID_ACCESS_PROVIDER + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = "1234567890123" + + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_804_1234567890123_DDM_02-01-2024_02-01-2024.csv", + "RESULTENERGY_805_1234567890123_DDM_02-01-2024_02-01-2024.csv", + ] + task = EnergyResultsTask(spark, dbutils, standard_wholesale_fixing_scenario_args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.EnergyResults, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_energy_results__when_standard_wholesale_fixing_scenario_energy_supplier__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.ENERGY_SUPPLIER + ) + standard_wholesale_fixing_scenario_args.requesting_actor_id = "1000000000000" + standard_wholesale_fixing_scenario_args.energy_supplier_ids = ["1000000000000"] + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_804_1000000000000_DDQ_02-01-2024_02-01-2024.csv", + "RESULTENERGY_805_1000000000000_DDQ_02-01-2024_02-01-2024.csv", + ] + task = EnergyResultsTask(spark, dbutils, standard_wholesale_fixing_scenario_args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.EnergyResults, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_hourly_time_series_points.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_hourly_time_series_points.py new file mode 100644 index 0000000..0ad3026 --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_hourly_time_series_points.py @@ -0,0 +1,72 @@ +import pytest +from pyspark.sql import SparkSession + +from dbutils_fixture import DBUtilsFixture + +from assertion import assert_file_names_and_columns +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.entry_points.tasks.task_type import TaskType +from settlement_report_job.entry_points.tasks.time_series_points_task import ( + TimeSeriesPointsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import cleanup_output_path, get_actual_files + + +# NOTE: The tests in test_execute_quarterly_time_series_points.py should cover execute_hourly also, so we don't need to test +# all the same things again here also. + + +@pytest.fixture(scope="function", autouse=True) +def reset_task_values(settlement_reports_output_path: str): + yield + cleanup_output_path( + settlement_reports_output_path=settlement_reports_output_path, + ) + + +def test_execute_hourly_time_series_points__when_standard_wholesale_fixing_scenario__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + expected_file_names = [ + "TSSD60_804_02-01-2024_02-01-2024.csv", + "TSSD60_805_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.energy_supplier_id, + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.time, + ] + [f"ENERGYQUANTITY{i}" for i in range(1, 26)] + task = TimeSeriesPointsTask( + spark, + dbutils, + standard_wholesale_fixing_scenario_args, + TaskType.TimeSeriesHourly, + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesHourly, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_metering_point_periods.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_metering_point_periods.py new file mode 100644 index 0000000..61dfa2a --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_metering_point_periods.py @@ -0,0 +1,294 @@ +from pyspark.sql import SparkSession +import pytest + +from data_seeding import standard_wholesale_fixing_scenario_data_generator +from assertion import assert_file_names_and_columns + +from dbutils_fixture import DBUtilsFixture +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.entry_points.tasks.metering_point_periods_task import ( + MeteringPointPeriodsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import get_start_date, get_end_date, cleanup_output_path, get_actual_files + + +@pytest.fixture(scope="function", autouse=True) +def reset_task_values(settlement_reports_output_path: str): + yield + cleanup_output_path( + settlement_reports_output_path=settlement_reports_output_path, + ) + + +def _get_expected_columns(requesting_actor_market_role: MarketRole) -> list[str]: + if requesting_actor_market_role == MarketRole.GRID_ACCESS_PROVIDER: + return [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_from_date, + CsvColumnNames.metering_point_to_date, + CsvColumnNames.grid_area_code_in_metering_points_csv, + CsvColumnNames.to_grid_area_code, + CsvColumnNames.from_grid_area_code, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + ] + elif requesting_actor_market_role is MarketRole.ENERGY_SUPPLIER: + return [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_from_date, + CsvColumnNames.metering_point_to_date, + CsvColumnNames.grid_area_code_in_metering_points_csv, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + ] + elif requesting_actor_market_role in [ + MarketRole.SYSTEM_OPERATOR, + MarketRole.DATAHUB_ADMINISTRATOR, + ]: + return [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_from_date, + CsvColumnNames.metering_point_to_date, + CsvColumnNames.grid_area_code_in_metering_points_csv, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_supplier_id, + ] + + +def test_execute_metering_point_periods__when_energy_supplier__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_energy_supplier_args + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + grid_area_code_1 = grid_area_codes[0] + grid_area_code_2 = grid_area_codes[1] + expected_file_names = [ + f"MDMP_{grid_area_code_1}_{args.requesting_actor_id}_DDQ_{start_time}_{end_time}.csv", + f"MDMP_{grid_area_code_2}_{args.requesting_actor_id}_DDQ_{start_time}_{end_time}.csv", + ] + expected_columns = _get_expected_columns( + standard_wholesale_fixing_scenario_energy_supplier_args.requesting_actor_market_role + ) + task = MeteringPointPeriodsTask( + spark, dbutils, standard_wholesale_fixing_scenario_energy_supplier_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MeteringPointPeriods, + args=standard_wholesale_fixing_scenario_energy_supplier_args, + ) + assert_file_names_and_columns( + path=get_report_output_path( + standard_wholesale_fixing_scenario_energy_supplier_args + ), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_metering_point_periods__when_grid_access_provider__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_grid_access_provider_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_grid_access_provider_args + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + expected_file_names = [ + f"MDMP_{grid_area_codes[0]}_{args.requesting_actor_id}_DDM_{start_time}_{end_time}.csv", + f"MDMP_{grid_area_codes[1]}_{args.requesting_actor_id}_DDM_{start_time}_{end_time}.csv", + ] + expected_columns = _get_expected_columns( + standard_wholesale_fixing_scenario_grid_access_provider_args.requesting_actor_market_role + ) + task = MeteringPointPeriodsTask( + spark, dbutils, standard_wholesale_fixing_scenario_grid_access_provider_args + ) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MeteringPointPeriods, + args=standard_wholesale_fixing_scenario_grid_access_provider_args, + ) + assert_file_names_and_columns( + path=get_report_output_path( + standard_wholesale_fixing_scenario_grid_access_provider_args + ), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_metering_point_periods__when_system_operator_or_datahub_admin_with_one_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + energy_supplier_id = ( + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + args.energy_supplier_ids = [energy_supplier_id] + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + expected_file_names = [ + f"MDMP_{grid_area_codes[0]}_{energy_supplier_id}_{start_time}_{end_time}.csv", + f"MDMP_{grid_area_codes[1]}_{energy_supplier_id}_{start_time}_{end_time}.csv", + ] + expected_columns = _get_expected_columns(args.requesting_actor_market_role) + task = MeteringPointPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MeteringPointPeriods, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_metering_point_periods__when_system_operator_or_datahub_admin_with_none_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + args.energy_supplier_ids = None + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + expected_file_names = [ + f"MDMP_{grid_area_codes[0]}_{start_time}_{end_time}.csv", + f"MDMP_{grid_area_codes[1]}_{start_time}_{end_time}.csv", + ] + expected_columns = _get_expected_columns( + standard_wholesale_fixing_scenario_args.requesting_actor_market_role + ) + task = MeteringPointPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MeteringPointPeriods, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_metering_point_periods__when_balance_fixing__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_balance_fixing_scenario_args: SettlementReportArgs, + standard_balance_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_balance_fixing_scenario_args + args.energy_supplier_ids = None + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + expected_file_names = [ + f"MDMP_{args.grid_area_codes[0]}_{start_time}_{end_time}.csv", + f"MDMP_{args.grid_area_codes[1]}_{start_time}_{end_time}.csv", + ] + expected_columns = _get_expected_columns(args.requesting_actor_market_role) + task = MeteringPointPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MeteringPointPeriods, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_metering_point_periods__when_include_basis_data_false__returns_no_file_paths( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.include_basis_data = False + task = MeteringPointPeriodsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MeteringPointPeriods, + args=standard_wholesale_fixing_scenario_args, + ) + assert actual_files is None or len(actual_files) == 0 diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_monthly_amounts.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_monthly_amounts.py new file mode 100644 index 0000000..85f4798 --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_monthly_amounts.py @@ -0,0 +1,234 @@ +from pyspark.sql import SparkSession + +import pytest + +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.tasks.monthly_amounts_task import ( + MonthlyAmountsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from assertion import assert_file_names_and_columns +from dbutils_fixture import DBUtilsFixture +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from data_seeding import standard_wholesale_fixing_scenario_data_generator +from utils import cleanup_output_path, get_actual_files + + +@pytest.fixture(scope="function", autouse=True) +def reset_task_values(settlement_reports_output_path: str): + yield + cleanup_output_path( + settlement_reports_output_path=settlement_reports_output_path, + ) + + +def test_execute_monthly_amounts__when_standard_wholesale_fixing_scenario__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_energy_supplier_args + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + + expected_file_names = [ + f"RESULTMONTHLY_804_{args.requesting_actor_id}_DDQ_02-01-2024_02-01-2024.csv", + f"RESULTMONTHLY_805_{args.requesting_actor_id}_DDQ_02-01-2024_02-01-2024.csv", + ] + + task = MonthlyAmountsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MonthlyAmounts, + args=standard_wholesale_fixing_scenario_energy_supplier_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_monthly_amounts__when_administrator_and_split_report_by_grid_area_is_false__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_datahub_admin_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + args = standard_wholesale_fixing_scenario_datahub_admin_args + args.split_report_by_grid_area = False + args.energy_supplier_ids = [ + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ] + + # Arrange + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + + expected_file_names = [ + f"RESULTMONTHLY_flere-net_{args.energy_supplier_ids[0]}_02-01-2024_02-01-2024.csv", + ] + + task = MonthlyAmountsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MonthlyAmounts, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_monthly_amounts__when_grid_access_provider__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_grid_access_provider_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + args = standard_wholesale_fixing_scenario_grid_access_provider_args + + # Get just one of the grid_areas of the dictionary. + for key, value in args.calculation_id_by_grid_area.items(): + target_grid_area = key + target_calc_id = value + break + args.calculation_id_by_grid_area = {target_grid_area: target_calc_id} + + # Arrange + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + + expected_file_names = [ + f"RESULTMONTHLY_{target_grid_area}_{args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + ] + + task = MonthlyAmountsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MonthlyAmounts, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_monthly_amounts__when_system_operator__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_system_operator_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + args = standard_wholesale_fixing_scenario_system_operator_args + + # Get just one of the grid_areas of the dictionary. + for key, value in args.calculation_id_by_grid_area.items(): + target_grid_area = key + target_calc_id = value + break + args.calculation_id_by_grid_area = {target_grid_area: target_calc_id} + + # Arrange + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + + expected_file_names = [ + f"RESULTMONTHLY_{target_grid_area}_02-01-2024_02-01-2024.csv", + ] + + task = MonthlyAmountsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.MonthlyAmounts, + args=args, + ) + + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_quarterly_time_series_points.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_quarterly_time_series_points.py new file mode 100644 index 0000000..c6639fa --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_quarterly_time_series_points.py @@ -0,0 +1,268 @@ +from pyspark.sql import SparkSession +import pytest + + +from data_seeding import ( + standard_wholesale_fixing_scenario_data_generator, + standard_balance_fixing_scenario_data_generator, +) +from assertion import assert_file_names_and_columns + +from dbutils_fixture import DBUtilsFixture +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.entry_points.tasks.task_type import TaskType +from settlement_report_job.entry_points.tasks.time_series_points_task import ( + TimeSeriesPointsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import cleanup_output_path, get_actual_files + + +@pytest.fixture(scope="function", autouse=True) +def reset_task_values(settlement_reports_output_path: str): + yield + cleanup_output_path( + settlement_reports_output_path=settlement_reports_output_path, + ) + + +def test_execute_quarterly_time_series_points__when_energy_supplier__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + energy_supplier_id = ( + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + args.requesting_actor_id = energy_supplier_id + args.energy_supplier_ids = [energy_supplier_id] + expected_file_names = [ + f"TSSD15_804_{energy_supplier_id}_DDQ_02-01-2024_02-01-2024.csv", + f"TSSD15_805_{energy_supplier_id}_DDQ_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.time, + ] + [f"ENERGYQUANTITY{i}" for i in range(1, 101)] + task = TimeSeriesPointsTask(spark, dbutils, args, TaskType.TimeSeriesQuarterly) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesQuarterly, + args=standard_wholesale_fixing_scenario_args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_quarterly_time_series_points__when_grid_access_provider__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = MarketRole.GRID_ACCESS_PROVIDER + args.energy_supplier_ids = None + expected_file_names = [ + f"TSSD15_804_{args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + f"TSSD15_805_{args.requesting_actor_id}_DDM_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.time, + ] + [f"ENERGYQUANTITY{i}" for i in range(1, 101)] + task = TimeSeriesPointsTask(spark, dbutils, args, TaskType.TimeSeriesQuarterly) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesQuarterly, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_quarterly_time_series_points__when_system_operator_or_datahub_admin_with_one_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + energy_supplier_id = ( + standard_wholesale_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + args.energy_supplier_ids = [energy_supplier_id] + expected_file_names = [ + f"TSSD15_804_{energy_supplier_id}_02-01-2024_02-01-2024.csv", + f"TSSD15_805_{energy_supplier_id}_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.energy_supplier_id, + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.time, + ] + [f"ENERGYQUANTITY{i}" for i in range(1, 101)] + task = TimeSeriesPointsTask(spark, dbutils, args, TaskType.TimeSeriesQuarterly) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesQuarterly, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [MarketRole.SYSTEM_OPERATOR, MarketRole.DATAHUB_ADMINISTRATOR], +) +def test_execute_quarterly_time_series_points__when_system_operator_or_datahub_admin_with_none_energy_supplier_id__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.requesting_actor_market_role = market_role + args.energy_supplier_ids = None + expected_file_names = [ + "TSSD15_804_02-01-2024_02-01-2024.csv", + "TSSD15_805_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.energy_supplier_id, + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.time, + ] + [f"ENERGYQUANTITY{i}" for i in range(1, 101)] + task = TimeSeriesPointsTask(spark, dbutils, args, TaskType.TimeSeriesQuarterly) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesQuarterly, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_quarterly_time_series_points__when_include_basis_data_false__returns_no_file_paths( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.include_basis_data = False + task = TimeSeriesPointsTask(spark, dbutils, args, TaskType.TimeSeriesQuarterly) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesQuarterly, + args=args, + ) + assert actual_files is None or len(actual_files) == 0 + + +def test_execute_quarterly_time_series_points__when_energy_supplier_and_balance_fixing__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_balance_fixing_scenario_args: SettlementReportArgs, + standard_balance_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_balance_fixing_scenario_args + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + energy_supplier_id = ( + standard_balance_fixing_scenario_data_generator.ENERGY_SUPPLIER_IDS[0] + ) + args.requesting_actor_id = energy_supplier_id + args.energy_supplier_ids = [energy_supplier_id] + expected_file_names = [ + f"TSSD15_804_{energy_supplier_id}_DDQ_02-01-2024_02-01-2024.csv", + f"TSSD15_805_{energy_supplier_id}_DDQ_02-01-2024_02-01-2024.csv", + ] + expected_columns = [ + CsvColumnNames.metering_point_id, + CsvColumnNames.metering_point_type, + CsvColumnNames.time, + ] + [f"ENERGYQUANTITY{i}" for i in range(1, 101)] + task = TimeSeriesPointsTask(spark, dbutils, args, TaskType.TimeSeriesQuarterly) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.TimeSeriesQuarterly, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) diff --git a/source/settlement_report_python/tests/entry_points/tasks/test_execute_wholesale_results.py b/source/settlement_report_python/tests/entry_points/tasks/test_execute_wholesale_results.py new file mode 100644 index 0000000..8e0f1bc --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/tasks/test_execute_wholesale_results.py @@ -0,0 +1,302 @@ +from pyspark.sql import SparkSession +import pytest + +from dbutils_fixture import DBUtilsFixture +from assertion import assert_file_names_and_columns + +from data_seeding.standard_wholesale_fixing_scenario_data_generator import ( + CHARGE_OWNER_ID_WITHOUT_TAX, +) +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, +) +from settlement_report_job.entry_points.tasks.wholesale_results_task import ( + WholesaleResultsTask, +) +from settlement_report_job.infrastructure.paths import get_report_output_path +from utils import ( + get_market_role_in_file_name, + get_start_date, + get_end_date, + cleanup_output_path, + get_actual_files, +) + + +@pytest.fixture(scope="function", autouse=True) +def reset_task_values(settlement_reports_output_path: str): + yield + cleanup_output_path( + settlement_reports_output_path=settlement_reports_output_path, + ) + + +def test_execute_wholesale_results__when_energy_supplier_and_split_by_grid_area_is_false__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_energy_supplier_args + args.split_report_by_grid_area = False + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + + market_role_in_file_name = get_market_role_in_file_name( + args.requesting_actor_market_role + ) + + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + + energy_supplier_id = args.energy_supplier_ids[0] + + expected_file_names = [ + f"RESULTWHOLESALE_flere-net_{energy_supplier_id}_{market_role_in_file_name}_{start_time}_{end_time}.csv", + ] + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.energy_quantity, + CsvColumnNames.price, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + task = WholesaleResultsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.WholesaleResults, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_execute_wholesale_results__when_energy_supplier_and_split_by_grid_area_is_true__returns_expected( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_energy_supplier_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_energy_supplier_args + args.split_report_by_grid_area = True + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + + market_role_in_file_name = get_market_role_in_file_name( + args.requesting_actor_market_role + ) + + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + grid_area_code_1 = grid_area_codes[0] + grid_area_code_2 = grid_area_codes[1] + + energy_supplier_id = args.energy_supplier_ids[0] + + expected_file_names = [ + f"RESULTWHOLESALE_{grid_area_code_1}_{energy_supplier_id}_{market_role_in_file_name}_{start_time}_{end_time}.csv", + f"RESULTWHOLESALE_{grid_area_code_2}_{energy_supplier_id}_{market_role_in_file_name}_{start_time}_{end_time}.csv", + ] + + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.energy_quantity, + CsvColumnNames.price, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + task = WholesaleResultsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.WholesaleResults, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +@pytest.mark.parametrize( + "market_role", + [ + pytest.param( + MarketRole.SYSTEM_OPERATOR, id="system_operator return correct file names" + ), + pytest.param( + MarketRole.DATAHUB_ADMINISTRATOR, + id="datahub_administrator return correct file names", + ), + ], +) +def test_when_market_role_is( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, + market_role: MarketRole, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.split_report_by_grid_area = True + args.requesting_actor_market_role = market_role + args.energy_supplier_ids = None + args.requesting_actor_id = CHARGE_OWNER_ID_WITHOUT_TAX + + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + grid_area_code_1 = grid_area_codes[0] + grid_area_code_2 = grid_area_codes[1] + + expected_file_names = [ + f"RESULTWHOLESALE_{grid_area_code_1}_{start_time}_{end_time}.csv", + f"RESULTWHOLESALE_{grid_area_code_2}_{start_time}_{end_time}.csv", + ] + + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.energy_quantity, + CsvColumnNames.price, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + task = WholesaleResultsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.WholesaleResults, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_when_market_role_is_grid_access_provider_return_correct_file_name( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + standard_wholesale_fixing_scenario_data_written_to_delta: None, +): + # Arrange + args = standard_wholesale_fixing_scenario_args + args.split_report_by_grid_area = True + args.requesting_actor_market_role = MarketRole.GRID_ACCESS_PROVIDER + args.energy_supplier_ids = None + + market_role_in_file_name = get_market_role_in_file_name( + args.requesting_actor_market_role + ) + + start_time = get_start_date(args.period_start) + end_time = get_end_date(args.period_end) + + grid_area_codes = list(args.calculation_id_by_grid_area.keys()) + grid_area_code_1 = grid_area_codes[0] + grid_area_code_2 = grid_area_codes[1] + + expected_file_names = [ + f"RESULTWHOLESALE_{grid_area_code_1}_{args.requesting_actor_id}_{market_role_in_file_name}_{start_time}_{end_time}.csv", + f"RESULTWHOLESALE_{grid_area_code_2}_{args.requesting_actor_id}_{market_role_in_file_name}_{start_time}_{end_time}.csv", + ] + + expected_columns = [ + CsvColumnNames.calculation_type, + CsvColumnNames.correction_settlement_number, + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.quantity_unit, + CsvColumnNames.currency, + CsvColumnNames.energy_quantity, + CsvColumnNames.price, + CsvColumnNames.amount, + CsvColumnNames.charge_type, + CsvColumnNames.charge_code, + CsvColumnNames.charge_owner_id, + ] + task = WholesaleResultsTask(spark, dbutils, args) + + # Act + task.execute() + + # Assert + actual_files = get_actual_files( + report_data_type=ReportDataType.WholesaleResults, + args=args, + ) + assert_file_names_and_columns( + path=get_report_output_path(args), + actual_files=actual_files, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) diff --git a/source/settlement_report_python/tests/entry_points/test_entry_points.py b/source/settlement_report_python/tests/entry_points/test_entry_points.py new file mode 100644 index 0000000..0c13e56 --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/test_entry_points.py @@ -0,0 +1,69 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +from typing import Any + +import pytest +from settlement_report_job.entry_points import entry_point as module + + +# IMPORTANT: +# If we add/remove tests here, we also update the "retry logic" in '.docker/entrypoint.sh', +# which depends on the number of "entry point tests". + + +def assert_entry_point_exists(entry_point_name: str) -> Any: + # Load the entry point function from the installed wheel + try: + entry_point = importlib.metadata.entry_points( + group="console_scripts", name=entry_point_name + ) + if not entry_point: + assert False, f"The {entry_point_name} entry point was not found." + # Check if the module exists + module_name = entry_point[0].module + function_name = entry_point[0].value.split(":")[1] + if not hasattr( + module, + function_name, + ): + assert ( + False + ), f"The entry point module function {function_name} does not exist in entry_point.py." + + importlib.import_module(module_name) + except importlib.metadata.PackageNotFoundError: + assert False, f"The {entry_point_name} entry point was not found." + + +@pytest.mark.parametrize( + "entry_point_name", + [ + "create_hourly_time_series", + "create_quarterly_time_series", + "create_charge_links", + "create_charge_price_points", + "create_energy_results", + "create_monthly_amounts", + "create_wholesale_results", + "create_metering_point_periods", + "create_zip", + ], +) +def test__installed_package__can_load_entry_point( + installed_package: None, + entry_point_name: str, +) -> None: + assert_entry_point_exists(entry_point_name) diff --git a/source/settlement_report_python/tests/entry_points/utils/test_get_dbutils.py b/source/settlement_report_python/tests/entry_points/utils/test_get_dbutils.py new file mode 100644 index 0000000..cce42ac --- /dev/null +++ b/source/settlement_report_python/tests/entry_points/utils/test_get_dbutils.py @@ -0,0 +1,10 @@ +import pytest +from pyspark.sql import SparkSession + +from settlement_report_job.entry_points.utils.get_dbutils import get_dbutils + + +def test_get_dbutils__when_run_locally__raise_exception(spark: SparkSession): + # Act + with pytest.raises(Exception): + get_dbutils(spark) diff --git a/source/settlement_report_python/tests/infrastructure/__init__.py b/source/settlement_report_python/tests/infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/infrastructure/test_create_zip_file.py b/source/settlement_report_python/tests/infrastructure/test_create_zip_file.py new file mode 100644 index 0000000..8e3e9f8 --- /dev/null +++ b/source/settlement_report_python/tests/infrastructure/test_create_zip_file.py @@ -0,0 +1,61 @@ +from pathlib import Path +import pytest +from tempfile import TemporaryDirectory +from pyspark.sql import SparkSession +from settlement_report_job.infrastructure.create_zip_file import create_zip_file + + +def test_create_zip_file__when_dbutils_is_none__raise_exception(): + # Arrange + dbutils = None + report_id = "report_id" + save_path = "save_path.zip" + files_to_zip = ["file1", "file2"] + + # Act + with pytest.raises(Exception): + create_zip_file(dbutils, report_id, save_path, files_to_zip) + + +def test_create_zip_file__when_save_path_is_not_zip__raise_exception(): + # Arrange + dbutils = None + report_id = "report_id" + save_path = "save_path" + files_to_zip = ["file1", "file2"] + + # Act + with pytest.raises(Exception): + create_zip_file(dbutils, report_id, save_path, files_to_zip) + + +def test_create_zip_file__when_no_files_to_zip__raise_exception(): + # Arrange + dbutils = None + report_id = "report_id" + save_path = "save_path.zip" + files_to_zip = ["file1", "file2"] + + # Act + with pytest.raises(Exception): + create_zip_file(dbutils, report_id, save_path, files_to_zip) + + +def test_create_zip_file__when_files_to_zip__create_zip_file(dbutils): + # Arrange + tmp_dir = TemporaryDirectory() + with open(f"{tmp_dir.name}/file1", "w") as f: + f.write("content1") + with open(f"{tmp_dir.name}/file2", "w") as f: + f.write("content2") + + report_id = "report_id" + save_path = f"{tmp_dir.name}/save_path.zip" + files_to_zip = [f"{tmp_dir.name}/file1", f"{tmp_dir.name}/file2"] + + # Act + create_zip_file(dbutils, report_id, save_path, files_to_zip) + + # Assert + assert Path(save_path).exists() + tmp_dir.cleanup() diff --git a/source/settlement_report_python/tests/infrastructure/test_csv_writer.py b/source/settlement_report_python/tests/infrastructure/test_csv_writer.py new file mode 100644 index 0000000..5712011 --- /dev/null +++ b/source/settlement_report_python/tests/infrastructure/test_csv_writer.py @@ -0,0 +1,843 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory + +from assertion import assert_file_names_and_columns +from settlement_report_job.infrastructure import csv_writer + +from pyspark.sql import SparkSession, DataFrame +import pyspark.sql.functions as F +from settlement_report_job.domain.utils.market_role import ( + MarketRole, +) +from settlement_report_job.domain.energy_results.prepare_for_csv import ( + prepare_for_csv, +) +from data_seeding import ( + standard_wholesale_fixing_scenario_data_generator, +) +from settlement_report_job.infrastructure.csv_writer import _write_files +from test_factories.default_test_data_spec import ( + create_energy_results_data_spec, +) +from dbutils_fixture import DBUtilsFixture +from functools import reduce +import pytest + +from settlement_report_job.domain.utils.report_data_type import ReportDataType + +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +import test_factories.time_series_points_csv_factory as time_series_points_factory +import test_factories.energy_factory as energy_factory +from settlement_report_job.domain.utils.csv_column_names import CsvColumnNames +from settlement_report_job.infrastructure.paths import get_report_output_path +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointResolutionDataProductValue, + MeteringPointTypeDataProductValue, +) +import settlement_report_job.domain.time_series_points.order_by_columns as time_series_points_order_by_columns +import settlement_report_job.domain.energy_results.order_by_columns as energy_order_by_columns + + +def _read_csv_file( + directory: str, + file_name: str, + spark: SparkSession, +) -> DataFrame: + file_name = f"{directory}/{file_name}" + return spark.read.csv(file_name, header=True) + + +@pytest.mark.parametrize( + "resolution,grid_area_codes,expected_file_count", + [ + (MeteringPointResolutionDataProductValue.HOUR, ["804", "805"], 2), + (MeteringPointResolutionDataProductValue.QUARTER, ["804", "805"], 2), + (MeteringPointResolutionDataProductValue.HOUR, ["804"], 1), + (MeteringPointResolutionDataProductValue.QUARTER, ["804", "805", "806"], 3), + ], +) +def test_write__returns_files_corresponding_to_grid_area_codes( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + resolution: MeteringPointResolutionDataProductValue, + grid_area_codes: list[str], + expected_file_count: int, +): + # Arrange + report_data_type = ( + ReportDataType.TimeSeriesHourly + if resolution == MeteringPointResolutionDataProductValue.HOUR + else ReportDataType.TimeSeriesQuarterly + ) + test_spec = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + grid_area_codes=grid_area_codes, + resolution=resolution, + ) + df_prepared_time_series_points = time_series_points_factory.create(spark, test_spec) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=time_series_points_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + ) + + # Assert + assert len(result_files) == expected_file_count + + +def test_write__when_higher_default_parallelism__number_of_files_is_unchanged( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + spark.conf.set("spark.sql.shuffle.partitions", "10") + spark.conf.set("spark.default.parallelism", "10") + report_data_type = ReportDataType.TimeSeriesHourly + expected_file_count = 2 + test_spec = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + grid_area_codes=["804", "805"], + ) + df_prepared_time_series_points = time_series_points_factory.create(spark, test_spec) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=time_series_points_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + ) + + # Assert + assert len(result_files) == expected_file_count + + +@pytest.mark.parametrize( + "number_of_rows,rows_per_file,expected_file_count", + [ + (201, 100, 3), + (101, 100, 2), + (100, 100, 1), + (99, 100, 1), + ], +) +def test_write__when_prevent_large_files_is_enabled__writes_expected_number_of_files( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + number_of_rows: int, + rows_per_file: int, + expected_file_count: int, +): + # Arrange + report_data_type = ReportDataType.TimeSeriesHourly + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=number_of_rows, + ) + df_prepared_time_series_points = time_series_points_factory.create(spark, test_spec) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=time_series_points_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=rows_per_file, + ) + + # Assert + assert df_prepared_time_series_points.count() == number_of_rows + assert len(result_files) == expected_file_count + + +@pytest.mark.parametrize( + "number_of_metering_points,number_of_days_for_each_mp,rows_per_file,expected_file_count", + [ + (21, 10, 100, 3), + (11, 10, 100, 2), + (9, 10, 100, 1), + ], +) +def test_write__files_have_correct_ordering_for_each_file( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + number_of_metering_points: int, + number_of_days_for_each_mp: int, + rows_per_file: int, + expected_file_count: int, +): + # Arrange + expected_order_by = [ + CsvColumnNames.metering_point_type, + CsvColumnNames.metering_point_id, + CsvColumnNames.time, + ] + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=number_of_metering_points, + num_days_per_metering_point=number_of_days_for_each_mp, + ) + df_prepared_time_series_points = time_series_points_factory.create(spark, test_spec) + df_prepared_time_series_points = df_prepared_time_series_points.orderBy(F.rand()) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=ReportDataType.TimeSeriesHourly, + order_by_columns=expected_order_by, + rows_per_file=rows_per_file, + ) + + # Assert + assert len(result_files) == expected_file_count + + # Assert that the files are ordered by metering_point_type, metering_point_id, start_of_day + # Asserting that the dataframe is unchanged + for file_name in result_files: + directory = get_report_output_path(standard_wholesale_fixing_scenario_args) + df_actual = _read_csv_file(directory, file_name, spark) + df_expected = df_actual.orderBy(expected_order_by) + assert df_actual.collect() == df_expected.collect() + + +@pytest.mark.parametrize( + "number_of_rows,grid_area_codes,expected_file_count", + [ + (20, ["804"], 1), + (20, ["804", "805"], 2), + ], +) +def test_write__files_have_correct_ordering_for_each_grid_area_code_file( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + number_of_rows: int, + grid_area_codes: list[str], + expected_file_count: int, +): + # Arrange + expected_order_by = [ + CsvColumnNames.metering_point_type, + CsvColumnNames.metering_point_id, + CsvColumnNames.time, + ] + report_data_type = ReportDataType.TimeSeriesHourly + test_spec = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + grid_area_codes=grid_area_codes, + num_metering_points=number_of_rows, + ) + df_prepared_time_series_points = time_series_points_factory.create(spark, test_spec) + df_prepared_time_series_points = df_prepared_time_series_points.orderBy(F.rand()) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=expected_order_by, + ) + + # Assert + assert len(result_files) == expected_file_count + + # Assert that the files are ordered by metering_point_type, metering_point_id, start_of_day + # Asserting that the dataframe is unchanged + for file_name in result_files: + directory = get_report_output_path(standard_wholesale_fixing_scenario_args) + df_actual = _read_csv_file(directory, file_name, spark) + df_expected = df_actual.orderBy(expected_order_by) + assert df_actual.collect() == df_expected.collect() + + +def test_write__files_have_correct_ordering_for_multiple_metering_point_types( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_file_count = 3 + individual_dataframes = [] + expected_order_by = [ + CsvColumnNames.metering_point_type, + CsvColumnNames.metering_point_id, + CsvColumnNames.time, + ] + report_data_type = ReportDataType.TimeSeriesQuarterly + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec_consumption = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=10, + ) + test_spec_production = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + metering_point_type=MeteringPointTypeDataProductValue.PRODUCTION, + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=20, + ) + df_prepared_time_series_points_consumption = time_series_points_factory.create( + spark, test_spec_consumption + ) + df_prepared_time_series_points_production = time_series_points_factory.create( + spark, test_spec_production + ) + df_prepared_time_series_points = df_prepared_time_series_points_consumption.union( + df_prepared_time_series_points_production + ).orderBy(F.rand()) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=expected_order_by, + rows_per_file=10, + ) + result_files.sort() + + # Assert + assert len(result_files) == expected_file_count + + # Assert that the files are ordered by metering_point_type, metering_point_id, start_of_day + # Asserting that the dataframe is unchanged + directory = get_report_output_path(standard_wholesale_fixing_scenario_args) + for file in result_files: + individual_dataframes.append(_read_csv_file(directory, file, spark)) + df_actual = reduce(DataFrame.unionByName, individual_dataframes) + df_expected = df_actual.orderBy(expected_order_by) + assert df_actual.collect() == df_expected.collect() + + +@pytest.mark.parametrize( + "number_of_rows,rows_per_file,expected_file_count", + [ + (201, 100, 3), + (101, 100, 2), + (99, 100, 1), + ], +) +def test_write__files_have_correct_sorting_across_multiple_files( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + number_of_rows: int, + rows_per_file: int, + expected_file_count: int, +): + # Arrange + individual_dataframes = [] + expected_order_by = [ + CsvColumnNames.metering_point_type, + CsvColumnNames.metering_point_id, + CsvColumnNames.time, + ] + report_data_type = ReportDataType.TimeSeriesHourly + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=number_of_rows, + ) + df_prepared_time_series_points = time_series_points_factory.create(spark, test_spec) + df_prepared_time_series_points = df_prepared_time_series_points.orderBy(F.rand()) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=expected_order_by, + rows_per_file=rows_per_file, + ) + result_files.sort() + + # Assert + assert len(result_files) == expected_file_count + + # Assert that the files are ordered by metering_point_type, metering_point_id, start_of_day + # Asserting that the dataframe is unchanged + directory = get_report_output_path(standard_wholesale_fixing_scenario_args) + for file in result_files: + individual_dataframes.append(_read_csv_file(directory, file, spark)) + df_actual = reduce(DataFrame.unionByName, individual_dataframes) + df_expected = df_actual.orderBy(expected_order_by) + assert df_actual.collect() == df_expected.collect() + + +def test_write__when_prevent_large_files__chunk_index_start_at_1( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_file_count = 3 + report_data_type = ReportDataType.TimeSeriesQuarterly + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec_consumption = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=30, + ) + df_prepared_time_series_points = time_series_points_factory.create( + spark, test_spec_consumption + ) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=time_series_points_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=10, + ) + + # Assert + assert len(result_files) == expected_file_count + for result_file in result_files: + file_name = result_file[:-4] + file_name_components = file_name.split("_") + + chunk_id_if_present = file_name_components[-1] + assert int(chunk_id_if_present) >= 1 and int(chunk_id_if_present) < 4 + + +def test_write__when_prevent_large_files_but_too_few_rows__chunk_index_should_be_excluded( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_file_count = 1 + report_data_type = ReportDataType.TimeSeriesQuarterly + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec_consumption = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + metering_point_type=MeteringPointTypeDataProductValue.CONSUMPTION, + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=30, + ) + df_prepared_time_series_points = time_series_points_factory.create( + spark, test_spec_consumption + ) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df_prepared_time_series_points, + report_data_type=report_data_type, + order_by_columns=time_series_points_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=31, + ) + + # Assert + assert len(result_files) == expected_file_count + file_name_components = result_files[0][:-4].split("_") + + assert not file_name_components[ + -1 + ].isdigit(), ( + "A valid integer indicating a present chunk index was found when not expected!" + ) + + +def test_write__when_prevent_large_files_and_multiple_grid_areas_but_too_few_rows__chunk_index_should_be_excluded( + dbutils: DBUtilsFixture, + spark: SparkSession, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_file_count = 2 + report_data_type = ReportDataType.TimeSeriesQuarterly + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + test_spec_consumption = time_series_points_factory.TimeSeriesPointsCsvTestDataSpec( + start_of_day=standard_wholesale_fixing_scenario_args.period_start, + num_metering_points=10, + grid_area_codes=["804", "805"], + ) + prepared_time_series_point = time_series_points_factory.create( + spark, test_spec_consumption, add_grid_area_code_partitioning_column=True + ) + + # Act + result_files = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=prepared_time_series_point, + report_data_type=report_data_type, + order_by_columns=time_series_points_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=31, + ) + + # Assert + assert len(result_files) == expected_file_count + for result_file in result_files: + file_name_components = result_file[:-4].split("_") + chunk_id_if_present = file_name_components[-1] + + assert ( + not chunk_id_if_present.isdigit() + ), "A valid integer indicating a present chunk index was found when not expected!" + + +def test_write__when_energy_and_split_report_by_grid_area_is_false__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_804_02-01-2024_02-01-2024.csv", + ] + + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area = { + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[ + 0 + ]: standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area[ + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[0] + ] + } + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + standard_wholesale_fixing_scenario_args.split_report_by_grid_area = True + + df = prepare_for_csv( + energy_factory.create_energy_per_es_v1( + spark, create_energy_results_data_spec(grid_area_code="804") + ), + standard_wholesale_fixing_scenario_args.split_report_by_grid_area, + standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ) + + # Act + actual_file_names = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df, + report_data_type=ReportDataType.EnergyResults, + order_by_columns=energy_order_by_columns.order_by_columns( + requesting_actor_market_role=standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=10000, + ) + + # Assert + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_file_names, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_write__when_energy_supplier_and_split_per_grid_area_is_false__returns_correct_columns_and_files( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_flere-net_1000000000000_DDQ_02-01-2024_02-01-2024.csv", + ] + + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.ENERGY_SUPPLIER + ) + energy_supplier_id = "1000000000000" + standard_wholesale_fixing_scenario_args.requesting_actor_id = energy_supplier_id + standard_wholesale_fixing_scenario_args.energy_supplier_ids = [energy_supplier_id] + standard_wholesale_fixing_scenario_args.split_report_by_grid_area = False + + df = prepare_for_csv( + energy_factory.create_energy_per_es_v1( + spark, + create_energy_results_data_spec( + grid_area_code="804", energy_supplier_id=energy_supplier_id + ), + ).union( + energy_factory.create_energy_per_es_v1( + spark, + create_energy_results_data_spec( + grid_area_code="805", energy_supplier_id=energy_supplier_id + ), + ) + ), + False, + standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ) + + # Act + actual_file_names = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df, + report_data_type=ReportDataType.EnergyResults, + order_by_columns=energy_order_by_columns.order_by_columns( + standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=10000, + ) + + # Assert + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_file_names, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_write__when_energy_and_prevent_large_files__returns_expected_number_of_files_and_content( + spark: SparkSession, + dbutils: DBUtilsFixture, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, +): + # Arrange + expected_file_count = 4 # corresponding to the number of grid areas in standard_wholesale_fixing_scenario + expected_columns = [ + CsvColumnNames.grid_area_code, + CsvColumnNames.energy_supplier_id, + CsvColumnNames.calculation_type, + CsvColumnNames.time, + CsvColumnNames.resolution, + CsvColumnNames.metering_point_type, + CsvColumnNames.settlement_method, + CsvColumnNames.energy_quantity, + ] + + expected_file_names = [ + "RESULTENERGY_804_02-01-2024_02-01-2024_1.csv", + "RESULTENERGY_804_02-01-2024_02-01-2024_2.csv", + "RESULTENERGY_804_02-01-2024_02-01-2024_3.csv", + "RESULTENERGY_804_02-01-2024_02-01-2024_4.csv", + ] + + standard_wholesale_fixing_scenario_args.requesting_actor_market_role = ( + MarketRole.DATAHUB_ADMINISTRATOR + ) + standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area = { + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[ + 0 + ]: standard_wholesale_fixing_scenario_args.calculation_id_by_grid_area[ + standard_wholesale_fixing_scenario_data_generator.GRID_AREAS[0] + ] + } + standard_wholesale_fixing_scenario_args.energy_supplier_ids = None + standard_wholesale_fixing_scenario_args.split_report_by_grid_area = False + standard_wholesale_fixing_scenario_args.prevent_large_text_files = True + + df = energy_factory.create_energy_per_es_v1( + spark, create_energy_results_data_spec(grid_area_code="804") + ) + + df = prepare_for_csv( + df, + True, + standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ) + + # Act + actual_file_names = csv_writer.write( + dbutils=dbutils, + args=standard_wholesale_fixing_scenario_args, + df=df, + report_data_type=ReportDataType.EnergyResults, + order_by_columns=energy_order_by_columns.order_by_columns( + standard_wholesale_fixing_scenario_args.requesting_actor_market_role, + ), + rows_per_file=df.count() // expected_file_count + 1, + ) + + # Assert + assert_file_names_and_columns( + path=get_report_output_path(standard_wholesale_fixing_scenario_args), + actual_files=actual_file_names, + expected_columns=expected_columns, + expected_file_names=expected_file_names, + spark=spark, + ) + + +def test_write_files__csv_separator_is_comma_and_decimals_use_points( + spark: SparkSession, +): + # Arrange + df = spark.createDataFrame([("a", 1.1), ("b", 2.2), ("c", 3.3)], ["key", "value"]) + tmp_dir = TemporaryDirectory() + csv_path = f"{tmp_dir.name}/csv_file" + + # Act + columns = _write_files( + df, + csv_path, + partition_columns=[], + order_by=[], + rows_per_file=1000, + ) + + # Assert + assert Path(csv_path).exists() + + for x in Path(csv_path).iterdir(): + if x.is_file() and x.name[-4:] == ".csv": + with x.open(mode="r") as f: + all_lines_written = f.readlines() + + assert all_lines_written[0] == "a,1.1\n" + assert all_lines_written[1] == "b,2.2\n" + assert all_lines_written[2] == "c,3.3\n" + + assert columns == ["key", "value"] + + tmp_dir.cleanup() + + +def test_write_files__when_order_by_specified_on_multiple_partitions( + spark: SparkSession, +): + # Arrange + df = spark.createDataFrame( + [("b", 2.2), ("b", 1.1), ("c", 3.3)], + ["key", "value"], + ) + tmp_dir = TemporaryDirectory() + csv_path = f"{tmp_dir.name}/csv_file" + + # Act + columns = _write_files( + df, + csv_path, + partition_columns=["key"], + order_by=["value"], + rows_per_file=1000, + ) + + # Assert + assert Path(csv_path).exists() + + for x in Path(csv_path).iterdir(): + if x.is_file() and x.name[-4:] == ".csv": + with x.open(mode="r") as f: + all_lines_written = f.readlines() + + if len(all_lines_written == 1): + assert all_lines_written[0] == "c;3,3\n" + elif len(all_lines_written == 2): + assert all_lines_written[0] == "b;1,1\n" + assert all_lines_written[1] == "b;2,2\n" + else: + raise AssertionError("Found unexpected csv file.") + + assert columns == ["value"] + + tmp_dir.cleanup() + + +def test_write_files__when_df_includes_timestamps__creates_csv_without_milliseconds( + spark: SparkSession, +): + # Arrange + df = spark.createDataFrame( + [ + ("a", datetime(2024, 10, 21, 12, 10, 30, 0)), + ("b", datetime(2024, 10, 21, 12, 10, 30, 30)), + ("c", datetime(2024, 10, 21, 12, 10, 30, 123)), + ], + ["key", "value"], + ) + tmp_dir = TemporaryDirectory() + csv_path = f"{tmp_dir.name}/csv_file" + + # Act + columns = _write_files( + df, + csv_path, + partition_columns=[], + order_by=[], + rows_per_file=1000, + ) + + # Assert + assert Path(csv_path).exists() + + for x in Path(csv_path).iterdir(): + if x.is_file() and x.name[-4:] == ".csv": + with x.open(mode="r") as f: + all_lines_written = f.readlines() + + assert all_lines_written[0] == "a,2024-10-21T12:10:30Z\n" + assert all_lines_written[1] == "b,2024-10-21T12:10:30Z\n" + assert all_lines_written[2] == "c,2024-10-21T12:10:30Z\n" + + assert columns == ["key", "value"] + + tmp_dir.cleanup() diff --git a/source/settlement_report_python/tests/infrastructure/test_report_name_factory.py b/source/settlement_report_python/tests/infrastructure/test_report_name_factory.py new file mode 100644 index 0000000..ebe6a88 --- /dev/null +++ b/source/settlement_report_python/tests/infrastructure/test_report_name_factory.py @@ -0,0 +1,519 @@ +import uuid +from datetime import datetime + +import pytest +from pyspark.sql import SparkSession + +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.infrastructure.report_name_factory import ( + FileNameFactory, + ReportDataType, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) + + +@pytest.fixture(scope="function") +def default_settlement_report_args() -> SettlementReportArgs: + """ + Note: Some tests depend on the values of `period_start` and `period_end` + """ + return SettlementReportArgs( + report_id=str(uuid.uuid4()), + requesting_actor_id="4123456789012", + period_start=datetime(2024, 6, 30, 22, 0, 0), + period_end=datetime(2024, 7, 31, 22, 0, 0), + calculation_type=CalculationType.WHOLESALE_FIXING, + calculation_id_by_grid_area={ + "016": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373") + }, + grid_area_codes=None, + split_report_by_grid_area=True, + prevent_large_text_files=False, + time_zone="Europe/Copenhagen", + catalog_name="catalog_name", + energy_supplier_ids=["1234567890123"], + requesting_actor_market_role=MarketRole.DATAHUB_ADMINISTRATOR, + settlement_reports_output_path="some_output_volume_path", + include_basis_data=True, + ) + + +@pytest.mark.parametrize( + "report_data_type,expected_pre_fix", + [ + (ReportDataType.TimeSeriesHourly, "TSSD60"), + (ReportDataType.TimeSeriesQuarterly, "TSSD15"), + (ReportDataType.ChargeLinks, "CHARGELINK"), + (ReportDataType.EnergyResults, "RESULTENERGY"), + (ReportDataType.WholesaleResults, "RESULTWHOLESALE"), + (ReportDataType.MonthlyAmounts, "RESULTMONTHLY"), + ], +) +def test_create__when_energy_supplier__returns_expected_file_name( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + expected_pre_fix: str, +): + # Arrange + args = default_settlement_report_args + energy_supplier_id = "1234567890123" + grid_area_code = "123" + args.requesting_actor_id = energy_supplier_id + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + args.energy_supplier_ids = [energy_supplier_id] + sut = FileNameFactory(report_data_type, args) + + # Act + actual = sut.create(grid_area_code, chunk_index=None) + + # Assert + assert ( + actual + == f"{expected_pre_fix}_{grid_area_code}_{energy_supplier_id}_DDQ_01-07-2024_31-07-2024.csv" + ) + + +def test_create__when_grid_access_provider__returns_expected_file_name( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, +): + # Arrange + args = default_settlement_report_args + grid_area_code = "123" + requesting_actor_id = "1111111111111" + args.requesting_actor_market_role = MarketRole.GRID_ACCESS_PROVIDER + args.requesting_actor_id = requesting_actor_id + args.energy_supplier_ids = None + sut = FileNameFactory(ReportDataType.TimeSeriesHourly, args) + + # Act + actual = sut.create(grid_area_code, chunk_index=None) + + # Assert + assert ( + actual + == f"TSSD60_{grid_area_code}_{requesting_actor_id}_DDM_01-07-2024_31-07-2024.csv" + ) + + +@pytest.mark.parametrize( + "market_role, energy_supplier_id, expected_file_name", + [ + (MarketRole.SYSTEM_OPERATOR, None, "TSSD60_123_01-07-2024_31-07-2024.csv"), + ( + MarketRole.DATAHUB_ADMINISTRATOR, + None, + "TSSD60_123_01-07-2024_31-07-2024.csv", + ), + ( + MarketRole.SYSTEM_OPERATOR, + "1987654321123", + "TSSD60_123_1987654321123_01-07-2024_31-07-2024.csv", + ), + ( + MarketRole.DATAHUB_ADMINISTRATOR, + "1987654321123", + "TSSD60_123_1987654321123_01-07-2024_31-07-2024.csv", + ), + ], +) +def test_create__when_system_operator_or_datahub_admin__returns_expected_file_name( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + market_role: MarketRole, + energy_supplier_id: str, + expected_file_name: str, +): + # Arrange + args = default_settlement_report_args + args.requesting_actor_market_role = market_role + args.energy_supplier_ids = [energy_supplier_id] if energy_supplier_id else None + grid_area_code = "123" + sut = FileNameFactory(ReportDataType.TimeSeriesHourly, args) + + # Act + actual = sut.create(grid_area_code, chunk_index=None) + + # Assert + assert actual == expected_file_name + + +def test_create__when_split_index_is_set__returns_file_name_that_include_split_index( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, +): + # Arrange + args = default_settlement_report_args + energy_supplier_id = "222222222222" + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + args.requesting_actor_id = energy_supplier_id + args.energy_supplier_ids = [energy_supplier_id] + sut = FileNameFactory(ReportDataType.TimeSeriesHourly, args) + + # Act + actual = sut.create(grid_area_code="123", chunk_index="17") + + # Assert + assert actual == f"TSSD60_123_{energy_supplier_id}_DDQ_01-07-2024_31-07-2024_17.csv" + + +@pytest.mark.parametrize( + "period_start,period_end,expected_start_date,expected_end_date", + [ + ( + datetime(2024, 2, 29, 23, 0, 0), + datetime(2024, 3, 31, 22, 0, 0), + "01-03-2024", + "31-03-2024", + ), + ( + datetime(2024, 9, 30, 22, 0, 0), + datetime(2024, 10, 31, 23, 0, 0), + "01-10-2024", + "31-10-2024", + ), + ], +) +def test_create__when_daylight_saving_time__returns_expected_dates_in_file_name( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + period_start: datetime, + period_end: datetime, + expected_start_date: str, + expected_end_date: str, +): + # Arrange + args = default_settlement_report_args + args.period_start = period_start + args.period_end = period_end + args.energy_supplier_ids = None + sut = FileNameFactory(ReportDataType.TimeSeriesHourly, args) + + # Act + actual = sut.create(grid_area_code="123", chunk_index="17") + + # Assert + assert actual == f"TSSD60_123_{expected_start_date}_{expected_end_date}_17.csv" + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_energy_supplier_requests_report_not_combined( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + args.split_report_by_grid_area = True + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + args.energy_supplier_ids = args.requesting_actor_id + + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code="123", chunk_index=None) + + # Assert + assert ( + actual + == f"{pre_fix}_123_{args.requesting_actor_id}_DDQ_01-07-2024_31-07-2024.csv" + ) + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_energy_supplier_requests_report_combined( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + args.calculation_id_by_grid_area = { + "123": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + "456": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + } + + args.split_report_by_grid_area = False + args.requesting_actor_market_role = MarketRole.ENERGY_SUPPLIER + args.energy_supplier_ids = [args.requesting_actor_id] + + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code=None, chunk_index=None) + + # Assert + assert ( + actual + == f"{pre_fix}_flere-net_{args.requesting_actor_id}_DDQ_01-07-2024_31-07-2024.csv" + ) + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_grid_access_provider_requests_report( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + args.requesting_actor_market_role = MarketRole.GRID_ACCESS_PROVIDER + args.calculation_id_by_grid_area = { + "456": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + } + args.energy_supplier_ids = [args.requesting_actor_id] + + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code="456", chunk_index=None) + + # Assert + assert ( + actual + == f"{pre_fix}_456_{args.requesting_actor_id}_DDM_01-07-2024_31-07-2024.csv" + ) + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_datahub_administrator_requests_report_single_grid( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + args.requesting_actor_market_role = MarketRole.DATAHUB_ADMINISTRATOR + args.energy_supplier_ids = None + args.calculation_id_by_grid_area = { + "456": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + } + + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code="456", chunk_index=None) + + # Assert + assert actual == f"{pre_fix}_456_01-07-2024_31-07-2024.csv" + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_datahub_administrator_requests_report_multi_grid_not_combined( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + args.calculation_id_by_grid_area = { + "123": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + "456": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + } + args.split_report_by_grid_area = True + args.requesting_actor_market_role = MarketRole.DATAHUB_ADMINISTRATOR + args.energy_supplier_ids = None + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code="456", chunk_index=None) + + # Assert + assert actual == f"{pre_fix}_456_01-07-2024_31-07-2024.csv" + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_datahub_administrator_requests_report_multi_grid_single_provider_combined( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + energy_supplier_id = "1234567890123" + args.calculation_id_by_grid_area = { + "123": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + "456": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + } + args.split_report_by_grid_area = False + args.requesting_actor_market_role = MarketRole.DATAHUB_ADMINISTRATOR + args.energy_supplier_ids = [energy_supplier_id] + + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code=None, chunk_index=None) + + # Assert + assert ( + actual == f"{pre_fix}_flere-net_{energy_supplier_id}_01-07-2024_31-07-2024.csv" + ) + + +@pytest.mark.parametrize( + "report_data_type, pre_fix", + [ + pytest.param( + ReportDataType.EnergyResults, + "RESULTENERGY", + id="returns correct energy file name", + ), + pytest.param( + ReportDataType.WholesaleResults, + "RESULTWHOLESALE", + id="returns correct wholesale file name", + ), + pytest.param( + ReportDataType.MonthlyAmounts, + "RESULTMONTHLY", + id="returns correct monthly amounts file name", + ), + ], +) +def test_create__when_datahub_administrator_requests_result_report_multi_grid_all_providers_combined( + spark: SparkSession, + default_settlement_report_args: SettlementReportArgs, + report_data_type: ReportDataType, + pre_fix: str, +): + # Arrange + args = default_settlement_report_args + args.calculation_id_by_grid_area = { + "123": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + "456": uuid.UUID("32e49805-20ef-4db2-ac84-c4455de7a373"), + } + args.split_report_by_grid_area = False + args.requesting_actor_market_role = MarketRole.DATAHUB_ADMINISTRATOR + args.energy_supplier_ids = None + + factory = FileNameFactory(report_data_type, args) + + # Act + actual = factory.create(grid_area_code=None, chunk_index=None) + + # Assert + assert actual == f"{pre_fix}_flere-net_01-07-2024_31-07-2024.csv" diff --git a/source/settlement_report_python/tests/integration_test/__init__.py b/source/settlement_report_python/tests/integration_test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/integration_test/test_azure_monitor.py b/source/settlement_report_python/tests/integration_test/test_azure_monitor.py new file mode 100644 index 0000000..7e8763b --- /dev/null +++ b/source/settlement_report_python/tests/integration_test/test_azure_monitor.py @@ -0,0 +1,249 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import time +import uuid +from datetime import timedelta +from typing import cast, Callable +from unittest.mock import Mock, patch + +import pytest +from azure.monitor.query import LogsQueryClient, LogsQueryResult +from settlement_report_job.entry_points.job_args.calculation_type import CalculationType +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.entry_points.entry_point import ( + start_task_with_deps, +) +from settlement_report_job.entry_points.tasks.task_type import TaskType +from integration_test_configuration import IntegrationTestConfiguration + + +class TestWhenInvokedWithArguments: + def test_add_info_log_record_to_azure_monitor_with_expected_settings( + self, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + integration_test_configuration: IntegrationTestConfiguration, + ) -> None: + """ + Assert that the settlement report job adds log records to Azure Monitor with the expected settings: + | where AppRoleName == "dbr-settlement-report" + | where SeverityLevel == 1 + | where Message startswith_cs "Command line arguments" + | where OperationId != "00000000000000000000000000000000" + | where Properties.Subsystem == "wholesale-aggregations" + - custom field "settlement_report_id" = + - custom field "CategoryName" = "Energinet.DataHub." + + + Debug level is not tested as it is not intended to be logged by default. + """ + + # Arrange + valid_task_type = TaskType.TimeSeriesHourly + standard_wholesale_fixing_scenario_args.report_id = str(uuid.uuid4()) + applicationinsights_connection_string = ( + integration_test_configuration.get_applicationinsights_connection_string() + ) + os.environ["CATALOG_NAME"] = "test_catalog" + task_factory_mock = Mock() + self.prepare_command_line_arguments(standard_wholesale_fixing_scenario_args) + + # Act + with patch( + "settlement_report_job.entry_points.tasks.task_factory.create", + task_factory_mock, + ): + with patch( + "settlement_report_job.entry_points.tasks.time_series_points_task.TimeSeriesPointsTask.execute", + return_value=None, + ): + start_task_with_deps( + task_type=valid_task_type, + applicationinsights_connection_string=applicationinsights_connection_string, + ) + + # Assert + # noinspection PyTypeChecker + logs_client = LogsQueryClient(integration_test_configuration.credential) + + query = f""" + AppTraces + | where AppRoleName == "dbr-settlement-report" + | where SeverityLevel == 1 + | where Message startswith_cs "Command line arguments" + | where OperationId != "00000000000000000000000000000000" + | where Properties.Subsystem == "wholesale-aggregations" + | where Properties.settlement_report_id == "{standard_wholesale_fixing_scenario_args.report_id}" + | where Properties.CategoryName == "Energinet.DataHub.settlement_report_job.entry_points.job_args.settlement_report_job_args" + | count + """ + + workspace_id = integration_test_configuration.get_analytics_workspace_id() + + def assert_logged(): + actual = logs_client.query_workspace( + workspace_id, query, timespan=timedelta(minutes=5) + ) + assert_row_count(actual, 1) + + # Assert, but timeout if not succeeded + wait_for_condition( + assert_logged, timeout=timedelta(minutes=3), step=timedelta(seconds=10) + ) + + def test_add_exception_log_record_to_azure_monitor_with_unexpected_settings( + self, + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + integration_test_configuration: IntegrationTestConfiguration, + ) -> None: + """ + Assert that the settlement report job adds log records to Azure Monitor with the expected settings: + | where AppRoleName == "dbr-settlement-report" + | where ExceptionType == "argparse.ArgumentTypeError" + | where OuterMessage startswith_cs "Grid area codes must consist of 3 digits" + | where OperationId != "00000000000000000000000000000000" + | where Properties.Subsystem == "wholesale-aggregations" + - custom field "settlement_report_id" = + - custom field "CategoryName" = "Energinet.DataHub." + + + Debug level is not tested as it is not intended to be logged by default. + """ + + # Arrange + valid_task_type = TaskType.TimeSeriesHourly + standard_wholesale_fixing_scenario_args.report_id = str(uuid.uuid4()) + standard_wholesale_fixing_scenario_args.calculation_type = ( + CalculationType.BALANCE_FIXING + ) + standard_wholesale_fixing_scenario_args.grid_area_codes = [ + "8054" + ] # Should produce an error with balance fixing + applicationinsights_connection_string = ( + integration_test_configuration.get_applicationinsights_connection_string() + ) + os.environ["CATALOG_NAME"] = "test_catalog" + task_factory_mock = Mock() + self.prepare_command_line_arguments(standard_wholesale_fixing_scenario_args) + + # Act + with pytest.raises(SystemExit): + with patch( + "settlement_report_job.entry_points.tasks.task_factory.create", + task_factory_mock, + ): + with patch( + "settlement_report_job.entry_points.tasks.time_series_points_task.TimeSeriesPointsTask.execute", + return_value=None, + ): + start_task_with_deps( + task_type=valid_task_type, + applicationinsights_connection_string=applicationinsights_connection_string, + ) + + # Assert + # noinspection PyTypeChecker + logs_client = LogsQueryClient(integration_test_configuration.credential) + + query = f""" + AppExceptions + | where AppRoleName == "dbr-settlement-report" + | where ExceptionType == "argparse.ArgumentTypeError" + | where OuterMessage startswith_cs "Grid area codes must consist of 3 digits" + | where OperationId != "00000000000000000000000000000000" + | where Properties.Subsystem == "wholesale-aggregations" + | where Properties.settlement_report_id == "{standard_wholesale_fixing_scenario_args.report_id}" + | where Properties.CategoryName == "Energinet.DataHub.telemetry_logging.span_recording" + | count + """ + + workspace_id = integration_test_configuration.get_analytics_workspace_id() + + def assert_logged(): + actual = logs_client.query_workspace( + workspace_id, query, timespan=timedelta(minutes=5) + ) + # There should be two counts, one from the arg_parser and one + assert_row_count(actual, 1) + + # Assert, but timeout if not succeeded + wait_for_condition( + assert_logged, timeout=timedelta(minutes=3), step=timedelta(seconds=10) + ) + + @staticmethod + def prepare_command_line_arguments( + standard_wholesale_fixing_scenario_args: SettlementReportArgs, + ) -> None: + standard_wholesale_fixing_scenario_args.report_id = str( + uuid.uuid4() + ) # Ensure unique report id + sys.argv = [] + sys.argv.append( + "--entry-point=execute_wholesale_results" + ) # Workaround as the parse command line arguments starts with the second argument + sys.argv.append( + f"--report-id={str(standard_wholesale_fixing_scenario_args.report_id)}" + ) + sys.argv.append( + f"--period-start={str(standard_wholesale_fixing_scenario_args.period_start.strftime('%Y-%m-%dT%H:%M:%SZ'))}" + ) + sys.argv.append( + f"--period-end={str(standard_wholesale_fixing_scenario_args.period_end.strftime('%Y-%m-%dT%H:%M:%SZ'))}" + ) + sys.argv.append( + f"--calculation-type={str(standard_wholesale_fixing_scenario_args.calculation_type.value)}" + ) + sys.argv.append("--requesting-actor-market-role=datahub_administrator") + sys.argv.append("--requesting-actor-id=1234567890123") + sys.argv.append( + f"--grid-area-codes={str(standard_wholesale_fixing_scenario_args.grid_area_codes)}" + ) + sys.argv.append( + '--calculation-id-by-grid-area={"804": "bf6e1249-d4c2-4ec2-8ce5-4c7fe8756253"}' + ) + + +def wait_for_condition(callback: Callable, *, timeout: timedelta, step: timedelta): + """ + Wait for a condition to be met, or timeout. + The function keeps invoking the callback until it returns without raising an exception. + """ + start_time = time.time() + while True: + elapsed_ms = int((time.time() - start_time) * 1000) + # noinspection PyBroadException + try: + callback() + print(f"Condition met in {elapsed_ms} ms") + return + except Exception: + if elapsed_ms > timeout.total_seconds() * 1000: + print( + f"Condition failed to be met before timeout. Timed out after {elapsed_ms} ms", + file=sys.stderr, + ) + raise + time.sleep(step.seconds) + print(f"Condition not met after {elapsed_ms} ms. Retrying...") + + +def assert_row_count(actual, expected_count): + actual = cast(LogsQueryResult, actual) + table = actual.tables[0] + row = table.rows[0] + value = row["Count"] + count = cast(int, value) + assert count == expected_count diff --git a/source/settlement_report_python/tests/integration_test_configuration.py b/source/settlement_report_python/tests/integration_test_configuration.py new file mode 100644 index 0000000..5078121 --- /dev/null +++ b/source/settlement_report_python/tests/integration_test_configuration.py @@ -0,0 +1,48 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from azure.identity import DefaultAzureCredential +from azure.keyvault.secrets import SecretClient + + +class IntegrationTestConfiguration: + def __init__(self, azure_keyvault_url: str): + self._credential = DefaultAzureCredential() + self._azure_keyvault_url = azure_keyvault_url + + # noinspection PyTypeChecker + # From https://youtrack.jetbrains.com/issue/PY-59279/Type-checking-detects-an-error-when-passing-an-instance-implicitly-conforming-to-a-Protocol-to-a-function-expecting-that: + # DefaultAzureCredential does not conform to protocol TokenCredential, because its method get_token is missing + # the arguments claims and tenant_id. Surely, they might appear among the arguments passed as **kwargs, but it's + # not guaranteed. In other words, you can make a call to get_token which will typecheck fine for + # DefaultAzureCredential, but not for TokenCredential. + self._secret_client = SecretClient( + vault_url=self._azure_keyvault_url, + credential=self._credential, + ) + + @property + def credential(self) -> DefaultAzureCredential: + return self._credential + + def get_analytics_workspace_id(self) -> str: + return self._get_secret_value("AZURE-LOGANALYTICS-WORKSPACE-ID") + + def get_applicationinsights_connection_string(self) -> str: + # This is the name of the secret in Azure Key Vault in the integration test environment + return self._get_secret_value("AZURE-APPINSIGHTS-CONNECTIONSTRING") + + def _get_secret_value(self, secret_name: str) -> str: + secret = self._secret_client.get_secret(secret_name) + return secret.value diff --git a/source/settlement_report_python/tests/integrationtest.local.settings.sample.yml b/source/settlement_report_python/tests/integrationtest.local.settings.sample.yml new file mode 100644 index 0000000..721c8f7 --- /dev/null +++ b/source/settlement_report_python/tests/integrationtest.local.settings.sample.yml @@ -0,0 +1,13 @@ +# This file contains the configuration settings for the integration tests. +# Create a copy of this file and remove the .sample extension and update the values. +# Read more about integration testing in python in Confluence. + +# Required +AZURE_KEYVAULT_URL: + +# Optional - only required if you want to run the integration tests in PyCharm +# Service Principal Credentials +AZURE_CLIENT_ID: +AZURE_TENANT_ID: +AZURE_CLIENT_SECRET: +AZURE_SUBSCRIPTION_ID: diff --git a/source/settlement_report_python/tests/test_factories/__init__.py b/source/settlement_report_python/tests/test_factories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/source/settlement_report_python/tests/test_factories/amounts_per_charge_factory.py b/source/settlement_report_python/tests/test_factories/amounts_per_charge_factory.py new file mode 100644 index 0000000..69ffaa6 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/amounts_per_charge_factory.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal +from typing import Union, List + +from pyspark.sql import SparkSession, DataFrame + + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + ChargeTypeDataProductValue, + ChargeResolutionDataProductValue, + MeteringPointTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.data_values.settlement_method import ( + SettlementMethodDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas.amounts_per_charge_v1 import ( + amounts_per_charge_v1, +) + + +@dataclass +class AmountsPerChargeRow: + """ + Data specification for creating wholesale test data. + """ + + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + result_id: str + grid_area_code: str + energy_supplier_id: str + charge_code: str + charge_type: ChargeTypeDataProductValue + charge_owner_id: str + resolution: ChargeResolutionDataProductValue + quantity_unit: str + metering_point_type: MeteringPointTypeDataProductValue + settlement_method: SettlementMethodDataProductValue | None + is_tax: bool + currency: str + time: datetime + quantity: Decimal + quantity_qualities: list[str] + price: Decimal + amount: Decimal + + +def create( + spark: SparkSession, + data_spec: Union[AmountsPerChargeRow, List[AmountsPerChargeRow]], +) -> DataFrame: + if not isinstance(data_spec, list): + data_specs = [data_spec] + else: + data_specs = data_spec + + rows = [] + for spec in data_specs: + row = { + DataProductColumnNames.calculation_id: spec.calculation_id, + DataProductColumnNames.calculation_type: spec.calculation_type.value, + DataProductColumnNames.calculation_version: spec.calculation_version, + DataProductColumnNames.result_id: spec.result_id, + DataProductColumnNames.grid_area_code: spec.grid_area_code, + DataProductColumnNames.energy_supplier_id: spec.energy_supplier_id, + DataProductColumnNames.charge_code: spec.charge_code, + DataProductColumnNames.charge_type: spec.charge_type.value, + DataProductColumnNames.charge_owner_id: spec.charge_owner_id, + DataProductColumnNames.resolution: spec.resolution.value, + DataProductColumnNames.quantity_unit: spec.quantity_unit, + DataProductColumnNames.metering_point_type: spec.metering_point_type.value, + DataProductColumnNames.settlement_method: getattr( + spec.settlement_method, "value", None + ), + DataProductColumnNames.is_tax: spec.is_tax, + DataProductColumnNames.currency: spec.currency, + DataProductColumnNames.time: spec.time, + DataProductColumnNames.quantity: spec.quantity, + DataProductColumnNames.quantity_qualities: spec.quantity_qualities, + DataProductColumnNames.price: spec.price, + DataProductColumnNames.amount: spec.amount, + } + rows.append(row) + + return spark.createDataFrame(rows, amounts_per_charge_v1) diff --git a/source/settlement_report_python/tests/test_factories/charge_link_periods_factory.py b/source/settlement_report_python/tests/test_factories/charge_link_periods_factory.py new file mode 100644 index 0000000..bd00cbb --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/charge_link_periods_factory.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass +from datetime import datetime + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + ChargeTypeDataProductValue, + CalculationTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + charge_link_periods_v1, +) + + +@dataclass +class ChargeLinkPeriodsRow: + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + charge_key: str + charge_code: str + charge_type: ChargeTypeDataProductValue + charge_owner_id: str + metering_point_id: str + quantity: int + from_date: datetime + to_date: datetime + + +def create( + spark: SparkSession, + rows: ChargeLinkPeriodsRow | list[ChargeLinkPeriodsRow], +) -> DataFrame: + if not isinstance(rows, list): + rows = [rows] + + row_list = [] + for row in rows: + row_list.append( + { + DataProductColumnNames.calculation_id: row.calculation_id, + DataProductColumnNames.calculation_type: row.calculation_type.value, + DataProductColumnNames.calculation_version: row.calculation_version, + DataProductColumnNames.charge_key: row.charge_key, + DataProductColumnNames.charge_code: row.charge_code, + DataProductColumnNames.charge_type: row.charge_type.value, + DataProductColumnNames.charge_owner_id: row.charge_owner_id, + DataProductColumnNames.metering_point_id: row.metering_point_id, + DataProductColumnNames.quantity: row.quantity, + DataProductColumnNames.from_date: row.from_date, + DataProductColumnNames.to_date: row.to_date, + } + ) + + return spark.createDataFrame(row_list, charge_link_periods_v1) diff --git a/source/settlement_report_python/tests/test_factories/charge_price_information_periods_factory.py b/source/settlement_report_python/tests/test_factories/charge_price_information_periods_factory.py new file mode 100644 index 0000000..46c55c5 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/charge_price_information_periods_factory.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass +from datetime import datetime + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + ChargeTypeDataProductValue, + ChargeResolutionDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + charge_price_information_periods_v1, +) + + +@dataclass +class ChargePriceInformationPeriodsRow: + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + charge_key: str + charge_code: str + charge_type: ChargeTypeDataProductValue + charge_owner_id: str + resolution: ChargeResolutionDataProductValue + is_tax: bool + from_date: datetime + to_date: datetime + + +def create( + spark: SparkSession, + rows: ChargePriceInformationPeriodsRow | list[ChargePriceInformationPeriodsRow], +) -> DataFrame: + if not isinstance(rows, list): + rows = [rows] + + row_list = [] + for row in rows: + row_list.append( + { + DataProductColumnNames.calculation_id: row.calculation_id, + DataProductColumnNames.calculation_type: row.calculation_type.value, + DataProductColumnNames.calculation_version: row.calculation_version, + DataProductColumnNames.charge_key: row.charge_key, + DataProductColumnNames.charge_code: row.charge_code, + DataProductColumnNames.charge_type: row.charge_type.value, + DataProductColumnNames.charge_owner_id: row.charge_owner_id, + DataProductColumnNames.resolution: row.resolution.value, + DataProductColumnNames.is_tax: row.is_tax, + DataProductColumnNames.from_date: row.from_date, + DataProductColumnNames.to_date: row.to_date, + } + ) + + return spark.createDataFrame(row_list, charge_price_information_periods_v1) diff --git a/source/settlement_report_python/tests/test_factories/charge_price_points_factory.py b/source/settlement_report_python/tests/test_factories/charge_price_points_factory.py new file mode 100644 index 0000000..bdfed6c --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/charge_price_points_factory.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + ChargeTypeDataProductValue, + CalculationTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas.charge_price_points_v1 import ( + charge_price_points_v1, +) + + +@dataclass +class ChargePricePointsRow: + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + charge_key: str + charge_code: str + charge_type: ChargeTypeDataProductValue + charge_owner_id: str + charge_price: Decimal + charge_time: datetime + + +def create( + spark: SparkSession, + rows: ChargePricePointsRow | list[ChargePricePointsRow], +) -> DataFrame: + if not isinstance(rows, list): + rows = [rows] + + row_list = [] + for row in rows: + row_list.append( + { + DataProductColumnNames.calculation_id: row.calculation_id, + DataProductColumnNames.calculation_type: row.calculation_type.value, + DataProductColumnNames.calculation_version: row.calculation_version, + DataProductColumnNames.charge_key: row.charge_key, + DataProductColumnNames.charge_code: row.charge_code, + DataProductColumnNames.charge_type: row.charge_type.value, + DataProductColumnNames.charge_owner_id: row.charge_owner_id, + DataProductColumnNames.charge_price: row.charge_price, + DataProductColumnNames.charge_time: row.charge_time, + } + ) + + return spark.createDataFrame(row_list, charge_price_points_v1) diff --git a/source/settlement_report_python/tests/test_factories/default_test_data_spec.py b/source/settlement_report_python/tests/test_factories/default_test_data_spec.py new file mode 100644 index 0000000..f7a668d --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/default_test_data_spec.py @@ -0,0 +1,361 @@ +from datetime import datetime, timedelta +from decimal import Decimal + +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointTypeDataProductValue, + ChargeTypeDataProductValue, + ChargeResolutionDataProductValue, + MeteringPointResolutionDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.data_values.calculation_type import ( + CalculationTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.data_values.settlement_method import ( + SettlementMethodDataProductValue, +) +from test_factories.charge_price_information_periods_factory import ( + ChargePriceInformationPeriodsRow, +) +from test_factories.latest_calculations_factory import LatestCalculationsPerDayRow +from test_factories.metering_point_periods_factory import MeteringPointPeriodsRow +from test_factories.metering_point_time_series_factory import ( + MeteringPointTimeSeriesTestDataSpec, +) + +from test_factories.charge_price_points_factory import ChargePricePointsRow +from test_factories.monthly_amounts_per_charge_factory import MonthlyAmountsPerChargeRow +from test_factories.total_monthly_amounts_factory import TotalMonthlyAmountsRow +from test_factories.charge_link_periods_factory import ChargeLinkPeriodsRow +from test_factories.energy_factory import EnergyTestDataSpec +from test_factories.amounts_per_charge_factory import AmountsPerChargeRow + +DEFAULT_FROM_DATE = datetime(2024, 1, 1, 23) +DEFAULT_TO_DATE = DEFAULT_FROM_DATE + timedelta(days=1) +DATAHUB_ADMINISTRATOR_ID = "1234567890123" +DEFAULT_PERIOD_START = DEFAULT_FROM_DATE +DEFAULT_PERIOD_END = DEFAULT_TO_DATE +DEFAULT_CALCULATION_ID = "11111111-1111-1111-1111-111111111111" +DEFAULT_CALCULATION_VERSION = 1 +DEFAULT_METERING_POINT_ID = "3456789012345" +DEFAULT_METERING_TYPE = MeteringPointTypeDataProductValue.CONSUMPTION +DEFAULT_RESOLUTION = MeteringPointResolutionDataProductValue.HOUR +DEFAULT_GRID_AREA_CODE = "804" +DEFAULT_FROM_GRID_AREA_CODE = None +DEFAULT_TO_GRID_AREA_CODE = None +DEFAULT_ENERGY_SUPPLIER_ID = "1234567890123" +DEFAULT_CHARGE_CODE = "41000" +DEFAULT_CHARGE_TYPE = ChargeTypeDataProductValue.TARIFF +DEFAULT_CHARGE_OWNER_ID = "3333333333333" +DEFAULT_CHARGE_PRICE = Decimal("10.000") + +# For energy results +DEFAULT_RESULT_ID = "12345678-4e15-434c-9d93-b03a6dd272a5" +DEFAULT_SETTLEMENT_METHOD = None +DEFAULT_QUANTITY_UNIT = "kwh" +DEFAULT_QUANTITY_QUALITIES = ["measured"] +DEFAULT_BALANCE_RESPONSIBLE_PARTY_ID = "1234567890123" + + +def create_charge_link_periods_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + charge_code: str = DEFAULT_CHARGE_CODE, + charge_type: ChargeTypeDataProductValue = DEFAULT_CHARGE_TYPE, + charge_owner_id: str = DEFAULT_CHARGE_OWNER_ID, + metering_point_id: str = DEFAULT_METERING_POINT_ID, + from_date: datetime = DEFAULT_PERIOD_START, + to_date: datetime = DEFAULT_PERIOD_END, + quantity: int = 1, +) -> ChargeLinkPeriodsRow: + charge_key = f"{charge_code}-{charge_type}-{charge_owner_id}" + return ChargeLinkPeriodsRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + charge_key=charge_key, + charge_code=charge_code, + charge_type=charge_type, + charge_owner_id=charge_owner_id, + metering_point_id=metering_point_id, + from_date=from_date, + to_date=to_date, + quantity=quantity, + ) + + +def create_charge_price_points_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + charge_code: str = DEFAULT_CHARGE_CODE, + charge_type: ChargeTypeDataProductValue = DEFAULT_CHARGE_TYPE, + charge_owner_id: str = DEFAULT_CHARGE_OWNER_ID, + charge_price: Decimal = DEFAULT_CHARGE_PRICE, + charge_time: datetime = DEFAULT_PERIOD_START, +) -> ChargePricePointsRow: + charge_key = f"{charge_code}-{charge_type}-{charge_owner_id}" + return ChargePricePointsRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + charge_key=charge_key, + charge_code=charge_code, + charge_type=charge_type, + charge_owner_id=charge_owner_id, + charge_price=charge_price, + charge_time=charge_time, + ) + + +def create_charge_price_information_periods_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + charge_code: str = DEFAULT_CHARGE_CODE, + charge_type: ChargeTypeDataProductValue = DEFAULT_CHARGE_TYPE, + charge_owner_id: str = DEFAULT_CHARGE_OWNER_ID, + resolution: ChargeResolutionDataProductValue = ChargeResolutionDataProductValue.HOUR, + is_tax: bool = False, + from_date: datetime = DEFAULT_PERIOD_START, + to_date: datetime = DEFAULT_PERIOD_END, +) -> ChargePriceInformationPeriodsRow: + charge_key = f"{charge_code}-{charge_type}-{charge_owner_id}" + + return ChargePriceInformationPeriodsRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + charge_key=charge_key, + charge_code=charge_code, + charge_type=charge_type, + charge_owner_id=charge_owner_id, + resolution=resolution, + is_tax=is_tax, + from_date=from_date, + to_date=to_date, + ) + + +def create_metering_point_periods_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + metering_point_id: str = DEFAULT_METERING_POINT_ID, + metering_point_type: MeteringPointTypeDataProductValue = DEFAULT_METERING_TYPE, + settlement_method: SettlementMethodDataProductValue = DEFAULT_SETTLEMENT_METHOD, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + resolution: MeteringPointResolutionDataProductValue = DEFAULT_RESOLUTION, + from_grid_area_code: str = DEFAULT_FROM_GRID_AREA_CODE, + to_grid_area_code: str = DEFAULT_TO_GRID_AREA_CODE, + parent_metering_point_id: str | None = None, + energy_supplier_id: str = DEFAULT_ENERGY_SUPPLIER_ID, + balance_responsible_party_id: str = DEFAULT_BALANCE_RESPONSIBLE_PARTY_ID, + from_date: datetime = DEFAULT_PERIOD_START, + to_date: datetime = DEFAULT_PERIOD_END, +) -> MeteringPointPeriodsRow: + return MeteringPointPeriodsRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + metering_point_id=metering_point_id, + metering_point_type=metering_point_type, + settlement_method=settlement_method, + grid_area_code=grid_area_code, + resolution=resolution, + from_grid_area_code=from_grid_area_code, + to_grid_area_code=to_grid_area_code, + parent_metering_point_id=parent_metering_point_id, + energy_supplier_id=energy_supplier_id, + balance_responsible_party_id=balance_responsible_party_id, + from_date=from_date, + to_date=to_date, + ) + + +def create_time_series_points_data_spec( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + metering_point_id: str = DEFAULT_METERING_POINT_ID, + metering_point_type: MeteringPointTypeDataProductValue = DEFAULT_METERING_TYPE, + resolution: MeteringPointResolutionDataProductValue = DEFAULT_RESOLUTION, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + energy_supplier_id: str = DEFAULT_ENERGY_SUPPLIER_ID, + from_date: datetime = DEFAULT_PERIOD_START, + to_date: datetime = DEFAULT_PERIOD_END, + quantity: Decimal = Decimal("1.005"), +) -> MeteringPointTimeSeriesTestDataSpec: + return MeteringPointTimeSeriesTestDataSpec( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + metering_point_id=metering_point_id, + metering_point_type=metering_point_type, + resolution=resolution, + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + from_date=from_date, + to_date=to_date, + quantity=quantity, + ) + + +def create_amounts_per_charge_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + energy_supplier_id: str = DEFAULT_ENERGY_SUPPLIER_ID, + charge_code: str = DEFAULT_CHARGE_CODE, + charge_type: ChargeTypeDataProductValue = DEFAULT_CHARGE_TYPE, + charge_owner_id: str = DEFAULT_CHARGE_OWNER_ID, + resolution: ChargeResolutionDataProductValue = ChargeResolutionDataProductValue.HOUR, + quantity_unit: str = "kWh", + metering_point_type: MeteringPointTypeDataProductValue = DEFAULT_METERING_TYPE, + settlement_method: SettlementMethodDataProductValue = DEFAULT_SETTLEMENT_METHOD, + is_tax: bool = False, + currency: str = "DKK", + time: datetime = DEFAULT_PERIOD_START, + quantity: Decimal = Decimal("1.005"), + quantity_qualities: list[str] = ["measured"], + price: Decimal = Decimal("0.005"), + amount: Decimal = Decimal("0.005"), +) -> AmountsPerChargeRow: + return AmountsPerChargeRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + result_id="result_id_placeholder", # Add appropriate value + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + charge_code=charge_code, + charge_type=charge_type, + charge_owner_id=charge_owner_id, + resolution=resolution, + quantity_unit=quantity_unit, + metering_point_type=metering_point_type, + settlement_method=settlement_method, + is_tax=is_tax, + currency=currency, + time=time, + quantity=quantity, + quantity_qualities=quantity_qualities, + price=price, + amount=amount, + ) + + +def create_monthly_amounts_per_charge_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + energy_supplier_id: str = DEFAULT_ENERGY_SUPPLIER_ID, + charge_code: str = DEFAULT_CHARGE_CODE, + charge_type: ChargeTypeDataProductValue = DEFAULT_CHARGE_TYPE, + charge_owner_id: str = DEFAULT_CHARGE_OWNER_ID, + quantity_unit: str = "kWh", + is_tax: bool = False, + currency: str = "DKK", + time: datetime = DEFAULT_PERIOD_START, + amount: Decimal = Decimal("0.005"), +) -> MonthlyAmountsPerChargeRow: + return MonthlyAmountsPerChargeRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + result_id="result_id_placeholder", # Add appropriate value + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + charge_code=charge_code, + charge_type=charge_type, + charge_owner_id=charge_owner_id, + quantity_unit=quantity_unit, + is_tax=is_tax, + currency=currency, + time=time, + amount=amount, + ) + + +def create_total_monthly_amounts_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + energy_supplier_id: str = DEFAULT_ENERGY_SUPPLIER_ID, + charge_owner_id: str = DEFAULT_CHARGE_OWNER_ID, + currency: str = "DKK", + time: datetime = DEFAULT_PERIOD_START, + amount: Decimal = Decimal("0.005"), +) -> TotalMonthlyAmountsRow: + return TotalMonthlyAmountsRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + result_id="result_id_placeholder", # Add appropriate value + grid_area_code=grid_area_code, + energy_supplier_id=energy_supplier_id, + charge_owner_id=charge_owner_id, + currency=currency, + time=time, + amount=amount, + ) + + +def create_latest_calculations_per_day_row( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.BALANCE_FIXING, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + start_of_day: datetime = DEFAULT_PERIOD_START, +) -> LatestCalculationsPerDayRow: + + return LatestCalculationsPerDayRow( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_version=calculation_version, + grid_area_code=grid_area_code, + start_of_day=start_of_day, + ) + + +def create_energy_results_data_spec( + calculation_id: str = DEFAULT_CALCULATION_ID, + calculation_type: CalculationTypeDataProductValue = CalculationTypeDataProductValue.WHOLESALE_FIXING, + calculation_period_start: datetime = DEFAULT_PERIOD_START, + calculation_period_end: datetime = DEFAULT_PERIOD_END, + calculation_version: int = DEFAULT_CALCULATION_VERSION, + result_id: str = DEFAULT_RESULT_ID, + grid_area_code: str = DEFAULT_GRID_AREA_CODE, + metering_point_type: MeteringPointTypeDataProductValue = DEFAULT_METERING_TYPE, + settlement_method: str = DEFAULT_SETTLEMENT_METHOD, + resolution: MeteringPointResolutionDataProductValue = DEFAULT_RESOLUTION, + quantity: Decimal = Decimal("1.005"), + quantity_unit: str = DEFAULT_QUANTITY_UNIT, + quantity_qualities: list[str] = DEFAULT_QUANTITY_QUALITIES, + from_date: datetime = DEFAULT_PERIOD_START, + to_date: datetime = DEFAULT_PERIOD_END, + energy_supplier_id: str = DEFAULT_ENERGY_SUPPLIER_ID, + balance_responsible_party_id: str = DEFAULT_BALANCE_RESPONSIBLE_PARTY_ID, +) -> EnergyTestDataSpec: + return EnergyTestDataSpec( + calculation_id=calculation_id, + calculation_type=calculation_type, + calculation_period_start=calculation_period_start, + calculation_period_end=calculation_period_end, + calculation_version=calculation_version, + result_id=result_id, + grid_area_code=grid_area_code, + metering_point_type=metering_point_type, + settlement_method=settlement_method, + resolution=resolution, + quantity=quantity, + quantity_unit=quantity_unit, + quantity_qualities=quantity_qualities, + from_date=from_date, + to_date=to_date, + energy_supplier_id=energy_supplier_id, + balance_responsible_party_id=balance_responsible_party_id, + ) diff --git a/source/settlement_report_python/tests/test_factories/energy_factory.py b/source/settlement_report_python/tests/test_factories/energy_factory.py new file mode 100644 index 0000000..9a54144 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/energy_factory.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta +from decimal import Decimal + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + MeteringPointTypeDataProductValue, + MeteringPointResolutionDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas.energy_v1 import energy_v1 +from settlement_report_job.infrastructure.wholesale.schemas.energy_per_es_v1 import ( + energy_per_es_v1, +) + + +@dataclass +class EnergyTestDataSpec: + """ + Data specification for creating energy test data. + Time series points are create between from_date and to_date with the specified resolution. + """ + + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_period_start: datetime + calculation_period_end: datetime + calculation_version: int + result_id: str + grid_area_code: str + metering_point_type: MeteringPointTypeDataProductValue + settlement_method: str + resolution: MeteringPointResolutionDataProductValue + quantity: Decimal + quantity_unit: str + quantity_qualities: list[str] + from_date: datetime + to_date: datetime + energy_supplier_id: str + balance_responsible_party_id: str + + +def _get_base_energy_rows_from_spec(data_spec: EnergyTestDataSpec): + rows = [] + resolution = ( + timedelta(hours=1) if data_spec.resolution == "PT1H" else timedelta(minutes=15) + ) + current_time = data_spec.from_date + while current_time < data_spec.to_date: + rows.append( + { + DataProductColumnNames.calculation_id: data_spec.calculation_id, + DataProductColumnNames.calculation_type: data_spec.calculation_type.value, + DataProductColumnNames.calculation_period_start: data_spec.calculation_period_start, + DataProductColumnNames.calculation_period_end: data_spec.calculation_period_end, + DataProductColumnNames.calculation_version: data_spec.calculation_version, + DataProductColumnNames.result_id: data_spec.result_id, + DataProductColumnNames.grid_area_code: data_spec.grid_area_code, + DataProductColumnNames.energy_supplier_id: data_spec.energy_supplier_id, + DataProductColumnNames.balance_responsible_party_id: data_spec.balance_responsible_party_id, + DataProductColumnNames.metering_point_type: data_spec.metering_point_type.value, + DataProductColumnNames.settlement_method: data_spec.settlement_method, + DataProductColumnNames.resolution: data_spec.resolution.value, + DataProductColumnNames.time: current_time, + DataProductColumnNames.quantity: data_spec.quantity, + DataProductColumnNames.quantity_unit: data_spec.quantity_unit, + DataProductColumnNames.quantity_qualities: data_spec.quantity_qualities, + } + ) + current_time += resolution + + return rows + + +def create_energy_per_es_v1( + spark: SparkSession, + data_spec: EnergyTestDataSpec, +) -> DataFrame: + rows = _get_base_energy_rows_from_spec(data_spec) + return spark.createDataFrame(rows, energy_per_es_v1) + + +def create_energy_v1(spark: SparkSession, data_spec: EnergyTestDataSpec) -> DataFrame: + rows = _get_base_energy_rows_from_spec(data_spec) + for row in rows: + del row[DataProductColumnNames.energy_supplier_id] + del row[DataProductColumnNames.balance_responsible_party_id] + + return spark.createDataFrame(rows, energy_v1) diff --git a/source/settlement_report_python/tests/test_factories/latest_calculations_factory.py b/source/settlement_report_python/tests/test_factories/latest_calculations_factory.py new file mode 100644 index 0000000..47ec119 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/latest_calculations_factory.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from datetime import datetime + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.schemas.latest_calculations_by_day_v1 import ( + latest_calculations_by_day_v1, +) + + +@dataclass +class LatestCalculationsPerDayRow: + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + grid_area_code: str + start_of_day: datetime + + +def create( + spark: SparkSession, + rows: LatestCalculationsPerDayRow | list[LatestCalculationsPerDayRow], +) -> DataFrame: + if not isinstance(rows, list): + rows = [rows] + + data_rows = [] + for row in rows: + data_rows.append( + { + DataProductColumnNames.calculation_id: row.calculation_id, + DataProductColumnNames.calculation_type: row.calculation_type.value, + DataProductColumnNames.calculation_version: row.calculation_version, + DataProductColumnNames.grid_area_code: row.grid_area_code, + DataProductColumnNames.start_of_day: row.start_of_day, + } + ) + + return spark.createDataFrame(data_rows, latest_calculations_by_day_v1) diff --git a/source/settlement_report_python/tests/test_factories/metering_point_periods_factory.py b/source/settlement_report_python/tests/test_factories/metering_point_periods_factory.py new file mode 100644 index 0000000..f28b482 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/metering_point_periods_factory.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from datetime import datetime + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + MeteringPointTypeDataProductValue, + MeteringPointResolutionDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.data_values.settlement_method import ( + SettlementMethodDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + metering_point_periods_v1, +) + + +@dataclass +class MeteringPointPeriodsRow: + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + metering_point_id: str + metering_point_type: MeteringPointTypeDataProductValue + settlement_method: SettlementMethodDataProductValue + grid_area_code: str + resolution: MeteringPointResolutionDataProductValue + from_grid_area_code: str | None + to_grid_area_code: str | None + parent_metering_point_id: str | None + energy_supplier_id: str + balance_responsible_party_id: str + from_date: datetime + to_date: datetime + + +def create( + spark: SparkSession, + rows: MeteringPointPeriodsRow | list[MeteringPointPeriodsRow], +) -> DataFrame: + if not isinstance(rows, list): + rows = [rows] + + row_list = [] + for row in rows: + row_list.append( + { + DataProductColumnNames.calculation_id: row.calculation_id, + DataProductColumnNames.calculation_type: row.calculation_type.value, + DataProductColumnNames.calculation_version: row.calculation_version, + DataProductColumnNames.metering_point_id: row.metering_point_id, + DataProductColumnNames.metering_point_type: row.metering_point_type.value, + DataProductColumnNames.settlement_method: ( + row.settlement_method.value if row.settlement_method else None + ), + DataProductColumnNames.grid_area_code: row.grid_area_code, + DataProductColumnNames.resolution: row.resolution.value, + DataProductColumnNames.from_grid_area_code: row.from_grid_area_code, + DataProductColumnNames.to_grid_area_code: row.to_grid_area_code, + DataProductColumnNames.parent_metering_point_id: row.parent_metering_point_id, + DataProductColumnNames.energy_supplier_id: row.energy_supplier_id, + DataProductColumnNames.balance_responsible_party_id: row.balance_responsible_party_id, + DataProductColumnNames.from_date: row.from_date, + DataProductColumnNames.to_date: row.to_date, + } + ) + + return spark.createDataFrame(row_list, metering_point_periods_v1) diff --git a/source/settlement_report_python/tests/test_factories/metering_point_time_series_factory.py b/source/settlement_report_python/tests/test_factories/metering_point_time_series_factory.py new file mode 100644 index 0000000..16c5072 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/metering_point_time_series_factory.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta + +from pyspark.sql import SparkSession, DataFrame +from pyspark.sql.types import DecimalType + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + MeteringPointTypeDataProductValue, + MeteringPointResolutionDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + metering_point_time_series_v1, +) + + +@dataclass +class MeteringPointTimeSeriesTestDataSpec: + """ + Data specification for creating a metering point time series test data. + Time series points are create between from_date and to_date with the specified resolution. + """ + + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + metering_point_id: str + metering_point_type: MeteringPointTypeDataProductValue + resolution: MeteringPointResolutionDataProductValue + grid_area_code: str + energy_supplier_id: str + from_date: datetime + to_date: datetime + quantity: DecimalType(18, 3) + + +def create( + spark: SparkSession, data_spec: MeteringPointTimeSeriesTestDataSpec +) -> DataFrame: + rows = [] + resolution = ( + timedelta(hours=1) + if data_spec.resolution == MeteringPointResolutionDataProductValue.HOUR + else timedelta(minutes=15) + ) + current_time = data_spec.from_date + while current_time < data_spec.to_date: + rows.append( + { + DataProductColumnNames.calculation_id: data_spec.calculation_id, + DataProductColumnNames.calculation_type: data_spec.calculation_type.value, + DataProductColumnNames.calculation_version: data_spec.calculation_version, + DataProductColumnNames.metering_point_id: data_spec.metering_point_id, + DataProductColumnNames.metering_point_type: data_spec.metering_point_type.value, + DataProductColumnNames.resolution: data_spec.resolution.value, + DataProductColumnNames.grid_area_code: data_spec.grid_area_code, + DataProductColumnNames.energy_supplier_id: data_spec.energy_supplier_id, + DataProductColumnNames.observation_time: current_time, + DataProductColumnNames.quantity: data_spec.quantity, + } + ) + current_time += resolution + + return spark.createDataFrame(rows, metering_point_time_series_v1) diff --git a/source/settlement_report_python/tests/test_factories/monthly_amounts_per_charge_factory.py b/source/settlement_report_python/tests/test_factories/monthly_amounts_per_charge_factory.py new file mode 100644 index 0000000..4d9bf7d --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/monthly_amounts_per_charge_factory.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal + +from pyspark.sql import SparkSession, DataFrame + + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, + ChargeTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas import ( + monthly_amounts_per_charge_v1, +) + + +@dataclass +class MonthlyAmountsPerChargeRow: + """ + Data specification for creating wholesale test data. + """ + + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + result_id: str + grid_area_code: str + energy_supplier_id: str + charge_code: str + charge_type: ChargeTypeDataProductValue + charge_owner_id: str + quantity_unit: str + is_tax: bool + currency: str + time: datetime + amount: Decimal + + +def create(spark: SparkSession, data_spec: MonthlyAmountsPerChargeRow) -> DataFrame: + row = { + DataProductColumnNames.calculation_id: data_spec.calculation_id, + DataProductColumnNames.calculation_type: data_spec.calculation_type.value, + DataProductColumnNames.calculation_version: data_spec.calculation_version, + DataProductColumnNames.result_id: data_spec.result_id, + DataProductColumnNames.grid_area_code: data_spec.grid_area_code, + DataProductColumnNames.energy_supplier_id: data_spec.energy_supplier_id, + DataProductColumnNames.charge_code: data_spec.charge_code, + DataProductColumnNames.charge_type: data_spec.charge_type.value, + DataProductColumnNames.charge_owner_id: data_spec.charge_owner_id, + DataProductColumnNames.quantity_unit: data_spec.quantity_unit, + DataProductColumnNames.is_tax: data_spec.is_tax, + DataProductColumnNames.currency: data_spec.currency, + DataProductColumnNames.time: data_spec.time, + DataProductColumnNames.amount: data_spec.amount, + } + + return spark.createDataFrame([row], monthly_amounts_per_charge_v1) diff --git a/source/settlement_report_python/tests/test_factories/time_series_points_csv_factory.py b/source/settlement_report_python/tests/test_factories/time_series_points_csv_factory.py new file mode 100644 index 0000000..c8ca340 --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/time_series_points_csv_factory.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from datetime import datetime, timedelta + +from pyspark.sql import SparkSession, DataFrame + +from settlement_report_job.domain.utils.csv_column_names import ( + CsvColumnNames, + EphemeralColumns, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + MeteringPointTypeDataProductValue, + MeteringPointResolutionDataProductValue, +) + +DEFAULT_METERING_POINT_TYPE = MeteringPointTypeDataProductValue.CONSUMPTION +DEFAULT_START_OF_DAY = datetime(2024, 1, 1, 23) +DEFAULT_GRID_AREA_CODES = ["804"] +DEFAULT_ENERGY_QUANTITY = 235.0 +DEFAULT_RESOLUTION = MeteringPointResolutionDataProductValue.HOUR +DEFAULT_NUM_METERING_POINTS = 10 +DEFAULT_NUM_DAYS_PER_METERING_POINT = 1 + + +@dataclass +class TimeSeriesPointsCsvTestDataSpec: + metering_point_type: MeteringPointTypeDataProductValue = DEFAULT_METERING_POINT_TYPE + start_of_day: datetime = DEFAULT_START_OF_DAY + grid_area_codes: list = field(default_factory=lambda: DEFAULT_GRID_AREA_CODES) + energy_quantity: float = DEFAULT_ENERGY_QUANTITY + resolution: MeteringPointResolutionDataProductValue = DEFAULT_RESOLUTION + num_metering_points: int = DEFAULT_NUM_METERING_POINTS + num_days_per_metering_point: int = DEFAULT_NUM_DAYS_PER_METERING_POINT + + +def create( + spark: SparkSession, + data_spec: TimeSeriesPointsCsvTestDataSpec, + add_grid_area_code_partitioning_column: bool = False, +) -> DataFrame: + rows = [] + counter = 0 + for grid_area_code in data_spec.grid_area_codes: + for _ in range(data_spec.num_metering_points): + counter += 1 + for i in range(data_spec.num_days_per_metering_point): + row = { + CsvColumnNames.energy_supplier_id: "1234567890123", + CsvColumnNames.metering_point_id: str(1000000000000 + counter), + CsvColumnNames.metering_point_type: data_spec.metering_point_type.value, + EphemeralColumns.grid_area_code_partitioning: grid_area_code, + CsvColumnNames.time: data_spec.start_of_day + timedelta(days=i), + } + if add_grid_area_code_partitioning_column: + row[EphemeralColumns.grid_area_code_partitioning] = grid_area_code + + for j in range( + 25 + if data_spec.resolution.value + == MeteringPointResolutionDataProductValue.HOUR + else 100 + ): + row[f"{CsvColumnNames.energy_quantity}{j + 1}"] = ( + data_spec.energy_quantity + ) + rows.append(row) + + df = spark.createDataFrame(rows) + return df diff --git a/source/settlement_report_python/tests/test_factories/total_monthly_amounts_factory.py b/source/settlement_report_python/tests/test_factories/total_monthly_amounts_factory.py new file mode 100644 index 0000000..ffd783d --- /dev/null +++ b/source/settlement_report_python/tests/test_factories/total_monthly_amounts_factory.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass +from datetime import datetime +from decimal import Decimal + +from pyspark.sql import SparkSession, DataFrame + + +from settlement_report_job.infrastructure.wholesale.column_names import ( + DataProductColumnNames, +) +from settlement_report_job.infrastructure.wholesale.data_values import ( + CalculationTypeDataProductValue, +) +from settlement_report_job.infrastructure.wholesale.schemas.total_monthly_amounts_v1 import ( + total_monthly_amounts_v1, +) + + +@dataclass +class TotalMonthlyAmountsRow: + """ + Data specification for creating wholesale test data. + """ + + calculation_id: str + calculation_type: CalculationTypeDataProductValue + calculation_version: int + result_id: str + grid_area_code: str + energy_supplier_id: str + charge_owner_id: str + currency: str + time: datetime + amount: Decimal + + +def create(spark: SparkSession, data_spec: TotalMonthlyAmountsRow) -> DataFrame: + row = { + DataProductColumnNames.calculation_id: data_spec.calculation_id, + DataProductColumnNames.calculation_type: data_spec.calculation_type.value, + DataProductColumnNames.calculation_version: data_spec.calculation_version, + DataProductColumnNames.result_id: data_spec.result_id, + DataProductColumnNames.grid_area_code: data_spec.grid_area_code, + DataProductColumnNames.energy_supplier_id: data_spec.energy_supplier_id, + DataProductColumnNames.charge_owner_id: data_spec.charge_owner_id, + DataProductColumnNames.currency: data_spec.currency, + DataProductColumnNames.time: data_spec.time, + DataProductColumnNames.amount: data_spec.amount, + } + + assert row[DataProductColumnNames.calculation_id] is not None + assert row[DataProductColumnNames.calculation_type] is not None + assert row[DataProductColumnNames.calculation_version] is not None + assert row[DataProductColumnNames.result_id] is not None + assert row[DataProductColumnNames.grid_area_code] is not None + assert row[DataProductColumnNames.energy_supplier_id] is not None + + assert row[DataProductColumnNames.currency] is not None + assert row[DataProductColumnNames.time] is not None + assert row[DataProductColumnNames.amount] is not None + + return spark.createDataFrame([row], total_monthly_amounts_v1) diff --git a/source/settlement_report_python/tests/utils.py b/source/settlement_report_python/tests/utils.py new file mode 100644 index 0000000..c51c2b1 --- /dev/null +++ b/source/settlement_report_python/tests/utils.py @@ -0,0 +1,107 @@ +# Copyright 2020 Energinet DataHub A/S +# +# Licensed under the Apache License, Version 2.0 (the "License2"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +from datetime import timedelta, datetime +from zoneinfo import ZoneInfo + +from settlement_report_job.domain.utils.market_role import MarketRole +from settlement_report_job.domain.utils.report_data_type import ReportDataType +from settlement_report_job.infrastructure.report_name_factory import ( + MarketRoleInFileName, +) +from settlement_report_job.entry_points.job_args.settlement_report_args import ( + SettlementReportArgs, +) +from settlement_report_job.infrastructure import paths + + +class Dates: + JAN_1ST = datetime(2023, 12, 31, 23) + JAN_2ND = datetime(2024, 1, 1, 23) + JAN_3RD = datetime(2024, 1, 2, 23) + JAN_4TH = datetime(2024, 1, 3, 23) + JAN_5TH = datetime(2024, 1, 4, 23) + JAN_6TH = datetime(2024, 1, 5, 23) + JAN_7TH = datetime(2024, 1, 6, 23) + JAN_8TH = datetime(2024, 1, 7, 23) + JAN_9TH = datetime(2024, 1, 8, 23) + + +DEFAULT_TIME_ZONE = "Europe/Copenhagen" + + +def cleanup_output_path(settlement_reports_output_path: str) -> None: + if os.path.exists(settlement_reports_output_path): + shutil.rmtree(settlement_reports_output_path) + os.makedirs(settlement_reports_output_path) + + +def get_actual_files( + report_data_type: ReportDataType, args: SettlementReportArgs +) -> list[str]: + path = paths.get_report_output_path(args) + if not os.path.isdir(path): + return [] + + return [ + f + for f in os.listdir(path) + if os.path.isfile(os.path.join(path, f)) + and f.startswith(_get_file_prefix(report_data_type)) + and f.endswith(".csv") + ] + + +def _get_file_prefix(report_data_type) -> str: + if report_data_type == ReportDataType.TimeSeriesHourly: + return "TSSD60" + elif report_data_type == ReportDataType.TimeSeriesQuarterly: + return "TSSD15" + elif report_data_type == ReportDataType.MeteringPointPeriods: + return "MDMP" + elif report_data_type == ReportDataType.ChargeLinks: + return "CHARGELINK" + elif report_data_type == ReportDataType.ChargePricePoints: + return "CHARGEPRICE" + elif report_data_type == ReportDataType.EnergyResults: + return "RESULTENERGY" + elif report_data_type == ReportDataType.WholesaleResults: + return "RESULTWHOLESALE" + elif report_data_type == ReportDataType.MonthlyAmounts: + return "RESULTMONTHLY" + raise NotImplementedError(f"Report data type {report_data_type} is not supported.") + + +def get_start_date(period_start: datetime) -> str: + time_zone_info = ZoneInfo(DEFAULT_TIME_ZONE) + return period_start.astimezone(time_zone_info).strftime("%d-%m-%Y") + + +def get_end_date(period_end: datetime) -> str: + time_zone_info = ZoneInfo(DEFAULT_TIME_ZONE) + return (period_end.astimezone(time_zone_info) - timedelta(days=1)).strftime( + "%d-%m-%Y" + ) + + +def get_market_role_in_file_name( + requesting_actor_market_role: MarketRole, +) -> str | None: + if requesting_actor_market_role == MarketRole.ENERGY_SUPPLIER: + return MarketRoleInFileName.ENERGY_SUPPLIER + elif requesting_actor_market_role == MarketRole.GRID_ACCESS_PROVIDER: + return MarketRoleInFileName.GRID_ACCESS_PROVIDER + + return None