Skip to content

Commit

Permalink
fix data_op device for gpu pinned tensor (#60357)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Dec 27, 2023
1 parent 9952846 commit 430894e
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,12 @@ phi::KernelKey GetKernelKey(
auto data_place =
op->attributes().at("place").dyn_cast<PlaceAttribute>().data();

auto backend = paddle::experimental::ParseBackend(data_place);
phi::Backend backend;
if (data_place.GetType() == AllocationType::GPUPINNED) {
backend = phi::Backend::CPU;
} else {
backend = paddle::experimental::ParseBackend(data_place);
}

return {backend,
phi::DataLayout::ANY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ def collect(inp):
def guard_fn(self) -> Guard:
with tmp_name_guard():
guards = []
with EventGuard(
"guard_fn: find vars and make stringify guard", event_level=1
):
with EventGuard("guard_fn: find vars and make stringify guard"):
for variable in find_traceable_vars(
self.input_variables + list(self._global_guarded_variables)
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def step(self, instr: Instruction):

opname = instr.opname if instr.opname != "PRECALL" else "PRECALL__CALL"
assert opname != "CALL", "CALL should fused with PRECALL"
with EventGuard(f"{opname}", event_level=1):
with EventGuard(f"{opname}", event_level=2):
return getattr(self, opname)(instr) # run single step.

def indexof(self, instr: Instruction):
Expand Down
17 changes: 4 additions & 13 deletions python/paddle/jit/sot/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from paddle.framework import core

_event_level = int(os.environ.get("EVENT_LEVEL", "-1"))
_event_level = int(os.environ.get("EVENT_LEVEL", "0"))


class SotProfiler:
Expand All @@ -37,7 +37,7 @@ def disable(self):


@contextmanager
def EventGuard(event_name, event_level=0):
def EventGuard(event_name, event_level=1):
try:
global _event_level
need_pop = False
Expand All @@ -50,20 +50,11 @@ def EventGuard(event_name, event_level=0):
core.nvprof_nvtx_pop()


if _event_level == -1:

@contextmanager
def _EmptyEventGuard(event_name, event_level=0):
yield

EventGuard = _EmptyEventGuard # noqa: F811


def event_register(event_name, event_level=0):
def event_register(event_name, event_level=1):
def event_wrapper(func):
@wraps(func)
def call_with_event(*args, **kwargs):
with EventGuard(event_name, event_level=0):
with EventGuard(event_name, event_level=event_level):
return func(*args, **kwargs)

return call_with_event
Expand Down

0 comments on commit 430894e

Please sign in to comment.