Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop.State incomplete #35

Closed
albertz opened this issue Oct 21, 2021 · 5 comments
Closed

Loop.State incomplete #35

albertz opened this issue Oct 21, 2021 · 5 comments
Assignees
Milestone

Comments

@albertz
Copy link
Member

albertz commented Oct 21, 2021

(Initial design via #16.)

The initial is not really handled yet.

Also, we need to think about shape and dtype.

Maybe we also want to pass shape and dtype on to RETURNN, to simplify the recurrent template construction.
Currently, this would be via out_type.
When we have rwth-i6/returnn#706, maybe this would be another way, by out_shape or so.

shape also must be able to handle dynamic dims which could change in each iteration (e.g. for cum_concat).

We also don't handle nested state (e.g. LayerState) yet but this is really required.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2021

Note that the example in #31 does not work yet:

lstm = Lstm(...)
with Loop() as loop:
  ...
  out, loop.state.lstm = lstm(x, state=loop.state.lstm)

Because nested state (LayerState) is not handled yet.

I also wonder, how should the initial state logic look like for this example with nested state?

Layers with state by convention (or by rule) return a tuple with LayerState as last element, and get a state argument of the same type. Should they also be required to define a default initial_state which returns also a structure of the same type?

How does this look like then? Like this?

loop.state.lstm = State(initial=lstm.initial_state())

Should we maybe decouple the usage of State to explicitly define the initial loop state and the State object as being the internal logic for Loop and _StateHolder? The only time the user currently sees and uses State is to define the initial state. Otherwise it's always only internal.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2021

Thinking further, I think this example (out, loop.state.lstm = lstm(x, state=loop.state.lstm)) cannot really work in any case when loop.state.lstm is not defined explicitly yet (initial state set) because when accessing loop.state.lstm, what should this return?

Currently it returns a LayerRef to prev:lstm which is wrong because we expect a nested LayerState here and not just a single layer.

Returning None to trigger the default init or so would also not make sense, as this is the state used in every iteration, not just for the first iteration. So we should maybe even explicitly throw an exception if None is passed here, when lstm operates on a single frame level.

Maybe it should return a special UnknownInitialState object or so? This would not be a layer ref. It can possibly represent nested structure. The behavior is a bit unclear. Should all modules and layers need to handle this explicitly? This would be bad. Should it just emulate any possible nested structure? Is this even possible? Not too hacky?

Or should we just always require the initial state to be explicitly assigned before we use this? So like this?

loop.state.lstm = State(initial=lstm.initial_state())
out, loop.state.lstm = lstm(x, state=loop.state.lstm)

@albertz
Copy link
Member Author

albertz commented Nov 2, 2021

I think we need the explicit assignment of the initial state in any case. This is maybe also not so bad anyway because we usually prefer explicit over implicit (eg: #16, #31).

So the question is just how that looks like. The code example from above would be for our current design. But we can still change this. It's maybe not the most intuitive way yet.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2021

When we make setting the initial state mandatory, I think we don't need to think about shape and dtype anymore. Also dynamic shapes should be supported via the standard dim tag logic.

@albertz
Copy link
Member Author

albertz commented Jan 5, 2022

Setting the initial state is mandatory now.

Also, Loop.State became hidden and is not intended for direct use anymore.

The code looks like:

loop = nn.Loop(axis=axis)
loop.state.lstm = self.lstm.default_initial_state()
with loop:
  x_ = loop.unstack(x)
  y_, loop.state.lstm = self.lstm(x_, state=loop.state.lstm, axis=nn.single_step_dim)
  y = loop.stack(y_)

(via)

@albertz albertz closed this as completed Jan 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants