Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-41,B-48][BUAA] Add type annotations for python/paddle/autograd/ir_backward.py,python/paddle/cost_model/cost_model.py #66890

Merged
merged 8 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import warnings
from typing import TYPE_CHECKING

import paddle.pir
from paddle.autograd.backward_utils import (
Expand Down Expand Up @@ -59,6 +60,11 @@
"""
__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper']

if TYPE_CHECKING:
from typing import Sequence
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collections.abc 导入


from paddle.base.libpaddle.pir import Value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from paddle.base.libpaddle.pir import Value
from paddle.pir import Value



def append_full_like(float_value, copy_value, value, state, backward_ops):
with paddle.amp.auto_cast(enable=False):
Expand Down Expand Up @@ -956,7 +962,12 @@ def create_backward_prune_set(
return outputs_set, inputs_set, no_gradvar_set


def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
def calc_gradient_helper(
outputs: Value | Sequence[Value],
inputs: Value | Sequence[Value],
grad_outputs: Value | Sequence[Value | None] | None = None,
no_grad_set: set[Value] | None = None,
) -> dict[Value, list[Value]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> dict[Value, list[Value]]:
) -> ValueDict:

block = outputs[0].get_defining_op().get_parent_block()
state = State(block)
if all_stop_gradient_true(block):
Expand Down Expand Up @@ -1038,7 +1049,12 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
return input_grad_map


def calc_gradient(outputs, inputs, grad_outputs, no_grad_set):
def calc_gradient(
outputs: Value | Sequence[Value],
inputs: Value | Sequence[Value],
grad_outputs: Value | Sequence[Value | None] | None = None,
no_grad_set: set[Value] | None = None,
) -> list[Value]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> list[Value]:
) -> list[Value | None]:

"""
calculate gradient of input

Expand Down Expand Up @@ -1084,15 +1100,15 @@ def calc_gradient(outputs, inputs, grad_outputs, no_grad_set):


def grad(
outputs,
inputs,
grad_outputs=None,
retain_graph=None,
create_graph=False,
only_inputs=True,
allow_unused=False,
no_grad_vars=None,
):
outputs: Value | Sequence[Value],
inputs: Value | Sequence[Value],
grad_outputs: Value | Sequence[Value | None] | None = None,
retain_graph: bool | None = None,
create_graph: bool | None = False,
only_inputs: bool | None = True,
allow_unused: bool | None = False,
no_grad_vars: Value | Sequence[Value] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
no_grad_vars: Value | Sequence[Value] | None = None,
no_grad_vars: Value | Sequence[Value] | set[Value] | None = None,

) -> list[Value]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> list[Value]:
) -> list[Value | None]:

'''
.. note::
**This API is ONLY available in imperative mode.**
Expand Down
24 changes: 16 additions & 8 deletions python/paddle/cost_model/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import json
import os
from typing import TYPE_CHECKING, Sequence
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

collections.abc 导入 Sequence


import numpy as np

import paddle
from paddle import static
from paddle.base import core

if TYPE_CHECKING:
from paddle.base.compiler import CompiledProgram
from paddle.base.framework import Program


class CostModel:
def __init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self):
def __init__(self) -> None:

pass

def build_program(self):
def build_program(self) -> tuple[Program, Program]:
paddle.enable_static()

main_program = static.Program()
Expand All @@ -47,11 +53,11 @@ def build_program(self):

def profile_measure(
self,
startup_program,
main_program,
device='gpu',
fetch_cost_list=['time'],
):
startup_program: Program | CompiledProgram,
main_program: Program | CompiledProgram,
device: str = 'gpu',
fetch_cost_list: Sequence = ['time'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fetch_cost_list: Sequence = ['time'],
fetch_cost_list: Sequence[???] = ['time'],

) -> None:
place = paddle.set_device('gpu')
x = np.random.random(size=(10, 1)).astype('float32')
exe = paddle.static.Executor(place)
Expand All @@ -64,7 +70,7 @@ def profile_measure(
cost_model = core.CostModel()
cost_data = cost_model.ProfileMeasure(device)

def static_cost_data(self):
def static_cost_data(self) -> dict[str, str | float]:
static_cost_data_path = os.path.join(
os.path.dirname(__file__), "static_op_benchmark.json"
)
Expand All @@ -74,7 +80,9 @@ def static_cost_data(self):
# return all static cost data
return load_dict

def get_static_op_time(self, op_name, forward=True, dtype="float32"):
def get_static_op_time(
self, op_name: str, forward: bool = True, dtype: str = "float32"
) -> dict[str, str | float]:
# if forward is True, return op forward time, otherwise return op backward time.
if op_name is None:
raise ValueError(
Expand Down