Skip to content

Commit

Permalink
Added Patching Support for torch.nn.Sequential Containers (#88)
Browse files Browse the repository at this point in the history
* Added Support For torch.nn.Sequential Containers to memtorch.mn.Module.patch_model
* Added Support For torch.nn.Sequential Containers to memtorch.bh.nonideality.apply_nonidealities
  • Loading branch information
coreylammie authored Sep 10, 2021
1 parent bb8836e commit e6825bb
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 197 deletions.
223 changes: 85 additions & 138 deletions memtorch/bh/nonideality/NonIdeality.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,73 +42,64 @@ def apply_nonidealities(model, non_idealities, **kwargs):
torch.nn.Module
Patched instance.
"""
for i, (name, m) in enumerate(list(model.named_modules())):
if type(m) in supported_module_parameters.values():
if "cpu" not in memtorch.__version__ and len(name.split(".")) > 1:
name = name.split(".")[1]

def apply_patched_module(model, patched_module, name, m):
if name.__contains__("."):
sequence_container, module = name.split(".")
if module.isdigit():
module = int(module)
model._modules[sequence_container][module] = patched_module
else:
setattr(
model._modules[sequence_container],
"%s" % module,
patched_module,
)
else:
model._modules[name] = patched_module

return model

for _, (name, m) in enumerate(list(model.named_modules())):
if type(m) in supported_module_parameters.values():
for non_ideality in non_idealities:
if non_ideality == NonIdeality.FiniteConductanceStates:
required(
kwargs,
["conductance_states"],
"memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates",
)
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_finite_conductance_states(
m, kwargs["conductance_states"]
),
)
else:
setattr(
model,
name,
apply_finite_conductance_states(
m, kwargs["conductance_states"]
),
)
model = apply_patched_module(
model,
apply_finite_conductance_states(
m, kwargs["conductance_states"]
),
name,
m,
)
elif non_ideality == NonIdeality.DeviceFaults:
required(
kwargs,
["lrs_proportion", "hrs_proportion", "electroform_proportion"],
"memtorch.bh.nonideality.NonIdeality.DeviceFaults",
)
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_device_faults(
m,
kwargs["lrs_proportion"],
kwargs["hrs_proportion"],
kwargs["electroform_proportion"],
),
)
else:
setattr(
model,
name,
apply_device_faults(
m,
kwargs["lrs_proportion"],
kwargs["hrs_proportion"],
kwargs["electroform_proportion"],
),
)
model = apply_patched_module(
model,
apply_device_faults(
m,
kwargs["lrs_proportion"],
kwargs["hrs_proportion"],
kwargs["electroform_proportion"],
),
name,
m,
)
elif non_ideality == NonIdeality.NonLinear:
if "simulate" in kwargs:
if kwargs["simulate"] == True:
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_non_linear(m, simulate=True),
)
else:
setattr(model, name, apply_non_linear(m, simulate=True))
model = apply_patched_module(
model, apply_non_linear(m, simulate=True), name, m
)
else:
required(
kwargs,
Expand All @@ -119,28 +110,17 @@ def apply_nonidealities(model, non_idealities, **kwargs):
],
"memtorch.bh.nonideality.NonIdeality.NonLinear",
)
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
)
else:
setattr(
model,
name,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
)
model = apply_patched_module(
model,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
name,
m,
)
else:
required(
kwargs,
Expand All @@ -151,84 +131,51 @@ def apply_nonidealities(model, non_idealities, **kwargs):
],
"memtorch.bh.nonideality.NonIdeality.NonLinear",
)
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
)
else:
setattr(
model,
name,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
)
model = apply_patched_module(
model,
apply_non_linear(
m,
kwargs["sweep_duration"],
kwargs["sweep_voltage_signal_amplitude"],
kwargs["sweep_voltage_signal_frequency"],
),
name,
m,
)
elif non_ideality == NonIdeality.Endurance:
required(
kwargs,
["x", "endurance_model", "endurance_model_kwargs"],
"memtorch.bh.nonideality.Endurance",
)
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_endurance_model(
layer=m,
x=kwargs["x"],
endurance_model=kwargs["endurance_model"],
**kwargs["endurance_model_kwargs"]
),
)
else:
setattr(
model,
name,
apply_endurance_model(
layer=m,
x=kwargs["x"],
endurance_model=kwargs["endurance_model"],
**kwargs["endurance_model_kwargs"]
),
)
model = apply_patched_module(
model,
apply_endurance_model(
layer=m,
x=kwargs["x"],
endurance_model=kwargs["endurance_model"],
**kwargs["endurance_model_kwargs"]
),
name,
m,
)
elif non_ideality == NonIdeality.Retention:
required(
kwargs,
["time", "retention_model", "retention_model_kwargs"],
"memtorch.bh.nonideality.Retention",
)
if hasattr(model, "module"):
setattr(
model.module,
name,
apply_retention_model(
layer=m,
time=kwargs["time"],
retention_model=kwargs["retention_model"],
**kwargs["retention_model_kwargs"]
),
)
else:
setattr(
model,
name,
apply_retention_model(
layer=m,
time=kwargs["time"],
retention_model=kwargs["retention_model"],
**kwargs["retention_model_kwargs"]
),
)
model = apply_patched_module(
model,
apply_retention_model(
layer=m,
time=kwargs["time"],
retention_model=kwargs["retention_model"],
**kwargs["retention_model_kwargs"]
),
name,
m,
)

return model

Expand Down
92 changes: 34 additions & 58 deletions memtorch/mn/Module.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,69 +92,45 @@ def patch_model(
Patched torch.nn.Module.
"""
model.map = mapping_routine
for i, (name, m) in enumerate(list(model.named_modules())):
for _, (name, m) in enumerate(list(model.named_modules())):
for parameter in module_parameters_to_patch:
if isinstance(m, parameter):
if "cpu" not in memtorch.__version__ and len(name.split(".")) > 1:
name = name.split(".")[1]

parameter_type = str(type(m))
patch = supported_module_parameters.get(parameter_type)
assert (
parameter_type in supported_module_parameters
), "Patching of %s is not currently supported" % type(m)
if hasattr(model, "module"):
setattr(
model.module,
name,
patch(
m,
memristor_model=memristor_model,
memristor_model_params=memristor_model_params,
mapping_routine=mapping_routine,
transistor=transistor,
programming_routine=programming_routine,
programming_routine_params=programming_routine_params,
p_l=p_l,
scheme=scheme,
tile_shape=tile_shape,
max_input_voltage=max_input_voltage,
scaling_routine=scaling_routine,
scaling_routine_params=scaling_routine_params,
ADC_resolution=ADC_resolution,
ADC_overflow_rate=ADC_overflow_rate,
quant_method=quant_method,
use_bindings=use_bindings,
verbose=verbose,
**kwargs
),
)
patched_module = patch(
m,
memristor_model=memristor_model,
memristor_model_params=memristor_model_params,
mapping_routine=mapping_routine,
transistor=transistor,
programming_routine=programming_routine,
programming_routine_params=programming_routine_params,
p_l=p_l,
scheme=scheme,
tile_shape=tile_shape,
max_input_voltage=max_input_voltage,
scaling_routine=scaling_routine,
scaling_routine_params=scaling_routine_params,
ADC_resolution=ADC_resolution,
ADC_overflow_rate=ADC_overflow_rate,
quant_method=quant_method,
use_bindings=use_bindings,
verbose=verbose,
**kwargs
)
if name.__contains__("."):
sequence_container, module = name.split(".")
if module.isdigit():
module = int(module)
model._modules[sequence_container][module] = patched_module
else:
setattr(
model._modules[sequence_container],
"%s" % module,
patched_module,
)
else:
setattr(
model,
name,
patch(
m,
memristor_model=memristor_model,
memristor_model_params=memristor_model_params,
mapping_routine=mapping_routine,
transistor=transistor,
programming_routine=programming_routine,
programming_routine_params=programming_routine_params,
p_l=p_l,
scheme=scheme,
tile_shape=tile_shape,
max_input_voltage=max_input_voltage,
scaling_routine=scaling_routine,
scaling_routine_params=scaling_routine_params,
ADC_resolution=ADC_resolution,
ADC_overflow_rate=ADC_overflow_rate,
quant_method=quant_method,
use_bindings=use_bindings,
verbose=verbose,
**kwargs
),
)
model._modules[name] = patched_module

def tune_(self, tune_kwargs=None):
"""Method to tune a memristive layer.
Expand Down
2 changes: 1 addition & 1 deletion memtorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.3"
__version__ = "1.1.3-cpu"

0 comments on commit e6825bb

Please sign in to comment.