From cd34953366075ea93f02e2f4d36de53cef17836c Mon Sep 17 00:00:00 2001 From: heheda Date: Fri, 6 Oct 2023 23:20:08 +0800 Subject: [PATCH] support blockdrop.forward_full --- frontend/guard_tracker.py | 90 +++++++----- frontend/object_table.py | 2 +- frontend/utils.py | 4 +- test/test_model_blockdrop.py | 256 +++++++++++++++++++++++++++++++++++ test/test_tensor.py | 18 ++- 5 files changed, 330 insertions(+), 40 deletions(-) create mode 100644 test/test_model_blockdrop.py diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 56ebd8cbf913..7154f3e84af0 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -10,6 +10,7 @@ import traceback import copy import dataclasses +import torch.fx.immutable_collections as fx_immutable from .code import ProcessedCode from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect, get_code_map, is_bound_method from .instruction import Instruction, ci @@ -103,20 +104,18 @@ def as_node_args_kwargs( self, args: list[Any], kwargs: dict[str, Any] ) -> tuple[tuple[torch.fx.Node, ...], dict[str, torch.fx.Node]]: - def as_fx_node(var: vs.Variable) -> NodeArgs: + def as_fx_node(arg: Any) -> NodeArgs: + if isinstance(arg, (tuple, list)): + return fx_immutable.immutable_list([as_fx_node(x) for x in arg]) + var = self.objects.get(arg, allow_unexist_const=True) if isinstance(var, vs.TorchParamVar): return self.fx_graph.create_node("get_attr", self.subparam_paths[var.obj], (), {}) return var.as_fx_node() - node_args = tuple( - as_fx_node(self.objects.get(arg, allow_unexist_const=True)) - for arg in args) - node_kwargs = { - key: as_fx_node(self.objects.get(arg, allow_unexist_const=True)) - for key, arg in kwargs.items() - } + node_args = tuple(as_fx_node(arg) for arg in args) + node_kwargs = {key: as_fx_node(arg) for key, arg in kwargs.items()} return node_args, node_kwargs @@ -195,9 +194,12 @@ def add_stored_globals(self, name: str) -> None: self.written = True self.stored_globals.add(name) - def store_pos_in_callee(self, pos: StorePos, idx: int) -> StorePos: + def store_pos_in_callee(self, pos: StorePos, + idx: int) -> Optional[StorePos]: if idx in self.objects.objs: var = self.objects.objs[idx] + if len(var.extract_code_at_start) == 0: + return None return var.extract_code_at_start[0] if isinstance(pos, StoreInLocal): raise ValueError("unknown local in callee", pos) @@ -206,21 +208,25 @@ def store_pos_in_callee(self, pos: StorePos, idx: int) -> StorePos: elif isinstance(pos, (StoreInGlobal, StoreInBuiltin)): return pos elif isinstance(pos, StoreInAttr): - return StoreInAttr( - self.store_pos_in_callee(pos.self_pos, pos.self_id), - pos.self_id, pos.attr_name) + parent_pos = self.store_pos_in_callee(pos.self_pos, pos.self_id) + if parent_pos is None: + return None + return StoreInAttr(parent_pos, pos.self_id, pos.attr_name) elif isinstance(pos, StoreInIndex): - return StoreInIndex( - self.store_pos_in_callee(pos.self_pos, pos.self_id), - pos.self_id, pos.self_index) + parent_pos = self.store_pos_in_callee(pos.self_pos, pos.self_id) + if parent_pos is None: + return None + return StoreInIndex(parent_pos, pos.self_id, pos.self_index) elif isinstance(pos, ExtractFromMethod): - return ExtractFromMethod( - self.store_pos_in_callee(pos.self_pos, pos.self_id), - pos.self_id, pos.method_name) + parent_pos = self.store_pos_in_callee(pos.self_pos, pos.self_id) + if parent_pos is None: + return None + return ExtractFromMethod(parent_pos, pos.self_id, pos.method_name) elif isinstance(pos, ExtractFromFunction): - return ExtractFromFunction( - self.store_pos_in_callee(pos.var_pos, pos.var_id), pos.var_id, - pos.func_name) + parent_pos = self.store_pos_in_callee(pos.var_pos, pos.var_id) + if parent_pos is None: + return None + return ExtractFromFunction(parent_pos, pos.var_id, pos.func_name) else: raise NotImplementedError @@ -246,11 +252,13 @@ def merge_call_guard() -> None: elif isinstance(pos, StoreInGlobal): new_var.extract_code_at_start.append(pos) elif isinstance(pos, StoreInAttr): - new_var.extract_code_at_start.append( - self.store_pos_in_callee(pos, idx)) + self_pos = self.store_pos_in_callee(pos, idx) + assert self_pos is not None + new_var.extract_code_at_start.append(self_pos) elif isinstance(pos, StoreInIndex): - new_var.extract_code_at_start.append( - self.store_pos_in_callee(pos, idx)) + self_pos = self.store_pos_in_callee(pos, idx) + assert self_pos is not None + new_var.extract_code_at_start.append(self_pos) else: raise NotImplementedError(pos) @@ -387,25 +395,29 @@ def get_or_make_var( merged_ids.add(id(obj)) return new_var + def get_new_store_pos(old: list[StorePos], + idx: int) -> list[StorePos]: + new: list[StorePos] = [] + for pos in old: + new_pos = self.store_pos_in_callee(pos, idx) + if new_pos is not None: + new.append(new_pos) + return new + for idx, var in state.objects.get_all_with_id(): if var.prev is not None: oldest = var.get_oldest_var() if len(oldest.extract_code_at_start) == 0: continue - new_extract = [ - self.store_pos_in_callee(pos, idx) - for pos in oldest.extract_code_at_start - ] + new_extract: list[StorePos] = get_new_store_pos( + oldest.extract_code_at_start, idx) get_or_make_var(var.obj, var.need_guard_check, self.fx_graph, new_extract) var = state.objects.get(return_value, allow_unexist_const=True) - new_extract = [ - self.store_pos_in_callee(pos, id(var.obj)) - for pos in var.extract_code_at_start - ] - get_or_make_var(return_value, var.need_guard_check, self.fx_graph, - new_extract) + new_extract = get_new_store_pos(var.extract_code_at_start, + id(var.obj)) + get_or_make_var(return_value, False, self.fx_graph, new_extract) if len(state.objects.objs_no_id) > 0: raise NotImplementedError @@ -745,6 +757,7 @@ def call_function( ] }) if is_user_defined_func(func) or isinstance(func, torch.nn.Sequential): + print("run into user defined function") stack_objs = get_all_objects_in_stack(self.frame) self.state.mark_defer_restart(f"call_function", stack_objs) from .tracer import get_process_frame @@ -860,6 +873,13 @@ def BINARY_SUBSCR(self, inst: Instruction) -> None: obj2 = get_value_stack_from_top(self.frame, 0) self.call_function(operator.getitem, [obj1, obj2], {}) + def COMPARE_OP(self, inst: Instruction) -> None: + obj1 = get_value_stack_from_top(self.frame, 1) + obj2 = get_value_stack_from_top(self.frame, 0) + cmp_op = ('lt', 'le', 'eq', 'ne', 'gt', 'ge') + self.call_function(getattr(operator, cmp_op[inst.arg]), [obj1, obj2], + {}) + def INPLACE_POWER(self, _inst: Instruction) -> None: self.binary_operation(operator.ipow) diff --git a/frontend/object_table.py b/frontend/object_table.py index 7f4d135e9fb4..723f94613822 100644 --- a/frontend/object_table.py +++ b/frontend/object_table.py @@ -24,7 +24,7 @@ def add(self, var: Variable, value: Any) -> None: old_var.extract_code_at_start.extend(var.extract_code_at_start) old_var.need_guard_check |= var.need_guard_check else: - self.objs[id(value)] = var + self.add_by_id(var, id(value)) var.add_subvars_to_table(self) def add_by_id(self, var: Variable, idx: int) -> None: diff --git a/frontend/utils.py b/frontend/utils.py index 1f91f148f0df..9e6bd9aa4bea 100644 --- a/frontend/utils.py +++ b/frontend/utils.py @@ -107,8 +107,6 @@ def get_root_module(func: Callable[..., Any]) -> str: def is_user_defined_func(func: Callable[..., Any]) -> bool: - if func in fx_graph_functions: - return False if hasattr(func, '__objclass__') and func.__objclass__ == torch._C._TensorBase: return False @@ -116,7 +114,7 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool: root_module = get_root_module(func) if root_module == '': return True - if root_module in ('math', 'builtins', 'torch', 'numpy'): + if root_module in ('math', 'builtins', 'torch', 'numpy', '_operator'): return False return True diff --git a/test/test_model_blockdrop.py b/test/test_model_blockdrop.py new file mode 100644 index 000000000000..084a63fdae6a --- /dev/null +++ b/test/test_model_blockdrop.py @@ -0,0 +1,256 @@ +import pytest +from frontend.compile import compile, reset +from frontend.utils import add_force_graph_break +from frontend.c_api import get_next_frame_id +import logging +from common.checker import run_and_check, HIT, MISS + +import torch +import torch.nn as nn +import random +import math +import torch.nn.functional as F + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class Flatten(nn.Module): + + def __init__(self): + super(Flatten, self).__init__() + + def forward(self, x): + return x.view(x.size(0), -1) + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + def forward(self, x): + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + return out + + +class DownsampleB(nn.Module): + + def __init__(self, nIn, nOut, stride): + super(DownsampleB, self).__init__() + self.avg = nn.AvgPool2d(stride) + self.expand_ratio = nOut // nIn + + def forward(self, x): + x = self.avg(x) + return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1) + + +#--------------------------------------------------------------------------------------------------# +class FlatResNet(nn.Module): + + def seed(self, x): + # x = self.relu(self.bn1(self.conv1(x))) -- CIFAR + # x = self.maxpool(self.relu(self.bn1(self.conv1(x)))) -- ImageNet + raise NotImplementedError + + # run a variable policy batch through the resnet implemented as a full mask over the residual + # fast to train, non-indicative of time saving (use forward_single instead) + def forward(self, x, policy): + + x = self.seed(x) + + t = 0 + for segment, num_blocks in enumerate(self.layer_config): + for b in range(num_blocks): + action = policy[:, t].contiguous() + residual = self.ds[segment](x) if b == 0 else x + + # early termination if all actions in the batch are zero + if action.data.sum() == 0: + x = residual + t += 1 + continue + + action_mask = action.float().view(-1, 1, 1, 1) + fx = F.relu(residual + self.blocks[segment][b](x)) + x = fx * action_mask + residual * (1 - action_mask) + t += 1 + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + # run a single, fixed policy for all items in the batch + # policy is a (15,) vector. Use with batch_size=1 for profiling + def forward_single(self, x, policy): + x = self.seed(x) + + t = 0 + for segment, num_blocks in enumerate(self.layer_config): + for b in range(num_blocks): + residual = self.ds[segment](x) if b == 0 else x + if policy[t] == 1: + x = residual + self.blocks[segment][b](x) + x = F.relu(x) + else: + x = residual + t += 1 + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def forward_full(self, x): + x = self.seed(x) + + for segment, num_blocks in enumerate(self.layer_config): + for b in range(num_blocks): + residual = self.ds[segment](x) if b == 0 else x + x = F.relu(residual + self.blocks[segment][b](x)) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +# Smaller Flattened Resnet, tailored for CIFAR +class FlatResNet32(FlatResNet): + + def __init__(self, block, layers, num_classes=10): + super(FlatResNet32, self).__init__() + + self.inplanes = 16 + self.conv1 = conv3x3(3, 16) + self.bn1 = nn.BatchNorm2d(16) + self.relu = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(8) + + strides = [1, 2, 2] + filt_sizes = [16, 32, 64] + self.blocks, self.ds = [], [] + for idx, (filt_size, num_blocks, + stride) in enumerate(zip(filt_sizes, layers, strides)): + blocks, ds = self._make_layer(block, + filt_size, + num_blocks, + stride=stride) + self.blocks.append(nn.ModuleList(blocks)) + self.ds.append(ds) + + self.blocks = nn.ModuleList(self.blocks) + self.ds = nn.ModuleList(self.ds) + self.fc = nn.Linear(64 * block.expansion, num_classes) + self.fc_dim = 64 * block.expansion + + self.layer_config = layers + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def seed(self, x): + x = self.relu(self.bn1(self.conv1(x))) + return x + + def _make_layer(self, block, planes, blocks, stride=1): + + downsample = nn.Sequential() + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleB(self.inplanes, planes * block.expansion, + stride) + + layers = [block(self.inplanes, planes, stride)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, 1)) + + return layers, downsample + + +def get_input(batch_size): + return torch.randn(batch_size, 3, 32, + 32).cuda(), torch.randint(0, 2, (batch_size, 15)).cuda() + + +def test_blockdrop_full(caplog): + reset() + with torch.no_grad(): + model = FlatResNet32(BasicBlock, [5, 5, 5]).cuda() + model.eval() + batch_size = 2 + inp = torch.randn(batch_size, 3, 32, 32).cuda() + expect_result = model.forward_full(inp) + compiled = compile(model.forward_full) + run_and_check(compiled, [MISS] * 20, 1, caplog, expect_result, inp) + run_and_check(compiled, [HIT], 1, caplog, expect_result, inp) diff --git a/test/test_tensor.py b/test/test_tensor.py index a58a8b3de5c3..f108f1bd32d6 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -147,6 +147,22 @@ def test_tensor_functional(caplog): a) +def fx_nest(x): + return torch.cat([x] + [x.mul(0)] * 2, 1) + + +def test_fx_nest(caplog): + reset() + compiled_fx_nest = compile(fx_nest) + a = torch.randn((3, 3)) + expect_result = fx_nest(a) + run_and_check(compiled_fx_nest, [MISS], 1, caplog, expect_result, a) + run_and_check(compiled_fx_nest, [HIT], 1, caplog, expect_result, a) + b = torch.randn((3, 3)) + expect_result = fx_nest(b) + run_and_check(compiled_fx_nest, [HIT], 1, caplog, expect_result, b) + + def tensor_shape(a): return a.size(), a.shape @@ -170,4 +186,4 @@ def test_tensor_dtype(caplog): a = torch.randn((3, 3)) expect_result = tensor_dtype(a) run_and_check(compiled_tensor_dtype, [MISS], 1, caplog, expect_result, a) - run_and_check(compiled_tensor_dtype, [HIT], 1, caplog, expect_result, a) \ No newline at end of file + run_and_check(compiled_tensor_dtype, [HIT], 1, caplog, expect_result, a)