Skip to content

Commit

Permalink
Fortran serialization codegen tests (#198)
Browse files Browse the repository at this point in the history
Co-authored-by: samkellerhals <[email protected]>
  • Loading branch information
ChristopherBignamini and samkellerhals authored May 9, 2023
1 parent 3fbfadb commit c21a8ae
Show file tree
Hide file tree
Showing 4 changed files with 767 additions and 13 deletions.
11 changes: 6 additions & 5 deletions pyutils/src/icon4py/f2ser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ def _parse_derived_types(self, derived_types: dict) -> dict:
MissingDerivedTypeError: If the type definition for a derived type could not be found in any of the dependency files.
"""
derived_type_defs = {}
for dep in self.dependencies:
parsed = crack(dep)
for block in parsed["body"]:
if block["block"] == "type":
derived_type_defs[block["name"]] = block["vars"]
if self.dependencies:
for dep in self.dependencies:
parsed = crack(dep)
for block in parsed["body"]:
if block["block"] == "type":
derived_type_defs[block["name"]] = block["vars"]

for _, subroutine_vars in derived_types.items():
for _, intent_vars in subroutine_vars.items():
Expand Down
19 changes: 17 additions & 2 deletions pyutils/tests/f2ser/test_f2ser_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,32 @@ def test_cli_no_deps(no_deps_source_file, outfile, cli):
assert result.exit_code == 0


def test_cli_wrong_deps(diffusion_granule, samples_path, outfile, cli):
inp = str(diffusion_granule)
deps = [str(samples_path / "wrong_derived_types_example.f90")]
args = [inp, outfile, "-d", *deps]
result = cli.invoke(main, args)
assert result.exit_code == 2
assert "Invalid value for '--dependencies' / '-d'" in result.output


def test_cli_missing_deps(diffusion_granule, outfile, cli):
inp = str(diffusion_granule)
args = [inp, outfile]
result = cli.invoke(main, args)
assert isinstance(result.exception, MissingDerivedTypeError)


def test_cli_wrong_source(outfile, cli):
inp = str("foo.90")
args = [inp, outfile]
result = cli.invoke(main, args)
assert "Invalid value for 'GRANULE_PATH'" in result.output


def test_cli_missing_source(not_existing_diffusion_granule, outfile, cli):
inp = str(not_existing_diffusion_granule)
args = [inp, outfile]
result = cli.invoke(main, args)
error_search = result.stdout.find("Invalid value for 'GRANULE_PATH'")
assert error_search != -1
assert isinstance(result.exception, SystemExit)
assert "Invalid value for 'GRANULE_PATH'" in result.output
94 changes: 88 additions & 6 deletions pyutils/tests/f2ser/test_f2ser_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,100 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import pytest

from icon4py.f2ser.deserialise import ParsedGranuleDeserialiser
from icon4py.f2ser.parse import GranuleParser
from icon4py.liskov.codegen.serialisation.generate import (
SerialisationCodeGenerator,
)
from icon4py.liskov.codegen.shared.types import GeneratedCode


def test_deserialiser_diffusion_codegen(diffusion_granule, diffusion_granule_deps):
parser = GranuleParser(diffusion_granule, diffusion_granule_deps)
parsed = parser()
deserialiser = ParsedGranuleDeserialiser(parsed, directory=".", prefix="test")
interface = deserialiser()
generator = SerialisationCodeGenerator(interface)
generated = generator()
parsed = GranuleParser(diffusion_granule, diffusion_granule_deps)()
interface = ParsedGranuleDeserialiser(parsed, directory=".", prefix="test")()
generated = SerialisationCodeGenerator(interface)()
assert len(generated) == 3


@pytest.fixture
def expected_no_deps_serialization_directives():
serialization_directives = [
GeneratedCode(
startln=12,
source="\n"
' !$ser init directory="." prefix="test"\n'
"\n"
" !$ser savepoint no_deps_init_in\n"
"\n"
" PRINT *, 'Serializing a=a'\n"
"\n"
" !$ser data a=a\n"
"\n"
" PRINT *, 'Serializing b=b'\n"
"\n"
" !$ser data b=b",
),
GeneratedCode(
startln=14,
source="\n"
" !$ser savepoint no_deps_init_out\n"
"\n"
" PRINT *, 'Serializing c=c'\n"
"\n"
" !$ser data c=c\n"
"\n"
" PRINT *, 'Serializing b=b'\n"
"\n"
" !$ser data b=b",
),
GeneratedCode(
startln=20,
source="\n"
" !$ser savepoint no_deps_run_in\n"
"\n"
" PRINT *, 'Serializing a=a'\n"
"\n"
" !$ser data a=a\n"
"\n"
" PRINT *, 'Serializing b=b'\n"
"\n"
" !$ser data b=b",
),
GeneratedCode(
startln=22,
source="\n"
" !$ser savepoint no_deps_run_out\n"
"\n"
" PRINT *, 'Serializing c=c'\n"
"\n"
" !$ser data c=c\n"
"\n"
" PRINT *, 'Serializing b=b'\n"
"\n"
" !$ser data b=b",
),
]
return serialization_directives


def test_deserialiser_directives_no_deps_codegen(
no_deps_source_file, expected_no_deps_serialization_directives
):
parsed = GranuleParser(no_deps_source_file)()
interface = ParsedGranuleDeserialiser(parsed, directory=".", prefix="test")()
generated = SerialisationCodeGenerator(interface)()
assert generated == expected_no_deps_serialization_directives


def test_deserialiser_directives_diffusion_codegen(
diffusion_granule, diffusion_granule_deps, samples_path
):
parsed = GranuleParser(diffusion_granule, diffusion_granule_deps)()
interface = ParsedGranuleDeserialiser(parsed, directory=".", prefix="test")()
generated = SerialisationCodeGenerator(interface)()
reference_savepoint = (
samples_path / "expected_diffusion_granule_savepoint.f90"
).read_text()
assert generated[0].source == reference_savepoint.rstrip()
Loading

0 comments on commit c21a8ae

Please sign in to comment.