Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rec design for recurrent definitions / loops #16

Closed
albertz opened this issue Aug 5, 2021 · 53 comments
Closed

Rec design for recurrent definitions / loops #16

albertz opened this issue Aug 5, 2021 · 53 comments
Milestone

Comments

@albertz
Copy link
Member

albertz commented Aug 5, 2021

This issue is to collect some thoughts on the recurrent loops design, which wraps the RecLayer with an explicit subnetwork in RETURNN.

The main goal is to have this very straight-forward and simple for the user. We can abstract away from the underlying RecLayer if that makes things easier. We can also extend RETURNN itself if needed.

Related is also #6 (rec prev mechanism), and this issue here might fix/resolve #6, although not necessarily.

This also needs some mechanism for unrolling/unstacking, i.e. when we iterate over input x with some time-axis, i.e. to get x[t]. This is rwth-i6/returnn#552.


To define a loop like this pseudo Python code:

x  # given, shape {batch, time, dim}
h = Zeros({batch,dim})()  # initial state, shape {batch,dim}
out = []
for t in range(x.max_seq_len):
  x_lin = Linear(dim)(x[t])
  h_prev = h
  h = Linear(dim)(x_lin + h_prev)
  out.append(h)

h  # final state
out  # shape {time, batch, dim}

Current design:

There is Loop() which can be used in a with context, which corresponds to the for-loop in the example, or in general to a while-loop. Like:

with Loop() as loop:
  ...

There is State() which can define hidden state (for any module or any code).

The example above can be written as:

h = State({batch, dim}, initial=0)
with Loop() as loop:  # this introduces a new loop
  x_t = loop.unstack(x)  # shape {batch, dim}

  x_lin = Linear(dim)(x_t)
  h_prev = h.get()
  h_ = Linear(dim)(x_lin + h_prev)  # shape {batch, dim}
  h.assign(h_)

  out = loop.stack(h_)  # shape {time,batch,dim}
  h_last = loop.last(h_)

# h.get() would now return the last state
# h_last is an alternative

Or with a module as:

class MyRec(Module):
  def __init__(self):
    super().__init__()
    self.x_linear = Linear(dim)
    self.h_linear = Linear(dim)
    self.h = State({batch, dim}, initial=0)

  def forward(self, x):
    # x shape is {batch, dim}
    x_lin = self.x_linear(x)
    h_prev = self.h.get()
    h = self.h_linear(x_lin + h_prev)  # shape {batch, dim}
    self.h.assign(h)
    return h

rec = MyRec()
with Loop() as loop:  # this introduces a new loop
  x_t = loop.unstack(x)  # shape {batch, dim}
  h_ = rec(x_t)  # shape {batch,dim}. this represents the inner value
  h = loop.last(h_)  # shape {batch,dim}
  out = loop.stack(h_)  # shape {time,batch,dim}

For the TF name scopes (and variable scopes), we should follow #25, i.e. make it exactly as the module hierarchy.

The RETURNN layer name of the created RecLayer via Loop does not matter too much. It could be arbitrary, or some clever (but simple) logic to use the first module name or so. The RETURNN layer hierarchy can be independent from the actual TF name scopes (via #25).

Special options for the RecLayer like include_eos can be options for Loop, like Loop(include_eos=True). Or as a method, like loop.set_include_eos(True).

Loop (potential) methods:

State has methods get and assign. (... See discussion below for more ...)

Current reasonings:

Why no special base class Rec which derives from Module? We want to easily allow to use any kind of module inside a loop. We think the current API makes this more straight-forward.

Why is h not an argument of forward, and why State instead? This allows to call other sub modules, which might define their own hidden state. So the root recurrent module does not need to know about all the hidden states of sub modules.

Why to have the hidden state explicit, and not use sth more close to self.prev? To make the behavior more straight-forward.

The current design allows for nested loops and sub modules with hidden state.
Only the Loop() call actually introduces a new loop.

class MySubRec(Module):
  def __init__(self):
    super().__init__()
    self.h = State({batch,dim})

  def forward(self, a):
    # assume a shape {batch,dim}
    h = self.h.get() + a
    self.h.assign(h)
    return h

class MyRec(Module):
  def __init__(self):
    super().__init__()
    self.sub = MySubRec()
    self.h = State({batch,dim})

  def forward(self, x):
    a = self.h.get() + x

    # example with sub as nested loop
    with Loop() as loop:
      y = self.sub(a)
      y = loop.last(y)

    # or: example with sub in same loop
    y = self.sub(a)
    
    self.h.assign(y)
    return y

There should not be any special handling needed for the Choice layer.
Note that the search flag and train flag logic is a separate thing (#18).

There should not be any special handling needed whether the input to a rec module call would be inside the current/same loop or not. unstack on some value which is already inside the loop would not make sense, though, and should result in an error. But this would all be covered by RETURNN logic already.

RETURNN rec automatic optimization should not cause any problems. RETURNN already should guarantee that it is equivalent. From the user view point, it never ever should matter whether it is optimized. Otherwise this is rwth-i6/returnn#573. On this returnn-common level, it should not matter.


Example for LSTM for a single step:

class Lstm(Module):
  def __init__(self):
    super().__init__()
    self.h = State({batch,dim})
    self.c = State({batch,dim})
    self.ff_linear = Linear(dim * 4)
    self.rec_linear = Linear(dim * 4)

  def forward(self, x):
    # x shape is {batch,dim} (single frame)
    x_ = self.ff_linear(x)
    h_ = self.rec_linear(self.h.get())
    x_in, g_in, g_forget, g_out = split(x_ + h_, 4)
    c = self.c.get() * sigmoid(g_forget) + tanh(x_in) * sigmoid(g_in)
    self.c.assign(c)
    h = tanh(c) * sigmoid(g_out)
    self.h.assign(h)
    return h
@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz albertz pinned this issue Aug 5, 2021
@Atticus1806

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@Atticus1806

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@Atticus1806

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@albertz
Copy link
Member Author

albertz commented Aug 9, 2021

Also related: rwth-i6/returnn#391 rwth-i6/returnn#545

I.e. the question about how to define masked self-attention in a generic/flexible way in this Rec concept. E.g. HistoryMaskLayer or so. This could all be wrapped directly but is this intuitive then? How does it look like? Can we make it more intuitive?

Extending on this, example draft:

with Loop() as loop:
  # x is [B,D] inside, [T,B,D] outside
  qkv = Linear(2*K+V)(x)  # [B,2*K+V] inside, [T,B,2*K+V] outside
  q, k, v = split(qkv, size_splits=[K,K,V])  # [B,K|V] inside, [T,B,K|V] outside
  k_accum = cum_concat(k)  # [B,T',K] inside, [B,T'',K] outside
  v_accum = cum_concat(v)  # [B,T',V] inside, [B,T'',V] outside
  energy = dot(q, k_accum, red1="static:-1", red2="static:-1", var1="T?", var2="T")  # [B,T] inside, [T,B,T'] outside
  att_weights = softmax_over_spatial(energy)  # [B,T] inside, [T,B,T'] outside
  att = dot(v_accum, att_weights, red1="T", red2="stag:history", var1="static:-1", var2=[])  # [B,V] inside, [T,B,V] outside

Note that the cum_concat behavior is conceptually similar to our loop.stack. Maybe this should be unified?

Note that we have the principle that the user should not need to think about the automatic optimization (rwth-i6/returnn#573).

The axes descriptions and axes themselves still need to be worked out (see referenced RETURNN issue above).
softmax_over_spatial (SoftmaxOverSpatialLayer) needs to be clever about the history time axis.

@albertz
Copy link
Member Author

albertz commented Aug 9, 2021

We should also work out how masked computation on the example of transducer with SlowRNN and FastRNN looks like. Again, in this example, what we want:

  • It should be straight-forward to write. The model definition would be for the recognition, i.e. using a Loop() and Choice() (or choice()) on the probability distribution.
  • This should be efficient in training (when possible). I.e. the automatic optimization should be able to handle this. But no special care by the user should be needed. (In the worst case, it would just not be optimized, but it never would be wrong.)

Example code draft:

x  # shape {batch,enc_time,dim}
slow_rnn = SlowRNN()
fast_rnn = FastRNN()
blank_pred = Linear(1)
non_blank_pred = Linear(...)
t = State({Batch}, dtype=int32, initial=0)
align_label = State({Batch}, dtype=int32, initial=0)
with Loop() as loop:  # over alignment labels
  x_t = x[t]  # shape {batch,dim}
  with MaskedComputation(mask=(align_label.get() != BLANK)):
    slow = slow_rnn(align_label.get(), x_t)
  fast = fast_rnn(align_label.get(), x_t, slow)
  blank_pred_energy = blank_pred(fast)
  log_prob_blank = log_sigmoid(blank_pred_energy)
  log_prob_not_blank = log_sigmoid(-blank_pred_energy)
  log_prob_non_blank_labels = log_softmax(non_blank_pred(fast))
  log_prob_combined = concat(log_prob_non_blank_labels + log_prob_not_blank, log_prob_blank)
  align_label.assign(choice(log_prob_combined, input_type="log_prob"))
  loop.end(t >= x.seq_len)

This is not so much about MaskedComputation here (see #23 about this) but esp on the question whether there is anything special w.r.t. the Loop concept, or whether this would just work as-is. Maybe that is already the case.

@albertz
Copy link
Member Author

albertz commented Aug 16, 2021

What do you mean or intend?

When not setting target for the Rec unit in the current config this causes an error with the end layer during training, that batch shapes (None, ) and (None, 1) do not match.

I don't really understand. It sounds like a bug. But anyway, things like this are really irrelevant for the discussion here. We should just fix or extend RETURNN in any way needed.

I was thinking each prev: later in the RETURNN config introduces at least 3 new lines of code,

Exactly like in the pure Python code in the beginning. Which is extremely simpel?

which in total would sum up quite quickly.

How? The user probably rarely would write such code. Most States would be introduced by some existing Modules and the user would just use those Modules.

Most likely my problem right now is, I dont really see how this would translate to a RETURNN config.

Well, speaking purely abstractly: RETURNN is generic enough to allow just any possible construction, or if not, it should be, and should be extended to allow that.

So this should never be a concern here in this repo (how it translates to a RETURNN config). The main goal is to have it logical, simple, straight-forward, expectable. How it translates to RETURNN is an implementation detail, after we have clarified what design we want.

Speaking more specifically now, how we actually do the implementation: I think most aspects were already discussed here (the thread gets too long to easily see that...). There might be some open questions but I think they can be solved easily.

I will start working on this once I finished the generalized self attention logic in RETURNN.

I can see that you try to do more with State here, but right now I am not sure what for, which if I understand you correctly is the current goal .Can you maybe give an example where we would want to do more with the State?

I just gave you one? This is one such example:

lstm = Lstm(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

I am not sure wether I understand this example correctly. Right here we dont have a state variable, so why would we assign it twice if we only call the assign command once?

We have. The Lstm introduces it. See the Lstm definition above. E.g. there is then lstm.c. And the lstm(...) call results in lstm.c.assign.

This has multiple state assign calls because you call the same module multiple times which reuses the same hidden state.

Although, maybe this is actually not such a good example. I wonder now. I'm thinking about e.g. Universal Transformer now, where you also would call the same self attention layer multiple times. But there you do not want to have it use the same hidden state, but every call (layer) should still have its own hidden state, just the parameters should be shared.

So maybe the idea that a module can have hidden state is not good after all? Again, speaking of PyTorch, PyTorch modules can have that as well (as buffers), but usually it is never done this way. E.g. the LSTM module in PyTorch or LSTMCell module in PyTorch explicitly gets the prev state and cell as arguments and returns the new state and cell.

Maybe we should somehow make it more explicit, what hidden state is being used? Maybe like this:

lstm = Lstm(...)
with Loop() as loop:
  with StateScope():
    layer1 = lstm(loop.unstack(x))
  with StateScope():
    layer2 = lstm(layer1)
  output = loop.stack(layer2)

Maybe we also should differentiate between buffers and state? Buffers (in general, and also in PyTorch) are intended to not really have influence on the behavior. Or not on the model behavior at least. I'm actually not sure sure on the typical usage of buffers in PyTorch. But they are definitely not used as hidden state. Hidden state is always passed explicitly.

@Atticus1806

This comment has been minimized.

@albertz
Copy link
Member Author

albertz commented Aug 16, 2021

So (to verify I am not mistaken) pretty much the state is a concept of returnn common which then in some cases might translate to prev: in RETURNN, but is a lot more powerful and could do a lot more than just that. Prev: is only one of the possible cases State can be used in. Using this definition I agree with you, these lines code for this is simple a straight forward.

Yes. In actually most of the common cases here in returnn-common it would result in "prev:...". I'm more thinking about the maybe somewhat unusual cases.

Maybe we should somehow make it more explicit, what hidden state is being used?

I dont know. If we really want to end up with a module like structure I feel like hidden states should be something the user should not really need to deal with himself if he is not changing basic configuration of that module (which shouldnt happen too often, because users usually should put their model together from modules without needing to make too many new ones imo). Adding something like with StateScope() would, if not needed for some logic/expressiveness make it more confusing I think.

There are multiple aspects/concepts:

  • A Module class.
  • A Module object (instance of the class).
    A module object can have parameters, and sub modules.
    We propose here that it can also have state.
  • A Module object call. This performs the computation by using the model parameters.
    Calling the same module object again will reuse the same model parameters.
    We propose that it also uses the same state. Although that is what I question now a bit.

So the model parameters are shared in this example:

lstm = Lstm(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

In the current proposal, also the hidden state is shared in this example.

However, maybe the more common use case is that the user wants to share the parameters but not the hidden state. Again, I'm thinking about Universal Transformer. But also most other example I can think of where parameters are to be shared, you would not want to share the hidden state.

I think it should be simple to write code for the common cases, while also allowing for exotic things, while also being straight-forward/logical.

How would you actually implement this common case, where you want to share parameters but not the hidden state?
Maybe like this:

lstm = Lstm(...)
with Loop() as loop:
  # Introduce separate state vars such that each layer has own hidden states.
  # Copy state var now to have the right initial state.
  h1, c1 = lstm.h.copy(), lstm.c.copy()
  h2, c2 = lstm.h.copy(), lstm.c.copy()
  lstm.c, lstm.h = h1, c1  # assign own hidden state vars
  layer1 = lstm(loop.unstack(x))
  lstm.c, lstm.h = h1, c1  # assign own hidden state vars
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

I'm not sure if this is still so easy. Or I would rather say no.

Note, in PyTorch, this (sharing parameters but not hidden state) would look like:

lstm = LstmCell(...)
layer1_state = lstm.initial_state()
layer2_state = lstm.initial_state()
for x_ in x:
  layer1, layer1_state = lstm(x, layer1_state)
  layer2, layer2_state = lstm(layer1, layer2_state)
  output = loop.stack(layer2)

In PyTorch, sharing parameters and hidden state would look like:

lstm = LstmCell(...)
lstm_state = lstm.initial_state()
for x_ in x:
  layer1, lstm_state = lstm(x, lstm_state)
  layer2, lstm_state = lstm(layer1, lstm_state)
  output = loop.stack(layer2)

In both cases, you make the handling of state explicit. So there is no confusion on the behavior of state, because it is always explicit.

So, I'm wondering if we also should make it always explicit. But still in a generic way. That is why I proposed StateScope. Maybe StateScope should be passed to the module call via state=StateScope() instead of using with here. Or maybe call it StateHolder or so. Which can hold any nested structure of state vars. The idea with with was that it would be nested and automatically use for any nested sub module calls. I agree that the initial example code for StateScope above is not so clear. This should be worked out.

This is all about parameter sharing, i.e. calling the same module multiple times. And how the hidden state should be handled in this case. Because we hide the hidden state away in the current proposal, such module call has side effects on internal state. Generally speaking, side effects are bad, because it makes it more difficult to reason about the behavior of code. Usually you want that some call like lstm(...) does not have side effects.

Btw, this argument is also about module calls in general. Not really about rec loop actually. E.g. this code in the current proposal:

lstm = LstmOnSeq(...)
layer1 = lstm(x)
layer2 = lstm(layer1)
output = layer2

It again would share the parameters (as expected) but also the hidden state (not sure if that is expected, although it follows logically from the current proposal).

In PyTorch, the example of sharing the parameters but not the hidden state would look like:

lstm = Lstm(...)
layer1, _ = lstm(x)  # default initial hidden state
layer2, _ = lstm(layer1)  # default initial hidden state
output = layer2

@albertz
Copy link
Member Author

albertz commented Aug 16, 2021

Ok, following from those thoughts, I'm now thinking that we definitely should make it explicit. For all modules which have state (by themselves, or via sub modules). And state is not an attrib of the module but just an explicit parameter and return value of a module call.

There can be reasonable default initial state. E.g. when you have def forward(self, ..., state=None) in a module, you can check if state is None: state = some_default_initial_state() or so. Every such module would return the new state. E.g. like a tuple as in the PyTorch LSTM example.

The example for sharing parameters but not sharing hidden state would look sth like:

lstm = Lstm(...)
layer1_state = StateScope()
layer2_state = StateScope()
with Loop() as loop:
  layer1, layer1_state = lstm(loop.unstack(x), layer1_state)
  layer2, layer2_state = lstm(layer1, layer2_state)
  output = loop.stack(layer2)

Which is very similar to the PyTorch example.

From the point of view of the model implementation, it is a bit strange now that there is a conceptual difference between arguments which are state (and thus not normal LayerRef instances but State or StateScope). And maybe sometimes you want that one argument becomes a state, and maybe in other times a different argument becomes a state, and maybe sometimes you do not even want that any of the arguments are states, e.g. calling lstm(x, (prev_h, prev_c)) where prev_h and prev_c are just other normal LayerRefs.

So this is a problem. I think the model definition should handle it all the same and just expect always LayerRefs (or nested structures of LayerRefs).

But then how to we handle this? Maybe more explicitly:

lstm = Lstm(...)
layer1_state = State(lstm.initial_state())
layer2_state = State(lstm.initial_state())
with Loop() as loop:
  layer1, layer1_state_ = lstm(loop.unstack(x), layer1_state.get())
  layer1_state.assign(layer1_state_)
  layer2, layer2_state_ = lstm(layer1, layer2_state.get())
  layer2_state.assign(layer2_state_)
  output = loop.stack(layer2)

State here would be a bit extended that it can also handle nested structures.
lstm.initial_state would still return a normal LayerRef.

I think this would be very logical and straight-forward now again. And all is explicit and behavior is always clean.

Other modules (e.g. Lstm etc) would never deal with the State concept. The only point where you deal with that is when you write a rec loop, i.e. Loop(). Which is probably also encapsulated away in some module, like Decoder, and the user in most common cases probably just writes sth like:

encoder = Encoder(...)
decoder = Decoder(...)
output = decoder(encoder(...))

I wonder a bit if this is too complicated now, to write the Loop and explicitly handle the state this way. But I'm not really sure how to make this simpler.

Or maybe like:

lstm = Lstm(...)
loop_state = State()
loop_state.layer1 = lstm.initial_state()
loop_state.layer2 = lstm.initial_state()
with Loop() as loop:
  layer1, loop_state.layer1 = lstm(loop.unstack(x), loop_state.layer1)
  layer2, loop_state.layer2 = lstm(layer1, loop_state.layer2)
  output = loop.stack(layer2)

This would add a bit of clever handling into the State object. Basically the assign and get calls here are covered by __setattr__ and __getattr__. This allows to write the loop code a bit shorter and maybe more readable.

@albertz
Copy link
Member Author

albertz commented Sep 24, 2021

I was quite busy with all the recent dim tag work (rwth-i6/returnn#577) and generalized self attention (rwth-i6/returnn#391). Generalized self attention is finished now, and the concept of consistent dim tags has been improved and extended a lot. Although there are still on-going discussions on taking this even further (rwth-i6/returnn#597, rwth-i6/returnn#632). Some of these might also be relevant for returnn-common but we should discuss this independently, maybe in #17.

I lost a bit track on the current state here, and the reasoning.

I think the final conclusion on the hidden state was to have it always explicit as a (state) argument to the module call, and also explicitly return the new state. So very much like PyTorch. So it is not really hidden at all.

I like it that way because there is no ambiguity in the code, how the hidden state is handled. This is all explicit.

The initial post description is maybe outdated on this. Do we still need to have the state as an attribute to the module (e.g. self.h = State({batch,dim}))? Why?

And why do we need this special State object with assign and get calls? Or the simplified code which uses __setattr__ and __getattr__ on some special State object?

I can see that the state module call argument might not just be a single LayerRef but it could also be some nested structure, esp if this is a module with other sub modules which also could have state. But this does not really explain why we need State. Maybe the concept of the State object was just introduced for the initial idea where we did not want to make this explicit, where the hidden state would have been all implicit and hidden? And now not needed anymore?

Edit The last comment just before this actually says:

state is not an attrib of the module but just an explicit parameter and return value of a module call.

So as I argued above, we would not have this module attrib anymore (like self.h = State({batch,dim})).
States are always explicitly passed to a module call, and new states are returned from it.

However, I don't understand this anymore:

From the point of view of the model implementation, it is a bit strange now that there is a conceptual difference between arguments which are state (and thus not normal LayerRef instances but State or StateScope)

Why is there a conceptual difference? Does it need to be? Why? Why can't the state module call argument just be a regular LayerRef (or nested structure)?

Actually I also addressed this before:

I think the model definition should handle it all the same and just expect always LayerRefs (or nested structures of LayerRefs).

But still in this comment I keep State (or StateScope) as a special concept (e.g. loop_state = State()). Why is this needed?

Edit

The first example from above would look like:

lstm = Lstm(...)
loop_state_layer = lstm.initial_state()
with Loop() as loop:
  layer1, loop_state_layer = lstm(loop.unstack(x), loop_state_layer)
  layer2, loop_state_layer = lstm(layer1, loop_state_layer)
  output = loop.stack(layer2)

The second example from above would look like:

lstm = Lstm(...)
loop_state_layer1 = lstm.initial_state()
loop_state_layer2 = lstm.initial_state()
with Loop() as loop:
  layer1, loop_state_layer1 = lstm(loop.unstack(x), loop_state_layer1)
  layer2, loop_state_layer2 = lstm(layer1, loop_state_layer2)
  output = loop.stack(layer2)

The problem is that this does not exactly corresponds to the Python while ...: loop as we want it to w.r.t. the Python local variables. We cannot correctly infer from this Python code that loop_state_layer is a recurrent variable which changes inside the loop. So this is probably one reason for this StateScope as it was suggested.

But the same problem is actually also for any other output. Anything in the loop which wants to use the value from the previous iteration. This was not really addressed before? Is this not a problem? Or was this solved differently?

Edit Ah, this is actually also in the very first proposal. That is what State really is for. To handle the RETURNN prev: logic in a generic way. So basically this solves #6.

So, to recap:

  • Other recurrent state (prev:... in RETURNN, Prev: in Rec #6) and hidden state would be handled in the same way.
  • And State (or StateScope or however we call it) is only relevant for the code which directly operates with Loop().
  • Any other submodules would not have any special logic for this. The prev hidden state would just be a normal module call parameter. However, we would have some conventions here for modules with state:
    • They would have a state argument in the module call. Which can be any arbitrary nested structure.
    • They would return the state as well. So usually a tuple where the last item is the state.

I still see a problem in catching errors here. When the user writes the code ignoring State but just like before, it would compile without error, but it would do the wrong thing, i.e. not the expected behavior. It would always use lstm.initial_state() as state in every iteration in this example.

Can we somehow catch such errors to avoid unexpected behavior?

Or is this maybe not too much a problem as this is actually not too much unexpected?

Also, do we want that we can also skip the state argument, i.e. that it has a reasonable default? Modules might have state=None in the function signature and then internally do sth like if state is None: state = self.initial_state(). However, this code would have exactly the problem as just described. I.e. then it would not use the prev state but always the initial state in every iteration. Is this a fundamental problem which cannot really be solved?

In PyTorch, this is the same behavior though, right? In PyTorch, there is the difference of LSTM (on seq) vs LSTMCell (on a single frame). LSTM does have this default initial state, but LSTMCell does not, as it does not make sense for this case. In RETURNN, we have both together, which maybe causes this confusion. But we do not need to wrap it exactly the same here in returnn-common. We could also have some LSTMCell. Or maybe use RnnCell or RecCell to be able to use other rec units from RETURNN as well (not just NativeLstm2). Or wrap them all separately. Or both. And for all of these, we require to have the state explicit as argument (no default). Although such modules would still have some function initial_state.

@albertz
Copy link
Member Author

albertz commented Sep 24, 2021

I'm questioning now whether this explicit code is maybe too complicated for many of the common use cases. On the RETURNN side, the hidden state is hidden and not explicit. So when translating some old-style RETURNN model to this new way, it would take somewhat extra effort (although it should be straight-forward).

One alternative is to introduce State inside the module call (but not as a module attrib). So a LSTMCell or RecCell could be defined like:

class LstmCell(Module):
  def __init__(...): ...
  def forward(self, input, state=None):
    if state is None:
      state = StateScope(...)
    h, c = state.h, state.c
    ...
    h.assign(new_h)
    c.assign(...)
    return new_h

Then we can use e.g. this code:

lstm = LstmCell(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

This would do param sharing between both lstm calls. However, it would not share the same hidden state (as expected).

The StateScope would be attached to the Loop (via sth like Loop.get_current() -> Loop which we can implement via the with loop: logic).

The not-so-nice thing about this is that we clearly differentiate between state and other input now. So it becomes complicated/non-straightforward when the user would also want to pass some custom state (custom h or c) to lstm. Or how to return the new cell state c.
It would maybe look like this:

layer1_c = get_state_scope(layer1).c.get()

@Atticus1806
Copy link
Contributor

So I just read through anything and I will add my comments. Since it was quite a lot, maybe I also missunderstood something, then just correct me:
So first lets start with the concept of a State. From what I remeber we started with adding a logic which is able to handle the prev: logic from Returnn and then started expanding on that adding more "features" like hidden state handling to it.

One of my general questions would be, why we even do explicit recurrent handling like getting state updates and so on instead of just "references" to these updates which would be passed in to Returnn for the config. Isn't the actual handling of how the LSTM unit works something the is part of Returnn and what we try to achieve in this repo is a nice way of writing it down. From what I understand right now you are also looking to include additional concepts.

From this I would also conclude my reasoning for deciding between the two variants for the hidden state: I feel like the implicit handling is one of the biggest strengths of Returnn, even though ofc. its a sharp sword to work with. Not explicitly having to worry about certain details of a layer makes it (at least from my HiWi view) quite more easy to start and also work with. What we should aim for in my opinion is an interface which is as simple as possible to "just start", but has the flexibility of allowing stronger configurations once the user is more used to it. So I feel like it would be fine to accept that basic configurations (where in this caes you don't do modifications to the internal hidden state logic) are as easy as possible but if you want to make use of some stronger concepts you would have to go a bit more in depth. The problem why this in other casees causes troubles is when the documentation is not good enough, making users trying to make the transition into more detail feel lost.

Now onto the specific example: I would prefer the second option in the general case. For the more advanced options I would then include an option to get the StateScope or certain elements of it and also include the possibility to make modifications to it. Again I think this is more of an advanced concept which I am not sure how much it will be used. Maybe I am mistaken here. So what we could allow is doing something like:
scope = get_state_scope(layer1) and then do stuff like scope.c.get() like you suggested, but also maybe scope.c.set() or even set_state_scope(layer1, scope) to overwrite the full scope with the (modified) new scope. This would leave the whole construction flexible enough to handle these cases in my opinion without too much of a workarround.

But overall I think this is a point where we need to put some thought into. Maybe we could work our 2 or 3 concrete ways (with 3-4 examples each) and then ask other users about it, because I feel like this is something where User feedback might be meaningful to make a decision. What do you think?

@albertz
Copy link
Member Author

albertz commented Sep 24, 2021

So first lets start with the concept of a State. From what I remeber we started with adding a logic which is able to handle the prev: logic from Returnn and then started expanding on that adding more "features" like hidden state handling to it.

Basically. But not directly. All the discussion here should be seen independent from RETURNN really. But really more about what would be a straightforward design (for people which do not know about RETURNN). It should not be that we adopt some strange thing from RETURNN only because that is how it is now in RETURNN.

So, when we think about loops (for ...: or while ...:), we need some way to access values from the previous iteration. And the question is, how to design that.

The next question is, whether we want to allow hidden state, which can be hidden, and thus is a separate concept, or whether there should not be a separate concept for hidden state, and it would just be the same as other values from previous iteration.

The cleanliness and straightforwardness of the design is of highest priority here. How this maps to RETURNN in the end is only secondary. We can surely map whatever we came up with, as long as it is well defined. Or if not, we can simply extend RETURNN such that we can. Although for almost everything discussed here, I think that RETURNN already supports it, so no extension or modification on RETURNN side would be needed.

One of my general questions would be, why we even do explicit recurrent handling like getting state updates and so on instead of just "references" to these updates which would be passed in to Returnn for the config.

I don't exactly understand. What do you mean by references to the updates?

The argument of explicit vs hidden/implicit is simple: Because only that way, it is straightforward. Hidden/implicit is always problematic. Esp when you want to change it, or have some control over it, it becomes unnatural. As long as you do not want to touch or access the hidden state, it doesn't matter. But as soon as you do, it matters. And there always will be such cases.

Isn't the actual handling of how the LSTM unit works something the is part of Returnn and what we try to achieve in this repo is a nice way of writing it down.

We are not changing that. Here we simply discuss how we design the handling of accessing previous values (values from the prev loop iteration), and hidden state, or whether hidden state should be handled differently or just the same as other previous values.

From what I understand right now you are also looking to include additional concepts.

No. No underlying concept is really new. It would still all map to what RETURNN does right now. Just the API is new. This is the whole point here of returnn-common. And for designing the API, we have the freedom to do it as we want. And I think we should try to prioritize cleanliness and straightforwardness.

From this I would also conclude my reasoning for deciding between the two variants for the hidden state: I feel like the implicit handling is one of the biggest strengths of Returnn, even though ofc. its a sharp sword to work with. Not explicitly having to worry about certain details of a layer makes it (at least from my HiWi view) quite more easy to start and also work with.

Many people claim that PyTorch is easier because it is all explicit. Usually nothing is hidden away. When reading other people's code, you rarely would ask yourself what it would actually do, or whether this module has some hidden state, because it is all explicit.

Explicitness can result in slightly more code but it is usually still pretty simple and short, and it is easier to follow and reason about because you don't have to think about implicit behavior.

Implicit behavior is maybe fine for all the simple cases but once it gets more complex, it can make it really hard to reason about.

I spoke with some other people and they all strictly preferred the explicitness.

What we should aim for in my opinion is an interface which is as simple as possible to "just start", but has the flexibility of allowing stronger configurations once the user is more used to it.

Yes, simplicity and flexibility are both the main goals of RETURNN, and also here of returnn-common.

However, I think you argue exactly for the opposite as I did before.

What does simple mean? Simple does not necessarily means short code. Simple is about writing code, reading code, and understanding code. It should never be ambiguous, otherwise it is not simple. It should be clear and straightforward. Straightforwardness makes it simple to write and understand. Clearness makes it simple to read.

What does flexibility means? It does not just mean that more complex things are possible. More complex things are always possible. Flexibility also means that more complex things are straightforward to do. Otherwise it is actually not really flexible, if something is not straightforward or unclear.

So I feel like it would be fine to accept that basic configurations (where in this caes you don't do modifications to the internal hidden state logic) are as easy as possible but if you want to make use of some stronger concepts you would have to go a bit more in depth. The problem why this in other casees causes troubles is when the documentation is not good enough, making users trying to make the transition into more detail feel lost.

You cannot really compensate a complicated non-straightforward design by just having better documentation. Treating hidden state as something different than non-hidden states just makes it more complicated, and not straightforward. When you have worked with non-hidden state before, it is not clear or straightforward how to work with hidden state now, when this is a different thing or concept.

Now onto the specific example: I would prefer the second option in the general case.

I actually asked someone on what behavior he would expect from this code:

lstm = LstmCell(...)
with Loop() as loop:
  layer1 = lstm(loop.unstack(x))
  layer2 = lstm(layer1)
  output = loop.stack(layer2)

He expected that the two lstm calls would not only share the params but also the hidden state. Which is exactly not what would happen. Or it depends on the implementation of LstmCell. So this is a perfect example what I meant before: It is not easy to read or understand. The behavior of the hidden state is unclear and ambiguous. And it is not straightforward how to handle hidden state now.

But overall I think this is a point where we need to put some thought into. Maybe we could work our 2 or 3 concrete ways (with 3-4 examples each) and then ask other users about it, because I feel like this is something where User feedback might be meaningful to make a decision. What do you think?

Yea, I also thought about getting some more feedback. It's a good idea to prepare some examples.

In all cases, what I think is important:

  • The code should be clear, not be ambiguous. It should be clear from reading the code, what will happen. Even from someone who is unfamiliar with RETURNN, or this new API. I would argue an API is bad if code is not clear.
  • It should be straightforward to achieve something, esp something more complex
    • by extending and modifying some existing code,
    • or writing it from scratch.

So the different examples could be:

  • 2 LSTM layers, not sharing params, not sharing hidden state.
  • 2 LSTM layers, not sharing params, not sharing hidden state, the first layer gets also the prev output of the second layer.
  • 2 LSTM layers, sharing params, not sharing hidden state.
  • 2 LSTM layers, sharing params, sharing hidden state.
  • 2 LSTM layers, not sharing params, sharing hidden state.
  • 2 LSTM layers, not sharing params, sharing hidden state, applying tanh on the hidden state in between.
  • Prefix decoding

I agree, some of these examples are maybe a bit exotic. But that is my point. It should still be straightforward to do. Otherwise it is not really flexible. In PyTorch, all of these are very simple and straightforward. In current RETURNN (dict-style net def), while all are possible in principle, only the first three are simple, while the others are definitely not, esp not straightforward. I expect and argue that whenever you have it explicit, it becomes straightforward.

@albertz
Copy link
Member Author

albertz commented Sep 28, 2021

Some further thoughts on the handling of state in general in a loop (orthogonal on the discussion whether hidden state should be a separate concept or not):

While assign and get on such a State object are somewhat canonical, this leads to lots of boilerplate code, which makes it somewhat more complicated to write than the corresponding natural logic in normal Python (or PyTorch) code (in a normal for or while loop).

I'm thinking about the options to simplify that to more canonical simplified Python, while still also not doing too much magic, such that it is still clear what happens.

  • One approach was already proposed, which would be a StateHolder object or so, where we do the same logic in __setattr__ and __getattr__.

    This is mostly fine, except of:

    • still some little overhead (always needs loop.state. or so as prefix)
    • maybe already too much magic?
    • If the user forgets about this, and writes the logic without loop.state. as prefix, there is no error and just wrong behavior. I don't think there is a good way we could detect this as an error.
  • Another idea I had was to pass locals() to Loop. At the exit of Loop, could this detect what local vars have changed inside the loop? Then this can also be used to implement such logic.

    Downsides:

    • Even more magic?
    • Maybe depends on specific CPython version, and might break?
    • Imagine you write b = b + 1 in the loop. So b gets reassigned. But now IDEs (e.g. PyCharm) and code checkers would complain that the new b is not used anymore. It is only used because of the locals magic.

    Some variant, which solves some of the downsides, while adding again some further function:
    The user could call sth like loop.exit(locals()) explicitly at the end of the loop. This is slightly less magic, more robust (should always work), and IDEs (at least PyCharm) will also not complain about unused local vars.

    I played a bit around with variations of this here.

  • We could do some Python-level code transformation, similar as JAX, TF tf.function (see AutoGraph transformations), PyTorch jit, etc. This is extremely flexible and powerful and basically allows us to do it in whatever way we want. We even can simply use normal for or while loops. This directly allows us to write very straight-forward Python code.

    Main downside: This is a heavy and complex thing to do. This adds a lot of complexity. Also, while I have some ideas how this can be implemented, and I have implemented some similar code before (on AST level, for pytorch-to-returnn), there are various different possible approaches here, and this would also need some more research, e.g. how tf.function does this, etc.

    Just out of interest, I'm following the logic of tf.function to the autograph transformation. This looks extremely complex, with lots of edge cases. At some point, it calls autograph.converted_call. And after a long list of exceptions and extra checks, that calls conversion.convert. And that calls AutoGraphTranspiler.transform_function. Then there is FunctionTranspiler, which seems to work on Python AST level. How does it get the AST? This looks ugly. There is inspect_utils.getimmediatesource which uses inspect.findsource and inspect.getblock. Which simply tries to get the source code filename and then loads that file. Then it calls gast.parse, where gast is this external Python package, which seems to wrap some incompatibilities in the official Python ast package between Python 2 and Python 3. But this is basically ast.parse, which uses the Python compile builtin, with flags = PyCF_ONLY_AST. Then the AST transformation logic happens in AutoGraphTranspiler. Maybe most interesting is the ControlFlowTransformer which handles if and while.

    And I just scratched the surface of tf.function autograph. This goes much deeper.
    The question is if we maybe can get away with much simpler Python AST transpile logic and code. We can maybe reuse FunctionTranspiler. Or maybe some other Python library for transpiling.

    I played a bit around with the TF transpiler code, which is generic (although it would have been better if this would be independent of TF, because this probably only exists in TF2, and also the API might not be stable).
    A simple example can be seen here. While this is actually not too complicated for this simple logic, I'm not sure if this is still not way too complex.

So, given these options, I tend to prefer StateHolder with __setattr__ and __getattr__.

@albertz
Copy link
Member Author

albertz commented Sep 28, 2021

How would the StateHolder with __setattr__ and __getattr__ look like? Here some possible variations:

The Loop object already could create that, as loop.state. It's only inside the with block then but this is maybe ok.

Should we allow usages without defining the initial value or the shape? Maybe. In that case, the code for 2 LSTM layers, sharing params, not sharing hidden state can look like:

lstm = Lstm(...)
with Loop() as loop:
  layer1, loop.state.layer1 = lstm(loop.unstack(x), loop.state.layer1)
  layer2, loop.state.layer2 = lstm(layer1, loop.state.layer2)
  output = loop.stack(layer2)

Or for 2 LSTM layers, sharing params, sharing hidden state:

lstm = Lstm(...)
with Loop() as loop:
  layer1, loop.state.lstm = lstm(loop.unstack(x), loop.state.lstm)
  layer2, loop.state.lstm = lstm(layer1, loop.state.lstm)
  output = loop.stack(layer2)

Or consider this Python code:

i = 0
for x_ in x:
  i = i + 1

Equivalent code here:

with Loop() as loop:
  loop.unstack(x)
  loop.state.i = loop.state.i + 1

How to explicitly specify the initial value, and maybe other things like shape? Maybe this can just be extended, like so:

with Loop() as loop:
  loop.unstack(x)
  loop.state.i = State(shape=(), initial=0)
  loop.state.i = loop.state.i + 1

This would be maybe a bit counter intuitive as loop.state.i assignments and reads would normally expect or return a LayerRef, and an assignment by State is handled special.
But other variants might also look a bit inconsistent, like loop.define_state("i", initial=0) or so. I'm not sure.

albertz added a commit that referenced this issue Sep 29, 2021
@JackTemaki
Copy link
Contributor

JackTemaki commented Oct 1, 2021

The "current" example at the top looks already quite understandable and straightforward, but I have some comments / questions:

  • Either the unstack function or the loop object itself should be able to take some kind of information what axes will be used for the loop. I think in the current example everything was based on the time axis.
  • the loop.last should be able to get an n parameter so that you can get the n last states (just what the window layer inside a recurrent net does now, to implement e.g. causal convolution decoders"
  • it is not clear yet to me how masks can/should be used, so if I want to update states only at a certain condition like with a future Condobject, but with a mask tensor matching the "unstacked" axis. Sorry, this is Masked computation wrapper #23

@albertz
Copy link
Member Author

albertz commented Oct 1, 2021

The "current" example at the top looks already quite understandable and straightforward, but I have some comments / questions:

  • Either the unstack function or the loop object itself should be able to take some kind of information what axes will be used for the loop. I think in the current example everything was based on the time axis.

Yes right. This is basically the discussion here: rwth-i6/returnn#597, rwth-i6/returnn#632 and related.

We still did not fully clarify whether we maybe should allow some defaults for cases where it is unique. Or basically we anyway need to do that for all existing layers to not break backward compatibility.

But anyway, this is somewhat orthogonal to the discussion here.

  • the loop.last should be able to get an n parameter so that you can get the n last states (just what the window layer inside a recurrent net does now, to implement e.g. causal convolution decoders"

No, I don't think so. It should follow the same principles as everything in RETURNN, it should be as simple as possible, and atomic. You can very easily get this functionality e.g. by putting a causal WindowLayer and then get the loop.last of that (that is anyway how loop.last with n would work internally).

  • it is not clear yet to me how masks can/should be used, so if I want to update states only at a certain condition like with a future Condobject, but with a mask tensor matching the "unstacked" axis. Sorry, this is Masked computation wrapper #23

Yes, this is #23, but actually, when we have all hidden state also now explicit, i.e. no distinction anymore, this also becomes pretty straight forward even without any such wrapper. The only reason such wrapper can be useful is to allow potential further automatic optimizations (as MaskedComputationLayer does right now).

@albertz
Copy link
Member Author

albertz commented Oct 6, 2021

So, a first version is implemented now.

See the test in test_rec_ff.
It uses this code:

x = get_extern_data("data")
with Loop() as loop:
  x_ = loop.unstack(x, axis="T")
  loop.state.h = y_ = Linear(n_out=13)([x_, loop.state.h])
  y = loop.stack(y_)
return y

Which results in this net dict:

{'loop': {'class': 'rec',
          'from': [],
          'unit': {'h': {'class': 'copy', 'from': 'linear'},
                   'linear': {'class': 'linear',
                              'from': ['rec_unstack', 'prev:h'],
                              'n_out': 13},
                   'output': {'class': 'copy', 'from': 'linear'},
                   'rec_unstack': {'axis': 'T',
                                   'class': 'rec_unstack',
                                   'from': 'base:data:data'}}},
 'output': {'class': 'copy', 'from': 'loop/output'}}

(Sorry for the pprint formatting...)
(Some of the layer names will probably change in some future version.)

@albertz
Copy link
Member Author

albertz commented Oct 6, 2021

So I'm closing this now, as we have the initial design implemented.

Please open separate issues if sth is broken, missing, or whatever.

@albertz albertz closed this as completed Oct 6, 2021
@albertz
Copy link
Member Author

albertz commented Oct 13, 2021

Just for reference, also Loop.end has been implemented now.

@albertz
Copy link
Member Author

albertz commented Oct 13, 2021

What's still missing is the default interface for all wrapped RETURNN layers with hidden state, which should make the state more explicit, as discussed here. This is #31.

@albertz
Copy link
Member Author

albertz commented Nov 2, 2021

Loop.State is also still incomplete, w.r.t. handling of initial state. Which should be explicit. But see #35 for the discussion.

@albertz
Copy link
Member Author

albertz commented Nov 8, 2021

I just found that JAX has quite some similar concept: jax.experimental.loops

There you have loops.Scope() which is kind of similar to our Loop and Loop.State.

@albertz
Copy link
Member Author

albertz commented Dec 29, 2021

Some update:

Instead of using State to initialize a state, the initial state is now defined outside the loop, and then when used inside the loop, it is always the prev value. Multiple assignments are not allowed, not inside and also not outside.

I think this makes it more straight-forward.

Example:

loop = nn.Loop(axis=axis)
loop.state.lstm = self.lstm.default_initial_state()
with loop:
  x_ = loop.unstack(x)
  y_, loop.state.lstm = self.lstm(x_, state=loop.state.lstm, axis=nn.single_step_dim)
  y = loop.stack(y_)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants