Skip to content

Commit

Permalink
finalized spreading to multiindex
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinstadler committed Aug 30, 2024
1 parent 5ad13d9 commit 91d53f5
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 85 deletions.
7 changes: 6 additions & 1 deletion pymrio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
import sys

from pymrio.core.fileio import *
from pymrio.core.mriosystem import Extension, IOSystem, concate_extension, extension_convert
from pymrio.core.mriosystem import (
Extension,
IOSystem,
concate_extension,
extension_convert,
)
from pymrio.tools.ioclass import ClassificationData, get_classification
from pymrio.tools.iodownloader import (
download_eora26,
Expand Down
17 changes: 12 additions & 5 deletions pymrio/core/mriosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3246,6 +3246,7 @@ def remove_extension(self, ext):

return self


def extension_convert(
*extensions,
df_map,
Expand All @@ -3267,7 +3268,7 @@ def extension_convert(
Parameters
----------
extensions : list of extensions
Extensions to convert. All extensions passed must
have an index structure (index names) ase described in df_map.
Expand Down Expand Up @@ -3373,7 +3374,7 @@ def extension_convert(
ignore_columns.append(extension_col_name)

gather = []

for ext in extensions:
gather.append(
ext.convert(
Expand All @@ -3389,10 +3390,16 @@ def extension_convert(

result_ext = concate_extension(*gather, name=new_extension_name)


for df, df_name in zip(result_ext.get_DataFrame(data=True, with_unit=True), result_ext.get_DataFrame(data=False, with_unit=True)):
for df, df_name in zip(
result_ext.get_DataFrame(data=True, with_unit=True),
result_ext.get_DataFrame(data=False, with_unit=True),
):
if df_name == "unit":
setattr(result_ext, df_name, df.groupby(level=df.index.names).agg(lambda x: ",".join(set(x))))
setattr(
result_ext,
df_name,
df.groupby(level=df.index.names).agg(lambda x: ",".join(set(x))),
)
else:
setattr(result_ext, df_name, df.groupby(level=df.index.names).agg(agg_func))

Expand Down
52 changes: 38 additions & 14 deletions pymrio/tools/ioutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,7 @@ def convert(
Extension for extensions:
# TODO: check if in the other docstring and then remove
extension ... extension name
unit_orig ... the original unit (optional, for double check with the unit)
unit_new ... the new unit to be set for the extension
Expand Down Expand Up @@ -1120,7 +1121,9 @@ def convert(
raise ValueError(f"Column {col} contains more then one '__'")
if bridge.orig not in df_map.columns:
raise ValueError(f"Column {bridge.orig} not in df_map")
elif bridge.orig not in df_orig.index.names:
elif (bridge.orig not in df_orig.index.names) and (
bridge.orig not in df_orig.columns.names
):
raise ValueError(f"Column {bridge.orig} not in df_orig")
else:
bridges.append(bridge)
Expand Down Expand Up @@ -1167,30 +1170,54 @@ def convert(

# renaming part, checks if the old name (bridge.orig) is in the current index
# and renames by the new one (bridge.new)

already_renamed = dict()
for bridge in bridges:
# encountering a bridge with the same orig name but which should
# lead to two new index levels
if bridge.orig in already_renamed.keys():
# duplicate the index level
df_collected.reset_index(level=already_renamed[bridge.orig].new, inplace=True)
df_collected[bridge.new] = df_cur_map.index.get_level_values(bridge.raw)[0]

if df_collected.index.name is None:
df_collected.set_index(already_renamed[bridge.orig].new, drop=True, append=False, inplace=True)
_index_order = list(df_collected.index.names)
df_collected.reset_index(
level=already_renamed[bridge.orig].new, inplace=True
)
df_collected[bridge.new] = df_cur_map.index.get_level_values(
bridge.raw
)[0]
if (len(df_collected.index.names) == 1) and (
df_collected.index.names[0] is None
):
df_collected.set_index(
already_renamed[bridge.orig].new,
drop=True,
append=False,
inplace=True,
)
else:
df_collected.set_index(already_renamed[bridge.orig].new, drop=True, append=True, inplace=True)
df_collected.set_index(
already_renamed[bridge.orig].new,
drop=True,
append=True,
inplace=True,
)
df_collected.set_index(bridge.new, drop=True, append=True, inplace=True)
df_collected.index = df_collected.index.reorder_levels(
_index_order + [bridge.new]
)

continue

for idx_old_names in df_collected.index.names:
if bridge.orig in idx_old_names:
# rename the index names
if isinstance(df_collected.index, pd.MultiIndex):
df_collected.index = df_collected.index.set_names( bridge.new, level=idx_old_names)
df_collected.index = df_collected.index.set_names(
bridge.new, level=idx_old_names
)
else:
df_collected.index = df_collected.index.set_names( bridge.new, level=None)
df_collected.index = df_collected.index.set_names(
bridge.new, level=None
)

# rename the actual index values
df_collected.reset_index(level=bridge.new, inplace=True)
Expand All @@ -1200,8 +1227,6 @@ def convert(
df_collected.loc[:, bridge.new] = df_collected.loc[
:, bridge.new
].str.replace(pat=old_row_name, repl=new_row_name, regex=True)
# CONT: Make a test case/method where a matching line gets extended into more index columns
# CONT: Ensure that the spread keeps the order as in the original mapping

# put the index back
if df_collected.index.name is None:
Expand Down Expand Up @@ -1241,8 +1266,7 @@ def convert(
]
try:
all_result = all_result.reorder_levels(new_index + orig_index_not_bridged)
except TypeError:
# case where there is only one index level
except TypeError: # case where there is only one index level
pass

return all_result.groupby(by=all_result.index.names).agg(agg_func)
130 changes: 83 additions & 47 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def test_extension_convert(fix_testmrio):
],
)
tt_pre.pre_calc = tt_pre.emissions.convert(
df_map, extension_name="emissions_new_pre_calc"
df_map, new_extension_name="emissions_new_pre_calc"
)
tt_pre.calc_all()

Expand Down Expand Up @@ -608,7 +608,7 @@ def test_extension_convert(fix_testmrio):
tt_post.calc_all()

tt_post.post_calc = tt_post.emissions.convert(
df_map, extension_name="emissions_new_post_calc"
df_map, new_extension_name="emissions_new_post_calc"
)

pdt.assert_series_equal(
Expand Down Expand Up @@ -640,7 +640,7 @@ def test_extension_convert(fix_testmrio):


def test_extension_convert_function(fix_testmrio):
"""Testing the convert function for a list of extensions """
"""Testing the convert function for a list of extensions"""

tt_pre = fix_testmrio.testmrio.copy()

Expand All @@ -656,55 +656,91 @@ def test_extension_convert_function(fix_testmrio):
"unit_new",
],
data=[
["Emissions", "emis.*", "air|water", "total_sum_tonnes", "total", 1e-3, "kg", "t"],
["Emissions", "emission_type2", "water", "water_emissions", "water", 1000, "kg", "g"],
[
"Emissions",
"emis.*",
"air|water",
"total_sum_tonnes",
"total",
1e-3,
"kg",
"t",
],
[
"Emissions",
"emission_type2",
"water",
"water_emissions",
"water",
1000,
"kg",
"g",
],
],
)

# CONT: Something wrong with setting the index to a multiindex when compartment is passed
# Next steps: run this in interprester (with autoreload) and set breakpoint in extension_convert
# Seems to be in gather, but after that in the aggregation or Concatenate we get a problem

# x = tt_pre.extension_convert(df_map, extension_name="emissions_new_pre_calc")
# x = tt_pre.extension_convert(df_map, new_extension_name="emissions_new_pre_calc")

# Doing two time the same extension
ext_double = pymrio.extension_convert(tt_pre.emissions, tt_pre.emissions, df_map=df_map_double, new_extension_name="emissions_new_pre_calc")

assert ext_double.unit.loc["total_sum_tonnes", "unit"] == "t"
assert ext_double.unit.loc["water_emissions", "unit"] == "g"

pdt.assert_series_equal(
ext_double.F.loc["total_sum_tonnes"],
tt_pre.emissions.F.sum(axis=0) * 1e-3 * 2,
check_names=False,
)

pdt.assert_series_equal(
ext_double.F.loc["water_emissions"],
tt_pre.emissions.F.loc["emission_type2",:].iloc[0,:] * 1000 * 2,
check_names=False,
)


tt_pre.emission_new = ext_double

df_map_add_across = pd.DataFrame(
columns=[
"extension",
"stressor",
"compartment",
"total__stressor",
"factor",
"unit_orig",
"unit_new",
],
data=[
["Emissions", "emission_type2", ".*", "water", 1, "kg", "kg"],
["emission_new_pre_calc", "water_emissions", ".*", "water", 1E-3, "g", "kg"],
],
ext_double = pymrio.extension_convert(
tt_pre.emissions,
tt_pre.emissions,
df_map=df_map_double,
new_extension_name="emissions_new_pre_calc",
)

ext_across = pymrio.extension_convert(tt_pre.emissions, ext_double, df_map=df_map_add_across, new_extension_name="add_across")
# TODO: check return type and update test
# assert ext_double.unit.loc["total_sum_tonnes", "unit"] == "t"
# assert ext_double.unit.loc["water_emissions", "unit"] == "g"

# pdt.assert_series_equal(
# ext_double.F.loc["total_sum_tonnes"],
# tt_pre.emissions.F.sum(axis=0) * 1e-3 * 2,
# check_names=False,
# )
#
# pdt.assert_series_equal(
# ext_double.F.loc["water_emissions"],
# tt_pre.emissions.F.loc["emission_type2", :].iloc[0, :] * 1000 * 2,
# check_names=False,
# )
#
# tt_pre.emission_new = ext_double
#
# df_map_add_across = pd.DataFrame(
# columns=[
# "extension",
# "stressor",
# "compartment",
# "total__stressor",
# "factor",
# "unit_orig",
# "unit_new",
# ],
# data=[
# ["Emissions", "emission_type2", ".*", "water", 1, "kg", "kg"],
# [
# "emission_new_pre_calc",
# "water_emissions",
# ".*",
# "water",
# 1e-3,
# "g",
# "kg",
# ],
# ],
# )
#
# ext_across = pymrio.extension_convert(
# tt_pre.emissions,
# ext_double,
# df_map=df_map_add_across,
# new_extension_name="add_across",
# )

# CONT:
# make a second extensions are check running over 2
Expand Down Expand Up @@ -756,35 +792,35 @@ def test_extension_convert_test_unit_fail(fix_testmrio):

with pytest.raises(ValueError):
fix_testmrio.testmrio.emissions.convert(
df_fail1, extension_name="emissions_new"
df_fail1, new_extension_name="emissions_new"
)

with pytest.raises(ValueError):
fix_testmrio.testmrio.emissions.convert(
df_fail2, extension_name="emissions_new", unit_column_orig="unit_emis"
df_fail2, new_extension_name="emissions_new", unit_column_orig="unit_emis"
)

with pytest.raises(ValueError):
fix_testmrio.testmrio.emissions.convert(
df_wo_unit, extension_name="emissions_new", unit_column_orig="unit_emis"
df_wo_unit, new_extension_name="emissions_new", unit_column_orig="unit_emis"
)

wounit = fix_testmrio.testmrio.emissions.convert(
df_wo_unit,
extension_name="emissions_new",
new_extension_name="emissions_new",
unit_column_orig=None,
unit_column_new=None,
)
assert wounit.unit == None

with pytest.raises(ValueError):
fix_testmrio.testmrio.emissions.convert(
df_new_unit, extension_name="emissions_new", unit_column_new="unit_new"
df_new_unit, new_extension_name="emissions_new", unit_column_new="unit_new"
)

newunit = fix_testmrio.testmrio.emissions.convert(
df_new_unit,
extension_name="emissions_new",
new_extension_name="emissions_new",
unit_column_orig=None,
unit_column_new="set_unit",
)
Expand Down
Loading

0 comments on commit 91d53f5

Please sign in to comment.