Skip to content

Commit

Permalink
Perf: use F.linear for MLP (#4513)
Browse files Browse the repository at this point in the history
It brings <1% speedup.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Refactor**
- Simplified linear transformation implementation in the neural network
layer
	- Improved code readability and efficiency in matrix operations
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
caic99 authored Jan 16, 2025
1 parent fdf8049 commit 2ba3100
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from deepmd.pt.utils import (
env,
Expand Down Expand Up @@ -202,18 +203,14 @@ def forward(
ori_prec = xx.dtype
if not env.DP_DTYPE_PROMOTION_STRICT:
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
else torch.matmul(xx, self.matrix)
)
yy = self.activate(yy).clone()
yy = F.linear(xx, self.matrix.t(), self.bias)
yy = self.activate(yy)
yy = yy * self.idt if self.idt is not None else yy
if self.resnet:
if xx.shape[-1] == yy.shape[-1]:
yy += xx
yy = yy + xx
elif 2 * xx.shape[-1] == yy.shape[-1]:
yy += torch.concat([xx, xx], dim=-1)
yy = yy + torch.concat([xx, xx], dim=-1)
else:
yy = yy
if not env.DP_DTYPE_PROMOTION_STRICT:
Expand Down

0 comments on commit 2ba3100

Please sign in to comment.