-
Notifications
You must be signed in to change notification settings - Fork 695
Question 1x1 conv vs linear #18
Comments
Update: I retested the comparison and it seems the info I gave before is not exactly accurate (I remembered wrong, so I deleted it to avoid misleading anyone). Sorry about the confusion. I give my latest test observations on V100 GPU inference throughput below:
Looking at 2 and 3, the ultimate reason why (NCHW -> permute to NHWC -> PyTorch LN -> linear layers -> layer scale -> permute back to NCHW) is slightly faster than (NCHW -> custom LN -> 1x1 convs -> layer scale), seems to be our custom LN layer operating on NCHW tensors is much slower than the PyTorch's LN that only supports operating on NHWC tensors. So we need the permutation to NHWC anyway to use PyTorch's LN, and given the observation in 1 (without permutation linear is faster than 1x1 convs), we use linear layers before permuting it back to do the "MLP" part. |
Thank you for your swift response, very detailed of an architectural design build, indeed I only noticed your comment on line 30 :). One last question if you don't mind and I am closing this issue, the choice of GeLU over ReLU is due to some dying neurons observations or was solely chosen based on the related Transformers' papers (BERT, GPT2) as mentionned? Is there a case you experimented with alternatives like Swish? Many thanks again, all the best to your future works. |
The choice of GELU over RELU is in part due to imitating Transformers. Another interesting observation is if we stick to RELU, in the next step "Fewer activations" the training curve becomes a bit strange, despite it can converge to a reasonable level finally. We didn't try activations other than RELU and GELU |
Congratulations on your work and thanks for sharing! I'd like to naively ask, what is the reason behind implementing 1x1 convs with fully connected layers? I know they are equivalent but I had been thinking the latter is less efficient.
Thanks in advance!
The text was updated successfully, but these errors were encountered: