Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer Modules #55

Merged
merged 4 commits into from
Feb 3, 2022
Merged

Transformer Modules #55

merged 4 commits into from
Feb 3, 2022

Conversation

Atticus1806
Copy link
Contributor

@Atticus1806 Atticus1806 commented Nov 3, 2021

Fix #53.

This is a draft for now, since Attention Modules from #52 need to be implemented for that. Comments on code style, changes etc. are welcome already of course.

Orienting on PyTorch for naming of variables and functions.
PyTorch documentation.

Note that PyTorch also has a masking logic in almost all modules, which was left out for now, since it is used in the Attention Modules in PyTorch. Depending on the structure of rc Attention this can of course be added again.

nn/transformer.py Outdated Show resolved Hide resolved
@albertz

This comment has been minimized.

nn/transformer.py Outdated Show resolved Hide resolved
@Atticus1806

This comment has been minimized.

@albertz

This comment has been minimized.

nn/transformer.py Outdated Show resolved Hide resolved
@Atticus1806
Copy link
Contributor Author

Atticus1806 commented Nov 4, 2021

Okay so for now I decided on the following conventions:

d_model becomes dim_model dim_feedforward becomes dim_ff. I first had d_ prefix in mind, since they use it in Attention is all you need, but I am not a big fan of one letter prefixes, they are not always unambiguous. Shortening feedforward also seemed reasonable, here ff is what they also use in Attention is all you need.

For the n vs num prefix I decided for num, since in RETURNN we also have n_out and I found it less confusing reserving n for that. Also again, one letter prefixes are somewhat weird in my opinion. Thus n_head becomes num_heads, num_layers stays the same.

Next point normalization vs norm is now norm, shorter and more clear I think, this was initially not possible because for some reason I additionally imported norm from ., but this is gone now so this works fine.

act is now activation again, since its not imported aswell. This unfortunately does not work for dropout. So I think this has to stay drop or some variant of it.

I also added a bunch of documentation to the parameters. Let me know if I missed anything or you would name variables differently.

@albertz
Copy link
Member

albertz commented Nov 4, 2021

d_model becomes dim_model, dim_feedforward becomes dim_ff.
n_head becomes num_heads, num_layers stays the same.
Next point normalization vs norm is now norm
act is now activation again

Yes, all good.

this was initially not possible because for some reason I additionally imported norm from .
This unfortunately does not work for dropout.

I don't understand this. Why is it not possible?

@JackTemaki JackTemaki self-requested a review November 4, 2021 09:23
@Atticus1806

This comment has been minimized.

@albertz

This comment has been minimized.

nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
@albertz
Copy link
Member

albertz commented Nov 4, 2021

Note that when we adopt #17, all the dimensions which are currently of type int would change the type to DimensionTag. Although the code otherwise would not change.

nn/transformer.py Outdated Show resolved Hide resolved
@albertz albertz mentioned this pull request Nov 4, 2021
@albertz
Copy link
Member

albertz commented Nov 4, 2021

Btw, as usual, also see failing tests.

@Atticus1806

This comment has been minimized.

nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
nn/transformer.py Outdated Show resolved Hide resolved
__init__.py Outdated Show resolved Hide resolved
@Atticus1806 Atticus1806 force-pushed the benedikt_transformer branch 3 times, most recently from 0bfb37c to a30633d Compare November 4, 2021 14:53
@albertz
Copy link
Member

albertz commented Nov 20, 2021

What do you think about leaving positional encodings out of this model for now?

As relative positional encoding is some property deep inside the Transformer (inside the self attention), and also an important mechanism at least for certain tasks, I think we should not ignore this now, because it might have a big influence on our design. Our design should allow to make use of relative pos encoding in an easy way.

Or more generally: The current Transformer design should allow for changes insides the core self attention process in an easy way. E.g. also to replace the self attention by other attention types, like sparse attention, using LHS and all that, Linformer, etc. (@Zettelkasten)

I think the current design does not really allow this. The way you would currently do that is probably to copy & paste the current Transformer code and replace the self attention by some own custom thing. Which basically shows that it does not allow for that. (Note that this is maybe not always bad. In some situations, this approach could just be fine.)

However, whatever potential generic API we propose here, we should also be careful that it is not too difficult to follow, because it is too abstract, has too much indirections, etc. Or even worse, having some quite abstract API, which allows for rel pos encoding, but it would turn out later that it is not really generic enough to allow other things, so it is really only useful for this one specific thing, and thus pointless (if we want it to be specific, it could just be a flag rel_pos_encoding: bool anyway, but we want to avoid that).

Maybe this is tricky to get right. Or maybe not really possible. But we should at least think about it.

@Atticus1806
Copy link
Contributor Author

The current Transformer design should allow for changes insides the core self attention process in an easy way.

Trying to phrase it more general: The transformer is a stack of multiple components. Depending on view and features this amount can be a bit different but in general we have smth. like a potential PreEncoderNet where e.g. (if used) absolute encodings are applied, an EncoderNet, a PreDecoderNet, a DecoderNet and a PostNet. The Encoder and Decoder Nets just stack Encoder/Decoder layer ontop of each other in most of the cases.
The point I am getting to is: I am wondering whether we should maybe provide 2 variations: One "Plug and Play" Transformer with standard architecture (possibly including improvements like relative encodings etc.), solid defaults and features that are required by the users for the vanilla transformer, but at the price of flexibility and then a second, lets call it for now, TransformerConstructor which gets as input possibly there building blocks I talked about before, potentially then stacking Encoder and Decoder layers, but other than that mainly handles the merging together of these components. What do you think?

@albertz
Copy link
Member

albertz commented Nov 24, 2021

When I hear "Plug and Play", my association is a flexible set of building blocks which can be plugged together in many ways, so basically what you describe in your second variant. But anyway, this is just terminology.

I'm not sure if we really need a TransformerConstructor class or so. Maybe the building blocks are already enough.

We should make this a bit more concrete. How do the building blocks look like? What is different from what we already have right now? Because we already have building blocks right now. They are just designed in a kind of hierarchical way. Transformer uses TransformerEncoder, TransformerEncoder uses TransformerEncoderLayer, and TransformerEncoderLayer uses SelfAttention. So when you want to replace or modify SelfAttention in this design, you would need to basically rebuild everything.

I don't really have a good solution so far. I have some ideas but I'm not exactly happy with them. I think we should think a bit more on this.

In any case, yes, there should be a ready-to-use Transformer as well. But this should then just be based on the building blocks, and should not be another separate implementation.

Also, this ready-to-use Transformer should in any case have an easy way to use relative positional encoding. Maybe that's not always needed, but this will be often needed. Maybe that should even be the default. Maybe specifically the Transformer XL variant.

@albertz
Copy link
Member

albertz commented Nov 25, 2021

We maybe should also not overthink this. It doesn't need to be too generic. What we want is an easy way to replace the default self attention inside the Transformer, as this is probably some frequent thing to change. So maybe self_att can just be an argument of the Transformer module, or maybe separately enc_self_att and dec_self_att.

@Atticus1806
Copy link
Contributor Author

So maybe self_att can just be an argument of the Transformer module, or maybe separately enc_self_att and dec_self_att.

Would you hand over a initialized layer which just is called in forwad or would you init it within the modules? Usually I would prefer within the modules, but then the arguments are fixed.

@albertz
Copy link
Member

albertz commented Nov 25, 2021

So maybe self_att can just be an argument of the Transformer module, or maybe separately enc_self_att and dec_self_att.

Would you hand over a initialized layer which just is called in forwad or would you init it within the modules? Usually I would prefer within the modules, but then the arguments are fixed.

Why would you init within the modules?

Of course, we need the ability here to pass in any module with any options. There are many options, like:

  • Passing module (instance of Module). It would be deepcopyd just like we do e.g. for TransformerEncoderLayer. So this is kind of consistent to other existing code.
  • Passing module_cls, module_args, module_kwargs. Ugly in my opinion.
  • Passing module_constructor (or so), which is supposed to be a function without arguments. So the user could pass sth like lambda: SelfAttention(...). This is also very flexible but also maybe a bit ugly and maybe unintuitive.

@Atticus1806
Copy link
Contributor Author

Atticus1806 commented Nov 25, 2021

Decided for your first option for now. Also changed norm_first default to True since I understood this should be better.
Would you type them with nn.Module ? This causes Pycharm to say Module does not have an inital state function, but not sure if this is relevant.

So in general what is "missing" for this to be finished is:

  • nn.SelfAttention initial state
  • nn.Attention implementation, I think this should be an own module.
  • self.search Edit: Or is this working? If so, how?
  • positional encodings, still not sure if I understand correctly, but should this be part of nn.SelfAttention? From my understanding its a trainable weight added during self attention right?

@Atticus1806
Copy link
Contributor Author

Edit: Or is this working? If so, how?

Removed all init parts of search for now only kept it in forward since we agreed on that api beforehand.

@albertz
Copy link
Member

albertz commented Dec 30, 2021

See also this: pytorch/pytorch#67999

@albertz
Copy link
Member

albertz commented Jan 3, 2022

I think we should not try to get this here into a perfect and ready state but just merge it soon and then iterate on it.

@Atticus1806 When do you think would be a good time to merge? Do you want to improve anything further before the merge, or just merge as-is?

@Atticus1806
Copy link
Contributor Author

So to be honest I put the extension of this model on hold for now, since this is a full model which might change a number of times in the concrete implementation depending on how much returnn_common changes until its first full release. While it still has errors in the tests I think they are related to things outside of this PR.

So we could either merge it now and update it when time comes and it is ready to use or we could leave this PR as is and update this PR then. I would be fine with both, if you want to keep the open PR's as little as possible we could merge now.

@albertz
Copy link
Member

albertz commented Jan 3, 2022

I think it's easier to merge it now (or with some cleanup given my recent comments) to easier allow for changes.

Also, most things are actually ready in returnn_common to test this now. Although you are right that some things might still change.

@Atticus1806
Copy link
Contributor Author

I did not really follow the development of returnn_common the last few weeks, since I was working on something else, so from my view what is open is:

  • Line 276 encoder-decoder attention. This should be nn.dot_attention I guess?
  • search (is this solved in general for returnn_common yet? Otherwise add this later I think)
  • Handling of dimension tags (again: is this fully decided yet? How much we need to care about this here right now?)

Other things can be extended later. I think search and dimension tags might also be extended later, depending on its state right now.

@albertz
Copy link
Member

albertz commented Jan 3, 2022

  • Line 276 encoder-decoder attention. This should be nn.dot_attention I guess?

Yes, I think so (without looking at it now).

  • search (is this solved in general for returnn_common yet? Otherwise add this later I think)

There is still some open question on the basic design but in principle it is available. But let's later look at that. Not needed for the merge.

  • Handling of dimension tags (again: is this fully decided yet? How much we need to care about this here right now?)

All on dimension tags should be clear for Transformer. They should consistently be used everywhere. There should not be int for dimensions. See the Conformer for similar code.

@Atticus1806
Copy link
Contributor Author

Okay, so the commit was a bit larger than expected. Updated the dimension tags (and added defaults in the Transformer class)
Do we want these defaults? Or should this be mandatory to be set by the user?
Also some naming changes.
From what I can see the only thing missing now is the "if self.search" part, but as you said this can be done later. Should we comment it out / remove it for now?

@Atticus1806 Atticus1806 marked this pull request as ready for review January 5, 2022 10:14
@albertz
Copy link
Member

albertz commented Jan 5, 2022

Just leave it as it is. I will merge and update it.

Where do the defaults come from now? This should be documented.

@Atticus1806
Copy link
Contributor Author

updated the documentation, linking the paper adding a remark for norm_first , since as discussed here True should be default.

Copy link
Member

@albertz albertz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I cleaned up a bit. As discussed, we will just merge this now, and can the do further improvements.

@albertz albertz merged commit 89ae989 into main Feb 3, 2022
@albertz albertz deleted the benedikt_transformer branch February 3, 2022 15:03
@albertz
Copy link
Member

albertz commented Aug 22, 2022

I think we should rename Transformer/TransformerDecoder to fit the generic high-level decoder (#49, still work-in-progress).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement standard Transformer encoder and decoder
4 participants