Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 7, 2024
1 parent b38a62e commit 5dd73df
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 37 deletions.
9 changes: 2 additions & 7 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def tabulate_fusion_se_atten(

@DescriptorBlock.register("se_atten")
class DescrptBlockSeAtten(DescriptorBlock):

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -274,14 +273,10 @@ def __init__(
self.compress = False
self.is_sorted = False
self.compress_info = nn.ParameterList(
[
nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))
]
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
)
self.compress_data = nn.ParameterList(
[
nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))
]
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
)

def get_rcut(self) -> float:
Expand Down
6 changes: 2 additions & 4 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def tabulate_fusion_se_r(
@BaseDescriptor.register("se_e2_r")
@BaseDescriptor.register("se_r")
class DescrptSeR(BaseDescriptor, torch.nn.Module):

def __init__(
self,
rcut,
Expand Down Expand Up @@ -150,7 +149,7 @@ def __init__(
self.trainable = trainable
for param in self.parameters():
param.requires_grad = trainable

# add for compression
self.compress = False
self.compress_info = nn.ParameterList(
Expand Down Expand Up @@ -414,9 +413,8 @@ def enable_compression(
tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec)
self.compress_data[ii] = tensor_data_ii
self.compress_info[ii] = info_ii

self.compress = True

self.compress = True

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def __init__(
self.trainable = trainable
for param in self.parameters():
param.requires_grad = trainable

# add for compression
self.compress = False
self.compress_info = nn.ParameterList(
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
# JIT will happy in this way...
model.model_def_script = json.dumps(data["model_def_script"])
if "min_nbor_dist" in data.get("@variables", {}):
model.min_nbor_dist = torch.tensor(float(data["@variables"]["min_nbor_dist"]), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
model.min_nbor_dist = torch.tensor(
float(data["@variables"]["min_nbor_dist"]),
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
model = torch.jit.script(model)
torch.jit.save(model, model_file)
48 changes: 25 additions & 23 deletions deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,36 +429,37 @@ def _n_all_excluded(self) -> int:
def grad(xbar: torch.Tensor, y: torch.Tensor, functype: int):
if functype == 1:
return 1 - y * y

elif functype == 2:
var = torch.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3))
return (
0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1)
+ 0.5 * var
+ 0.5
)

elif functype == 3:
return torch.where(xbar > 0, torch.ones_like(xbar), torch.zeros_like(xbar))

elif functype == 4:
return torch.where((xbar > 0) & (xbar < 6), torch.ones_like(xbar), torch.zeros_like(xbar))

return torch.where(
(xbar > 0) & (xbar < 6), torch.ones_like(xbar), torch.zeros_like(xbar)
)

elif functype == 5:
return 1.0 - 1.0 / (1.0 + torch.exp(xbar))

elif functype == 6:
return y * (1 - y)

else:
raise ValueError(f"Unsupported function type: {functype}")



def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int):
if functype == 1:
return -2 * y * (1 - y * y)

elif functype == 2:
var1 = torch.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3))
var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1)
Expand All @@ -467,22 +468,21 @@ def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int):
- SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar**2 + 1) * var1
+ var2
)

elif functype in [3, 4]:
return torch.zeros_like(xbar)

elif functype == 5:
exp_xbar = torch.exp(xbar)
return exp_xbar / ((1 + exp_xbar) * (1 + exp_xbar))

elif functype == 6:
return y * (1 - y) * (1 - 2 * y)

else:
return -torch.ones_like(xbar)



def unaggregated_dy_dx_s(
y: torch.Tensor, w_np: np.ndarray, xbar: torch.Tensor, functype: int
):
Expand All @@ -497,8 +497,8 @@ def unaggregated_dy_dx_s(
raise ValueError("Dim of input xbar should be 2")

grad_xbar_y = grad(xbar, y, functype)
w = torch.flatten(w)[:y.shape[1]].repeat(y.shape[0], 1)

w = torch.flatten(w)[: y.shape[1]].repeat(y.shape[0], 1)

dy_dx = grad_xbar_y * w

Expand Down Expand Up @@ -527,7 +527,7 @@ def unaggregated_dy2_dx_s(

grad_grad_result = grad_grad(xbar, y, functype)

w_flattened = torch.flatten(w)[:y.shape[1]].repeat(y.shape[0], 1)
w_flattened = torch.flatten(w)[: y.shape[1]].repeat(y.shape[0], 1)

dy2_dx = grad_grad_result * w_flattened * w_flattened

Expand Down Expand Up @@ -556,7 +556,7 @@ def unaggregated_dy_dx(

grad_ybar_z = grad(ybar, z, functype)

dy_dx = dy_dx.view(-1)[:(length * size)].view(length, size)
dy_dx = dy_dx.view(-1)[: (length * size)].view(length, size)

accumulator = dy_dx @ w

Expand Down Expand Up @@ -597,18 +597,20 @@ def unaggregated_dy2_dx(
grad_ybar_z = grad(ybar, z, functype)
grad_grad_ybar_z = grad_grad(ybar, z, functype)

dy2_dx = dy2_dx.view(-1)[:(length * size)].view(length, size)
dy_dx = dy_dx.view(-1)[:(length * size)].view(length, size)
dy2_dx = dy2_dx.view(-1)[: (length * size)].view(length, size)
dy_dx = dy_dx.view(-1)[: (length * size)].view(length, size)

accumulator1 = dy2_dx @ w
accumulator2 = dy_dx @ w

dz_drou = grad_ybar_z * accumulator1 + grad_grad_ybar_z * accumulator2 * accumulator2
dz_drou = (
grad_ybar_z * accumulator1 + grad_grad_ybar_z * accumulator2 * accumulator2
)

if width == size:
dz_drou += dy2_dx
if width == 2 * size:
dy2_dx = torch.cat((dy2_dx, dy2_dx), dim=1)
dz_drou += dy2_dx
return dz_drou

return dz_drou
1 change: 0 additions & 1 deletion source/tests/pt/test_model_compression_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
tests_path,
)


if GLOBAL_NP_FLOAT_PRECISION == np.float32:
default_places = 4
else:
Expand Down

0 comments on commit 5dd73df

Please sign in to comment.