Skip to content

Commit

Permalink
Reduce overhead of Function call
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 21, 2024
1 parent 8593f34 commit 862f158
Showing 1 changed file with 61 additions and 63 deletions.
124 changes: 61 additions & 63 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,15 +540,42 @@ def __contains__(self, item):
self._value = ValueAttribute()
self._container = ContainerAttribute()

# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage)
update_storage = [
container
for inp, container in zip(
self.maker.expanded_inputs, input_storage, strict=True
)
if inp.update is not None
]
# Updates are the last inner outputs that are not returned by Function.__call__
self.n_returned_outputs = len(self.output_storage) - len(update_storage)

# Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
self.update_input_storage: tuple[int, Container] = ()
if getattr(vm, "need_update_inputs", True):
self.update_input_storage = tuple(
zip(
range(self.n_returned_outputs, len(output_storage)),
update_storage,
strict=True,
)
)

# This is used only when `vm.need_update_inputs` is `False`, because
# we're using one of the VM objects and it is putting updates back into
# the input containers all by itself.
self.n_returned_outputs = len(self.output_storage) - sum(
inp.update is not None for inp in self.maker.expanded_inputs
)
# In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
# After the call, we want to erase (some of) these references, to allow Python to GC them if unused
# Required input containers are the non-default inputs, must always be provided again, so we GC them
self.clear_input_output_storage_data = [
container.storage for container in input_storage if container.required
]
if getattr(vm, "allow_gc", False):
# If the vm allows us to GC the outputs, do it
self.clear_input_output_storage_data += [
container.storage
for container, variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=True
)
if variable.owner is not None # Not a constant output
]

for node in self.maker.fgraph.apply_nodes:
if isinstance(node.op, HasInnerGraph):
Expand Down Expand Up @@ -747,7 +774,7 @@ def checkSV(sv_ori, sv_rpl):
elif isinstance(profile, str):
profile = pytensor.compile.profiling.ProfileStats(message=profile)

f_cpy = maker.__class__(
f_cpy = type(maker)(
inputs=ins,
outputs=outs,
fgraph=fg_cpy,
Expand All @@ -765,6 +792,8 @@ def checkSV(sv_ori, sv_rpl):
# check that.
accept_inplace=True,
no_fgraph_prep=True,
output_keys=maker.output_keys,
name=name,
).create(input_storage, storage_map=new_storage_map)

for in_ori, in_cpy, ori, cpy in zip(
Expand Down Expand Up @@ -796,9 +825,6 @@ def checkSV(sv_ori, sv_rpl):
in_cpy.variable = swap[in_ori.variable]

f_cpy.trust_input = self.trust_input
f_cpy.unpack_single = self.unpack_single
f_cpy.name = name
f_cpy.maker.fgraph.name = name
return f_cpy

def _restore_defaults(self):
Expand All @@ -808,7 +834,7 @@ def _restore_defaults(self):
value = value.storage[0]
self[i] = value

def __call__(self, *args, **kwargs):
def __call__(self, *args, output_subset=None, **kwargs):
"""
Evaluates value of a function on given arguments.
Expand Down Expand Up @@ -842,7 +868,6 @@ def __call__(self, *args, **kwargs):
if profile:
t0 = time.perf_counter()

output_subset = kwargs.pop("output_subset", None)
if output_subset is not None:
warnings.warn("output_subset is deprecated.", FutureWarning)
if self.output_keys is not None:
Expand Down Expand Up @@ -993,37 +1018,18 @@ def __call__(self, *args, **kwargs):
if outputs is None:
outputs = [x.data for x in self.output_storage]

# Remove internal references to required inputs.
# These cannot be re-used anyway.
for arg_container in input_storage:
if arg_container.required:
arg_container.storage[0] = None

# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if getattr(self.vm, "allow_gc", False):
# strict=False because we are in a hot loop
for o_container, o_variable in zip(
self.output_storage, self.maker.fgraph.outputs, strict=False
):
if o_variable.owner is not None:
# this node is the variable of computation
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None

if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]
# Set updates and filter them out from the returned outputs
for i, input_storage in self.update_input_storage:
input_storage.data = outputs[i]
outputs = outputs[: self.n_returned_outputs]

# Remove input and output values from storage data
for storage_data in self.clear_input_output_storage_data:
storage_data[0] = None

# Put default values back in the storage
self._restore_defaults()
if self.defaults:
self._restore_defaults()

if profile:
dt_call = time.perf_counter() - t0
Expand All @@ -1039,25 +1045,21 @@ def __call__(self, *args, **kwargs):

if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
return outputs[0]
else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)

if output_subset is None:
# strict=False because we are in a hot loop
return dict(zip(self.output_keys, outputs, strict=False))
else:
return {
self.output_keys[index]: outputs[index]
for index in output_subset
}
if output_subset is not None:
outputs = [outputs[i] for i in output_subset]

if output_subset is None:
return outputs
if self.output_keys is None:
if self.unpack_single:
[out] = outputs
return out
else:
return [outputs[i] for i in output_subset]
return outputs
else:
output_keys = self.output_keys
if output_subset is not None:
output_keys = [output_keys[i] for i in output_subset]
return dict(zip(output_keys, outputs, strict=True))

value = property(
lambda self: self._value,
Expand Down Expand Up @@ -1091,10 +1093,6 @@ def get_shared(self):
"""
return [i.variable for i in self.maker.inputs if i.implicit]

def sync_shared(self):
# NOTE: sync was needed on old gpu backend
pass

def dprint(self, **kwargs):
"""Debug print itself
Expand Down

0 comments on commit 862f158

Please sign in to comment.