-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][B-41,B-48][BUAA] Add type annotations for python/paddle/autograd/ir_backward.py,python/paddle/cost_model/cost_model.py
#66890
Changes from 6 commits
1dcbb32
9e5bd28
75e212d
d721d8f
e7108f3
da287c6
c379da4
607e3d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,6 +16,7 @@ | |||||
|
||||||
import logging | ||||||
import warnings | ||||||
from typing import TYPE_CHECKING | ||||||
|
||||||
import paddle.pir | ||||||
from paddle.autograd.backward_utils import ( | ||||||
|
@@ -59,6 +60,11 @@ | |||||
""" | ||||||
__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper'] | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from typing import Sequence | ||||||
|
||||||
from paddle.base.libpaddle.pir import Value | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
|
||||||
def append_full_like(float_value, copy_value, value, state, backward_ops): | ||||||
with paddle.amp.auto_cast(enable=False): | ||||||
|
@@ -956,7 +962,12 @@ def create_backward_prune_set( | |||||
return outputs_set, inputs_set, no_gradvar_set | ||||||
|
||||||
|
||||||
def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): | ||||||
def calc_gradient_helper( | ||||||
outputs: Value | Sequence[Value], | ||||||
inputs: Value | Sequence[Value], | ||||||
grad_outputs: Value | Sequence[Value | None] | None = None, | ||||||
no_grad_set: set[Value] | None = None, | ||||||
) -> dict[Value, list[Value]]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
block = outputs[0].get_defining_op().get_parent_block() | ||||||
state = State(block) | ||||||
if all_stop_gradient_true(block): | ||||||
|
@@ -1038,7 +1049,12 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): | |||||
return input_grad_map | ||||||
|
||||||
|
||||||
def calc_gradient(outputs, inputs, grad_outputs, no_grad_set): | ||||||
def calc_gradient( | ||||||
outputs: Value | Sequence[Value], | ||||||
inputs: Value | Sequence[Value], | ||||||
grad_outputs: Value | Sequence[Value | None] | None = None, | ||||||
no_grad_set: set[Value] | None = None, | ||||||
) -> list[Value]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
calculate gradient of input | ||||||
|
||||||
|
@@ -1084,15 +1100,15 @@ def calc_gradient(outputs, inputs, grad_outputs, no_grad_set): | |||||
|
||||||
|
||||||
def grad( | ||||||
outputs, | ||||||
inputs, | ||||||
grad_outputs=None, | ||||||
retain_graph=None, | ||||||
create_graph=False, | ||||||
only_inputs=True, | ||||||
allow_unused=False, | ||||||
no_grad_vars=None, | ||||||
): | ||||||
outputs: Value | Sequence[Value], | ||||||
inputs: Value | Sequence[Value], | ||||||
grad_outputs: Value | Sequence[Value | None] | None = None, | ||||||
retain_graph: bool | None = None, | ||||||
create_graph: bool | None = False, | ||||||
only_inputs: bool | None = True, | ||||||
allow_unused: bool | None = False, | ||||||
no_grad_vars: Value | Sequence[Value] | None = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) -> list[Value]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
''' | ||||||
.. note:: | ||||||
**This API is ONLY available in imperative mode.** | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,22 +11,28 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
from __future__ import annotations | ||||||
|
||||||
import json | ||||||
import os | ||||||
from typing import TYPE_CHECKING, Sequence | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 从 |
||||||
|
||||||
import numpy as np | ||||||
|
||||||
import paddle | ||||||
from paddle import static | ||||||
from paddle.base import core | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from paddle.base.compiler import CompiledProgram | ||||||
from paddle.base.framework import Program | ||||||
|
||||||
|
||||||
class CostModel: | ||||||
def __init__(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
pass | ||||||
|
||||||
def build_program(self): | ||||||
def build_program(self) -> tuple[Program, Program]: | ||||||
paddle.enable_static() | ||||||
|
||||||
main_program = static.Program() | ||||||
|
@@ -47,11 +53,11 @@ def build_program(self): | |||||
|
||||||
def profile_measure( | ||||||
self, | ||||||
startup_program, | ||||||
main_program, | ||||||
device='gpu', | ||||||
fetch_cost_list=['time'], | ||||||
): | ||||||
startup_program: Program | CompiledProgram, | ||||||
main_program: Program | CompiledProgram, | ||||||
device: str = 'gpu', | ||||||
fetch_cost_list: Sequence = ['time'], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) -> None: | ||||||
place = paddle.set_device('gpu') | ||||||
x = np.random.random(size=(10, 1)).astype('float32') | ||||||
exe = paddle.static.Executor(place) | ||||||
|
@@ -64,7 +70,7 @@ def profile_measure( | |||||
cost_model = core.CostModel() | ||||||
cost_data = cost_model.ProfileMeasure(device) | ||||||
|
||||||
def static_cost_data(self): | ||||||
def static_cost_data(self) -> dict[str, str | float]: | ||||||
static_cost_data_path = os.path.join( | ||||||
os.path.dirname(__file__), "static_op_benchmark.json" | ||||||
) | ||||||
|
@@ -74,7 +80,9 @@ def static_cost_data(self): | |||||
# return all static cost data | ||||||
return load_dict | ||||||
|
||||||
def get_static_op_time(self, op_name, forward=True, dtype="float32"): | ||||||
def get_static_op_time( | ||||||
self, op_name: str, forward: bool = True, dtype: str = "float32" | ||||||
) -> dict[str, str | float]: | ||||||
# if forward is True, return op forward time, otherwise return op backward time. | ||||||
if op_name is None: | ||||||
raise ValueError( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从
collections.abc
导入