Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FourierMatmul function giving some errors #11

Open
jaytimbadia opened this issue Feb 23, 2022 · 0 comments
Open

FourierMatmul function giving some errors #11

jaytimbadia opened this issue Feb 23, 2022 · 0 comments

Comments

@jaytimbadia
Copy link

jaytimbadia commented Feb 23, 2022

Hey,

Wonderful translation!

I just implemented it myself, but this FourierMatmul is giving error of dimension mismatch.
Can you please let me know what are the dimensions it expects? Please help?

My sample Inputs

import json 
from fnet import FNetPretraining
from transformers import FNetTokenizer

with open('config.json', 'r') as f:
    config = json.load(f)


tokenizer = FNetTokenizer.from_pretrained("google/fnet-base")

inputs = tokenizer(['Hello, my dog is so cute', 'Hello world'], 
                return_tensors='pt',
                padding=True,
                truncation=True, max_length=512)

# print(inputs) 
{'input_ids': tensor([[    4,  9665, 16680,   275,  3314,    65,   215,  6387,     5],
        [    4,  9665,   725,     5,     3,     3,     3,     3,     3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]])}
input_ids=inputs['input_ids']
token_type_ids = inputs['token_type_ids']

obj1 = FNetPretraining(config=config)
obj1.forward(input_ids, token_type_ids)
class FourierMMLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.dft_mat_seq = torch.tensor(linalg.dft(config['max_position_embeddings']))
        self.dft_mat_hidden = torch.tensor(linalg.dft(config['hidden_size']))

    def forward(self, hidden_states):
        hidden_states_complex = hidden_states.type(torch.complex128)
        return torch.einsum(
            "...ij,...jk,...ni->...nk",
            hidden_states_complex,
            self.dft_mat_hidden,
            self.dft_mat_seq
        ).real.type(torch.float32)

Error

Traceback (most recent call last):
  File "inference.py", line 22, in <module>
    obj1.forward(input_ids, token_type_ids)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 124, in forward
    self.encoder(input_ids, type_ids)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 113, in forward
    sequence_output = self.encoder(embedding_output)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 94, in forward
    hidden_states = layer_module(hidden_states)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 80, in forward
    fft_output = self.fft(hidden_states)
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/sda1/ml_models/fourier_net/fnet.py", line 62, in forward
    return torch.einsum(
  File "/mnt/sda1/luck/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [2, 9, 768]->[2, 1, 1, 9, 768] [768, 768]->[1, 1, 768, 1, 768] [512, 512]->[1, 512, 1, 512, 1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant