Skip to content

Commit

Permalink
require aircraft_type and wingspan in EmpiricalGrid params
Browse files Browse the repository at this point in the history
  • Loading branch information
zebengberg committed Sep 1, 2023
1 parent 4fcf06d commit 4d6e641
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
### Internals

- Add `compute_tau_cirrus_in_model_init` parameter to `CocipParams`. This controls whether to compute the cirrus optical depth in `Cocip.__init__` or `Cocip.eval`. When set to `"auto"` (the default), the `tau_cirrus` is computed in `Cocip.__init__` if and only if the `met` parameter is dask-backed.
- Change data requirements for the `EmpiricalGrid` aircraft performance model.

## v0.47.0

Expand Down
60 changes: 34 additions & 26 deletions pycontrails/ext/empirical_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from pandas.core.groupby import DataFrameGroupBy

from pycontrails.core.aircraft_performance import (
AircraftPerformanceGrid,
Expand All @@ -25,12 +24,14 @@ class EmpiricalGridParams(AircraftPerformanceGridParams):
#: Random state to use for sampling
random_state: int | np.random.Generator | None = None

#: Empirical data to use for sampling. Must have columns:
#: Empirical data to use for sampling. Must include columns:
#: - altitude_ft
#: - true_airspeed
#: - aircraft_mass
#: - fuel_flow
#: - engine_efficiency
#: - aircraft_type
#: - wingspan
#: If None, an error will be raised at runtime.
data: pd.DataFrame | None = None

Expand All @@ -53,7 +54,14 @@ class EmpiricalGrid(AircraftPerformanceGrid):
source: GeoVectorDataset
default_params = EmpiricalGridParams

variables = "true_airspeed", "aircraft_mass", "fuel_flow", "engine_efficiency"
variables = (
"true_airspeed",
"aircraft_mass",
"fuel_flow",
"engine_efficiency",
"aircraft_type",
"wingspan",
)

@overload
def eval(self, source: GeoVectorDataset, **params: Any) -> GeoVectorDataset:
Expand Down Expand Up @@ -87,42 +95,40 @@ def eval(
self.require_source_type(GeoVectorDataset)

altitude_ft = self.source.altitude_ft.copy()
altitude_ft.round(-3, out=altitude_ft)
dtype = altitude_ft.dtype
altitude_ft.round(-3, out=altitude_ft) # round to flight levels

# Take only the columns that are not already in the source
columns = sorted(set(self.variables).difference(self.source))

# Initialize the variables in the source with NaNs
self.source.update({k: np.full(len(self.source), np.nan, dtype=dtype) for k in columns})

# Fill the variables with sampled data
self._sample(altitude_ft, columns)
# Fill the source with sampled data at each flight level
self._sample(altitude_ft)

return self.source

def _get_grouped(self, columns: list[str]) -> DataFrameGroupBy:
"""Group the data by altitude and return the groupby object."""
def _query_data(self) -> pd.DataFrame:
"""Query ``self.params["data"]`` for the source aircraft type."""

df = self.params["data"]
if df is None:
# Take only the columns that are not already in the source
columns = sorted(set(self.variables).difference(self.source))
data = self.params["data"]
if data is None:
raise ValueError("No data provided")

try:
df = df[["altitude_ft"] + columns]
except KeyError as e:
raise ValueError(f"Column {e} not in data") from e
aircraft_type = self.source.attrs.get("aircraft_type", self.params["aircraft_type"])
data = data.query(f"aircraft_type == '{aircraft_type}'")
assert not data.empty, f"No data for aircraft type: {aircraft_type}"

# Round to flight levels
df["altitude_ft"] = df["altitude_ft"].round(-3)
return df.groupby("altitude_ft")
data.loc[:, "altitude_ft"] = data["altitude_ft"].round(-3)

return data[["altitude_ft"] + columns].drop(columns=["aircraft_type"])

def _sample(self, altitude_ft: npt.NDArray[np.float_], columns: list[str]) -> None:
def _sample(self, altitude_ft: npt.NDArray[np.float_]) -> None:
"""Sample the data and update the source."""

grouped = self._get_grouped(columns) # move to init if the groupby is expensive
df = self._query_data()
grouped = df.groupby("altitude_ft")
rng = self.params["random_state"]

other = {k: np.full_like(altitude_ft, np.nan) for k in df}

for altitude, group in grouped:
filt = altitude_ft == altitude
n = filt.sum()
Expand All @@ -131,4 +137,6 @@ def _sample(self, altitude_ft: npt.NDArray[np.float_], columns: list[str]) -> No

sample = group.sample(n=n, replace=True, random_state=rng)
for k, v in sample.items():
self.source[k][filt] = v
other[k][filt] = v

self.source.update(other) # type: ignore[arg-type]

0 comments on commit 4d6e641

Please sign in to comment.