-
-
Notifications
You must be signed in to change notification settings - Fork 624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling engine to run single epochs #1371
Comments
@alxlampe interesting idea, thanks ! If I understand correctly, what you would like to achieve, it can be done with some events filtering: from ignite.engine import Engine, Events
engine = Engine(lambda e, b: None)
def once_at_start(engine, _):
return engine.state.epoch == 0
def once_at_end(engine, _):
return engine.state.epoch == 10
engine.add_event_handler(Events.STARTED(once_at_start), lambda x: print("started"))
engine.add_event_handler(Events.EPOCH_STARTED, lambda x: print("{} epoch started".format(x.state.epoch)))
engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("{} epoch completed".format(x.state.epoch)))
engine.add_event_handler(Events.COMPLETED(once_at_end), lambda x: print("completed"))
engine.run([0, 1, 2], max_epochs=3)
print("Do something else")
engine.run([0, 1, 2], max_epochs=6)
print("Do something else")
engine.run([0, 1, 2], max_epochs=10) gives
Can it be generalized to your use-case or it is too urgly and specific. What do you think ?
Maybe, this requirement is not satisfied with above code. |
@vfdev-5 Thanks for your reply. def once_at_end(engine, _):
return engine.state.epoch == engine.state.max_epochs What I think is, that the event handlers that (should) occur exactly one time during a run (like I have a more detailed use case, where this won't work (I think). In this example. I have nested engines:
So i have this nesting of engines:
While Then, if Below, you find the complete example of the use case and how I solved it. I added event handlers to all three engines to print messages and the indent of the message corresponds the nesting depth. The result of the code example below is the following output:
Some notes:
import time
from typing import Callable
from typing import Iterable
from typing import Optional
from ignite._utils import _to_hours_mins_secs
from ignite.engine import Engine
from ignite.engine import Events
from ignite.engine import State
class IterableEngine(Engine):
def _internal_run(self, return_generator) -> State:
self.should_terminate = self.should_terminate_single_epoch = False
self._init_timers(self.state)
try:
start_time = time.time()
self._fire_event(Events.STARTED)
while self.state.epoch < self.state.max_epochs and not self.should_terminate:
self.state.epoch += 1
self._fire_event(Events.EPOCH_STARTED)
if self._dataloader_iter is None:
self._setup_engine()
time_taken = self._run_once_on_dataset()
# time is available for handlers but must be update after fire
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
handlers_start_time = time.time()
if self.should_terminate:
self._fire_event(Events.TERMINATE)
else:
self._fire_event(Events.EPOCH_COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(
"Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self.state.epoch, hours, mins, secs)
)
if self.should_terminate:
break
if return_generator:
yield self.state
time_taken = time.time() - start_time
# time is available for handlers but must be update after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info("Engine run complete. Time taken: %02d:%02d:%02d" % (hours, mins, secs))
except BaseException as e:
self._dataloader_iter = None
self.logger.error("Engine run is terminating due to exception: %s.", str(e))
self._handle_exception(e)
self._dataloader_iter = None
return self.state
def run(
self,
data: Iterable,
max_epochs: Optional[int] = None,
epoch_length: Optional[int] = None,
seed: Optional[int] = None,
return_generator: Optional[bool] = False
) -> State:
"""Runs the `process_function` over the passed data.
Engine has a state and the following logic is applied in this function:
- At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed` if provided. A timer for
total and per-epoch time is initialized when Events.STARTED is handled.
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
provided, state is kept and used in the function.
- If state is defined and engine is "done" (no iterations to run until `max_epochs`), a new state is defined.
- If state is defined, engine is NOT "done", then input arguments if provided override defined state.
Args:
data (Iterable): Collection of batches allowing repeated iteration (e.g., list or `DataLoader`).
max_epochs (int, optional): Max epochs to run for (default: None).
If a new state should be created (first run or run again from ended engine), it's default value is 1.
If run is resuming from a state, provided `max_epochs` will be taken into account and should be larger
than `engine.state.max_epochs`.
epoch_length (int, optional): Number of iterations to count as one epoch. By default, it can be set as
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
determined as the iteration on which data iterator raises `StopIteration`.
This argument should not change if run is resuming from a state.
seed (int, optional): Deprecated argument. Please, use `torch.manual_seed` or
:meth:`~ignite.utils.manual_seed`.
Returns:
State: output state.
Note:
User can dynamically preprocess input batch at :attr:`~ignite.engine.events.Events.ITERATION_STARTED` and
store output batch in `engine.state.batch`. Latter is passed as usually to `process_function` as argument:
.. code-block:: python
trainer = ...
@trainer.on(Events.ITERATION_STARTED)
def switch_batch(engine):
engine.state.batch = preprocess_batch(engine.state.batch)
"""
if seed is not None:
warnings.warn(
"Argument seed is deprecated. It will be removed in 0.5.0. "
"Please, use torch.manual_seed or ignite.utils.manual_seed"
)
if not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")
if self.state.max_epochs is not None:
# Check and apply overridden parameters
if max_epochs is not None:
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be larger than the start epoch "
"defined in the state: {} vs {}".format(max_epochs, self.state.epoch)
)
self.state.max_epochs = max_epochs
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, given {} vs {}".format(
epoch_length, self.state.epoch_length
)
)
if self.state.max_epochs is None or self._is_done(self.state):
# Create new state
if max_epochs is None:
max_epochs = 1
if epoch_length is None:
epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
self.state.iteration = 0
self.state.epoch = 0
self.state.max_epochs = max_epochs
self.state.epoch_length = epoch_length
self.logger.info("Engine run starting with max_epochs={}.".format(max_epochs))
else:
self.logger.info(
"Engine run resuming from iteration {}, epoch {} until {} epochs".format(
self.state.iteration, self.state.epoch, self.state.max_epochs
)
)
self.state.dataloader = data
return self._internal_run(return_generator=return_generator)
class ChildEngine(IterableEngine):
"""Engine that is attached to a parent engine, that runs infinite number epochs and until the parent engine
terminates or completed"""
def __init__(self, process_function: Callable):
super().__init__(process_function)
self.epoch_iterator = None
def _iterate_epoch(self, parent_engine):
next(self.epoch_iterator)
def _update_max_epochs(self):
"""Makes engine never reach max_epochs"""
self.state.max_epochs = self.state.epoch + 1
def _setup_epoch_iterator(self, data):
self.add_event_handler(Events.EPOCH_STARTED, self._update_max_epochs) # ensures child engine never completes
self.epoch_iterator = self.run(data, max_epochs=1, return_generator=True)
def attach_to_parent_engine(self, parent_engine, data):
# setup iterator object on parent engines event started
parent_engine.add_event_handler(Events.STARTED, self._setup_epoch_iterator, data)
# runs one epoch of child engine on epoch completed
parent_engine.add_event_handler(Events.EPOCH_COMPLETED, self._iterate_epoch)
# stream all events of termination of the parent engine to the child engine
parent_engine.add_event_handler(Events.COMPLETED, lambda engine: self.fire_event(Events.COMPLETED))
parent_engine.add_event_handler(Events.TERMINATE, lambda engine: self.fire_event(Events.TERMINATE))
parent_engine.add_event_handler(Events.EXCEPTION_RAISED,
lambda engine, e: self._fire_event(Events.EXCEPTION_RAISED, e))
if __name__ == '__main__':
parent_engine = Engine(lambda e, b: 0.)
child_engine = ChildEngine(lambda e, b: 0.)
child_child_engine = ChildEngine(lambda e, b: 0.)
dummy_data = [1, 2, 3]
child_engine.attach_to_parent_engine(parent_engine, dummy_data)
child_child_engine.attach_to_parent_engine(child_engine, dummy_data)
# add event handlers for demonstration
parent_engine.add_event_handler(Events.STARTED, lambda x: print("EVENT: STARTED"))
parent_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("EVENT: EPOCH_COMPLETED"))
parent_engine.add_event_handler(Events.COMPLETED, lambda x: print("EVENT: COMPLETED"))
child_engine.add_event_handler(Events.STARTED, lambda x: print("\tEVENT: STARTED"))
child_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("\tEVENT: EPOCH_COMPLETED"))
child_engine.add_event_handler(Events.COMPLETED, lambda x: print("\tEVENT: COMPLETED"))
child_child_engine.add_event_handler(Events.STARTED, lambda x: print("\t\t EVENT: STARTED"))
child_child_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda x: print("\t\t EVENT: EPOCH_COMPLETED"))
child_child_engine.add_event_handler(Events.COMPLETED, lambda x: print("\t\t EVENT: COMPLETED"))
parent_engine.run(dummy_data, max_epochs=2) |
@alxlampe thanks for details, I like what we would like to achieve and I think this can be interesting and important for other application ! I'm not 100% sure about the way to implement it and I think Let me think about what can be done with the API
|
@vfdev-5 I did some experiments with the example from my last post and there are some issues: if return_generator:
yield self.state This does not work as expected. Even if Second, if I have two doubly nested engines, (two of
I've created a gist here with an engine that implements the following methods:
This engine is attachable or nestable with The output is the following:
Another case would be, if the parent engine only serves to drive the run of it's child engines. Then one could create the parent engine with a process function that runs epochs of the child engines: def _process_child_engines(engine, child_engines):
for child_engine in child_engines:
child_engine.run_epoch() This could also solve the problem in #1384 (Step 2 in the discussion). The advantage would be that How could it look like: child_engines = []
for k in range(num_k):
k_fold_data_loader = get_k_fold_data(k, num_k)
# setup_engine is a function that sets up the training process for one engine with all it's metrics loggers etc.
engine = setup_engine(data=k_fold_data_loader) # function takes arguments like engine.run, stores k_fold_data_loader
child_engines.append(engine)
serving_engine = Engine(_process_child_engines) # process function from above
# attach childs to have one time events synchronized with serving engine
for child_engine in child_engines:
child_engine.attach_to_parent_engine(serving_engine)
# add metrics summarizer
metrics_summarizer = MetricsSummarizer(child_engines, ...) # add childs to summarizer
metrics_summarizer.attach(serving_engine) # adds event handler to epoch_completed to summarize metrics after each epoch
serving_engine.run(data=child_engines, max_epochs=100) Some more info about my use case: I hope that gives some insights and shows, that it is a useful feature :-) Another use case that fits nicely into this framework is to run an experiment with multiple seeds at the same time and compute metrics summary (mean, min, max) on the fly. |
@alxlampe thanks for the update and detailed info about your use-case ! Yes, this could be definitely an awesome and useful feature !
Yes, I also remarked that while playing around. Currently, I think to rewrite Engine as a generator and wrap it such that there is no BC break. Having a generator would be also interesting in case of Federated Learning (see #1382 ) where we need to stop at a number of iterations and not epochs...
Yes, this can be helpful too. Maybe we can open another FR request for that. Let's discuss first about the Engine and what can be done. I think about two things now on how to recode Engine:
I wonder if the second change could cover in an acceptable way your need to split the loop on epoch. |
I think #1382 could be solved, if for _ in range(10):
engine.run_iteration() Obviously, this is not a clean solution. But what do you think of something more general like engine.run(data, max_epochs, epoch_length, seed)
engine.run_epochs(data=None, num_epochs=1, epoch_length=None, seed=None)
engine.run_iterations(data=None, num_iterations=1, epoch_length=None, seed=None) This would give the options, to assign a new dataset and to control how many epochs/iterations should be iterated. # setup from the main program
engine.setup_run()
engine.run_iterations(num_iterations=10)
# do something in the main program
engine.run_iterations(num_iterations=10)
# ... And a nice feature would be to create generators from that. |
@alxlampe there is a sort of major issue with
Yes, this API is very interesting and such decoupling can provide another level of flexibility that Engine could be inserted in any kind of loops 👍 . Still the API should be discussed and all side effect to understand.
Seems like mixing both behaviours is not that trivial. Not yet sure, but doing something like that can work out maybe as_generator = True
def bar(i):
if as_generator:
yield None
return None
def foo(n):
print("start")
for i in range(n):
print("-", i)
if as_generator:
yield from bar(i)
else:
bar(i)
print("-- ")
print("final return")
return i Main point is that we could |
🚀 Feature
Problem
I am using multiple engines in a nested way. That means, that if e.g. the main engine fires
Events.EPOCH_COMPLETED
, another child engine is attached to this event and shall run only one epoch. A solution would be to run the child engine withengine.run(max_epochs=1)
but then, the engine fires setup and teardown events likeEvents.STARTED
andEvents.COMPLETED
each time I callengine.run(max_epochs=1)
even though those events are for the purpose to only be fired one time, as far as I understand.Since my child engine must setup and teardown things, I could attach event handlers to the main engine, but the handlers I want to attach do not know that a main engine exists. The handlers shouldn't have any access to the main engine.
Solution
I need some functionality that the engine can do the following (This is just an example with a bad but possible way of implementing this):
Instead of calling a function, one could create an iterable object from
engine.run
and get the same behavior in a nicer way:Or one can use loops:
The output is:
I added the code at the bottom where I subclass from
Engine
and overload the_internal_run
method with a copy of the original method and added one line, where I add theyield
statement. You can execute it and it outputs the example.To switch between the actual and this behavior, one could put
yield
into an if statement and pass an additional argument toengine.run
, e.g.engine.run(max_epochs=3, return_generator=True
) or set a flag of the engine to enable this functionality.What do you think?
Code:
The text was updated successfully, but these errors were encountered: