Skip to content

Commit

Permalink
pyfabm variable clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
jornbr committed Oct 16, 2024
1 parent d5c3bdb commit 0265bde
Showing 1 changed file with 70 additions and 60 deletions.
130 changes: 70 additions & 60 deletions src/pyfabm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import ctypes
import re
import logging
import enum
from typing import (
MutableMapping,
Optional,
Tuple,
Iterable,
Union,
Callable,
Any,
Mapping,
Sequence,
TypeVar,
Expand Down Expand Up @@ -455,6 +455,15 @@ def log(msg: str):
INTERIOR_DEPENDENCY = 7
HORIZONTAL_DEPENDENCY = 8
SCALAR_DEPENDENCY = 9


class DataType(enum.IntEnum):
REAL = 1
INTEGER = 2
LOGICAL = 3
STRING = 4


ATTRIBUTE_LENGTH = 256

DISPLAY_MINIMUM = 0
Expand Down Expand Up @@ -561,19 +570,19 @@ def __init__(self, model: "Model", variable_pointer: ctypes.c_void_p):
self.model = model
self._pvariable = variable_pointer

def __getitem__(self, key: str):
def __getitem__(self, key: str) -> Union[float, int, bool]:
typecode = self.model.fabm.variable_get_property_type(
self._pvariable, key.encode("ascii")
)
if typecode == 1:
if typecode == DataType.REAL:
return self.model.fabm.variable_get_real_property(
self._pvariable, key.encode("ascii"), -1.0
)
elif typecode == 2:
elif typecode == DataType.INTEGER:
return self.model.fabm.variable_get_integer_property(
self._pvariable, key.encode("ascii"), 0
)
elif typecode == 3:
elif typecode == DataType.LOGICAL:
return (
self.model.fabm.variable_get_logical_property(
self._pvariable, key.encode("ascii"), 0
Expand All @@ -587,29 +596,12 @@ class Variable(object):
def __init__(
self,
model: "Model",
name: Optional[str] = None,
units: Optional[str] = None,
long_name: Optional[str] = None,
name: str,
units: str,
long_name: str,
path: Optional[str] = None,
variable_pointer: Optional[ctypes.c_void_p] = None,
):
self.model = model
self._pvariable = variable_pointer
if variable_pointer:
strname = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
strunits = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
strlong_name = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
self.model.fabm.variable_get_metadata(
variable_pointer, ATTRIBUTE_LENGTH, strname, strunits, strlong_name
)
if name is None:
name = strname.value.decode("ascii")
if units is None:
units = strunits.value.decode("ascii")
if long_name is None:
long_name = strlong_name.value.decode("ascii")
self.properties = VariableProperties(self.model, self._pvariable)

self.name = name
self.units = units
self.units_unicode = None if units is None else createPrettyUnit(units)
Expand All @@ -618,49 +610,67 @@ def __init__(

@property
def long_path(self) -> str:
if self._pvariable is None:
return self.long_name
return self.long_name

@property
def output_name(self) -> str:
"""Name suitable for output (alphanumeric characters and underscores only)"""
return re.sub(r"\W", "_", self.name)

@property
def options(self) -> Optional[Sequence]:
pass

def __repr__(self) -> str:
return f"<{self.name}={self.value}>"


class VariableFromPointer(Variable):
def __init__(self, model: "Model", variable_pointer: ctypes.c_void_p):
strname = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
strunits = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
strlong_name = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
model.fabm.variable_get_metadata(
variable_pointer, ATTRIBUTE_LENGTH, strname, strunits, strlong_name
)
name = strname.value.decode("ascii")
units = strunits.value.decode("ascii")
long_name = strlong_name.value.decode("ascii")

super().__init__(model, name, units, long_name)

self._pvariable = variable_pointer
self.properties = VariableProperties(self.model, self._pvariable)

@property
def long_path(self) -> str:
strlong_name = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
self.model.fabm.variable_get_long_path(
self._pvariable, ATTRIBUTE_LENGTH, strlong_name
)
return strlong_name.value.decode("ascii")

@property
def output_name(self) -> str:
"""Name suitable for output
(alphanumeric characters and underscores only)"""
return re.sub(r"\W", "_", self.name)

@property
def missing_value(self) -> Optional[float]:
"""Value that indicates missing data, for instance, on land.
`None` if not set."""
if self._pvariable is not None:
return self.model.fabm.variable_get_missing_value(self._pvariable)

@property
def options(self) -> Optional[Sequence]:
pass
return self.model.fabm.variable_get_missing_value(self._pvariable)

def getRealProperty(self, name, default=-1.0) -> float:
return self.model.fabm.variable_get_real_property(
self._pvariable, name.encode("ascii"), default
)

def __repr__(self) -> str:
return f"<{self.name}={self.value}>"


class Dependency(Variable):
class Dependency(VariableFromPointer):
def __init__(
self,
model: "Model",
variable_pointer: ctypes.c_void_p,
shape: Tuple[int],
link_function: Callable[[ctypes.c_void_p, ctypes.c_void_p, np.ndarray], None],
):
super().__init__(model, variable_pointer=variable_pointer)
super().__init__(model, variable_pointer)
self._is_set = False
self._link_function = link_function
self._shape = shape
Expand Down Expand Up @@ -689,11 +699,11 @@ def required(self) -> bool:
return self.model.fabm.variable_is_required(self._pvariable) != 0


class StateVariable(Variable):
class StateVariable(VariableFromPointer):
def __init__(
self, model: "Model", variable_pointer: ctypes.c_void_p, data: np.ndarray
):
super().__init__(model, variable_pointer=variable_pointer)
super().__init__(model, variable_pointer)
self._data = data

@property
Expand Down Expand Up @@ -723,15 +733,15 @@ def no_precipitation_dilution(self) -> bool:
)


class DiagnosticVariable(Variable):
class DiagnosticVariable(VariableFromPointer):
def __init__(
self,
model: "Model",
variable_pointer: ctypes.c_void_p,
index: int,
horizontal: bool,
):
super().__init__(model, variable_pointer=variable_pointer)
super().__init__(model, variable_pointer)
self._data = None
self._horizontal = horizontal
self._index = index + 1
Expand Down Expand Up @@ -767,7 +777,7 @@ def __init__(
index: int,
units: Optional[str] = None,
long_name: Optional[str] = None,
type: Optional[int] = None,
type: Optional[DataType] = None,
has_default: bool = False,
):
super().__init__(model, name, units, long_name)
Expand All @@ -777,22 +787,22 @@ def __init__(

def _get_value(self, *, default: bool = False):
default = 1 if default else 0
if self._type == 1:
if self._type == DataType.REAL:
return self.model.fabm.get_real_parameter(
self.model.pmodel, self._index, default
)
elif self._type == 2:
elif self._type == DataType.INTEGER:
return self.model.fabm.get_integer_parameter(
self.model.pmodel, self._index, default
)
elif self._type == 3:
elif self._type == DataType.LOGICAL:
return (
self.model.fabm.get_logical_parameter(
self.model.pmodel, self._index, default
)
!= 0
)
elif self._type == 4:
elif self._type == DataType.STRING:
result = ctypes.create_string_buffer(ATTRIBUTE_LENGTH)
self.model.fabm.get_string_parameter(
self.model.pmodel, self._index, default, ATTRIBUTE_LENGTH, result
Expand All @@ -804,22 +814,22 @@ def value(self) -> Union[float, int, bool, str]:
return self._get_value()

@value.setter
def value(self, value):
def value(self, value: Union[float, int, bool, str]):
settings = self.model._save_state()

if self._type == 1:
if self._type == DataType.REAL:
self.model.fabm.set_real_parameter(
self.model.pmodel, self.name.encode("ascii"), value
)
elif self._type == 2:
elif self._type == DataType.INTEGER:
self.model.fabm.set_integer_parameter(
self.model.pmodel, self.name.encode("ascii"), value
)
elif self._type == 3:
elif self._type == DataType.LOGICAL:
self.model.fabm.set_logical_parameter(
self.model.pmodel, self.name.encode("ascii"), value
)
elif self._type == 4:
elif self._type == DataType.STRING:
self.model.fabm.set_string_parameter(
self.model.pmodel, self.name.encode("ascii"), value.encode("ascii")
)
Expand Down Expand Up @@ -911,7 +921,7 @@ def clear(self):
self._lookup_ci = None


class Coupling(Variable):
class Coupling(VariableFromPointer):
def __init__(self, model: "Model", index: int):
self._ptarget = ctypes.c_void_p()
self._psource = ctypes.c_void_p()
Expand All @@ -921,7 +931,7 @@ def __init__(self, model: "Model", index: int):
ctypes.byref(self._psource),
ctypes.byref(self._ptarget),
)
super().__init__(model, variable_pointer=self._psource)
super().__init__(model, self._psource)
self._options = None

@property
Expand Down

0 comments on commit 0265bde

Please sign in to comment.