Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MF model #544

Merged
merged 8 commits into from
Mar 17, 2023
Merged

Optimize MF model #544

merged 8 commits into from
Mar 17, 2023

Conversation

rayrayraykk
Copy link
Collaborator

@rayrayraykk rayrayraykk commented Mar 15, 2023

Optimize MF model forward:

  • Before: (user_embedding * item_embedding) [mask]
  • Now: (user_embedding[user_mask].to_sparse() * item_embedding[item_mask].to_sparse()[mask]

Optimization of MF

  1. Convert MF model to nn.Embedding Version: forward with used embeddings
  2. Saving client data to processed_data.pkl to save time when repeating the exp.

@rayrayraykk rayrayraykk added the enhancement New feature or request label Mar 15, 2023
@rayrayraykk rayrayraykk reopened this Mar 15, 2023
@rayrayraykk
Copy link
Collaborator Author

rayrayraykk commented Mar 15, 2023

A CUDA Error caused by sparse tensor asserts, and see pytorch/pytorch#68323 for details (This cannot be fixed in torch1.10 version.)

dtype=torch.float32))
self.register_parameter('embed_item', self.embed_item)
self.num_user, self.num_item = num_user, num_item
self.embed_user = torch.nn.Embedding(num_user, num_hidden, sparse=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why sparse is needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too. In my opinion, the label matrix for MF task is very sparse, but the user/item embedding and prediction results should be dense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See this discussion for details.

dtype=torch.float32).to_dense()

return mask * pred, label, float(np.prod(pred.size())) / len(ratings)
device = self.embed_user.weight.device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what indices and ratings are should be stated


label = torch.tensor(np.array(ratings)).to(device)

return pred, label
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when do you calculate these for negative examples?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is OK to me to just fit the observed entries, but it is a very naive baseline approach I guess

@joneswong
Copy link
Collaborator

@DavdGao hi dawei, please help us check the changes about datasets. Thanks!

Copy link
Collaborator

@DavdGao DavdGao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the inline comments


self.processed_data = os.path.join(self.root, self.base_folder,
'processed_data.pkl')
if os.path.exists(self.processed_data):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think line 104-105 will be executed, since a new exp directory (self.root) will be created for each time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.root = 'data/'

dtype=torch.float32))
self.register_parameter('embed_item', self.embed_item)
self.num_user, self.num_item = num_user, num_item
self.embed_user = torch.nn.Embedding(num_user, num_hidden, sparse=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between torch.nn.Embedding and torch.normal, and is torch.nn.Embedding much faster than torch.normal?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Embedding can be used for matmul with used weight:
E.g.:
user_emb(users_idx) * item_emb(item_idx): (n, d) * (d, n)
But the raw implementation is (N, d) * (d, N)

dtype=torch.float32))
self.register_parameter('embed_item', self.embed_item)
self.num_user, self.num_item = num_user, num_item
self.embed_user = torch.nn.Embedding(num_user, num_hidden, sparse=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me too. In my opinion, the label matrix for MF task is very sparse, but the user/item embedding and prediction results should be dense.

user_embedding = self.embed_user(indices[0])
item_embedding = self.embed_item(indices[1])

pred = torch.diag(torch.matmul(user_embedding, item_embedding.T))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pred is calculated by torch.diag here?

Copy link
Collaborator Author

@rayrayraykk rayrayraykk Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update to :pred = (user_embedding * item_embedding).sum(dim=1)

Copy link
Collaborator

@joneswong joneswong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved.

@joneswong joneswong merged commit c6a7de4 into alibaba:master Mar 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants