Skip to content

Commit

Permalink
ensure DataFrame columns are ordered consistently with EmpircalGrid.v…
Browse files Browse the repository at this point in the history
…ariables
  • Loading branch information
zebengberg committed Sep 8, 2023
1 parent 4d6e641 commit bceb00d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions pycontrails/ext/empirical_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _query_data(self) -> pd.DataFrame:
"""Query ``self.params["data"]`` for the source aircraft type."""

# Take only the columns that are not already in the source
columns = sorted(set(self.variables).difference(self.source))
columns = [v for v in self.variables if v not in self.source]
data = self.params["data"]
if data is None:
raise ValueError("No data provided")
Expand All @@ -118,7 +118,7 @@ def _query_data(self) -> pd.DataFrame:
# Round to flight levels
data.loc[:, "altitude_ft"] = data["altitude_ft"].round(-3)

return data[["altitude_ft"] + columns].drop(columns=["aircraft_type"])
return data[["altitude_ft", *columns]].drop(columns=["aircraft_type"])

def _sample(self, altitude_ft: npt.NDArray[np.float_]) -> None:
"""Sample the data and update the source."""
Expand All @@ -127,7 +127,9 @@ def _sample(self, altitude_ft: npt.NDArray[np.float_]) -> None:
grouped = df.groupby("altitude_ft")
rng = self.params["random_state"]

other = {k: np.full_like(altitude_ft, np.nan) for k in df}
source = self.source
for k in df:
source[k] = np.full_like(altitude_ft, np.nan)

for altitude, group in grouped:
filt = altitude_ft == altitude
Expand All @@ -137,6 +139,4 @@ def _sample(self, altitude_ft: npt.NDArray[np.float_]) -> None:

sample = group.sample(n=n, replace=True, random_state=rng)
for k, v in sample.items():
other[k][filt] = v

self.source.update(other) # type: ignore[arg-type]
source[k][filt] = v

0 comments on commit bceb00d

Please sign in to comment.