diff --git a/deepmd/common.py b/deepmd/common.py index 1f9d3afb0c..d32422f0db 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -538,11 +538,6 @@ def cast_precision(func: Callable) -> Callable: If it does not match (e.g. it is an integer), the decorator will do nothing on it. - Parameters - ---------- - precision : tf.DType - Tensor data type that casts to - Returns ------- Callable @@ -560,6 +555,7 @@ def cast_precision(func: Callable) -> Callable: ... def f(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: ... return x ** 2 + y """ + @wraps(func) def wrapper(self, *args, **kwargs): # only convert tensors returned_tensor = func(