Skip to content

Commit

Permalink
Prepare for "Fix type-safety of torch.nn.Module instances": fbcode/…
Browse files Browse the repository at this point in the history
…d* (pytorch#3397)

Summary:
Pull Request resolved: pytorch#3397

X-link: facebookresearch/FBGEMM#485

See D52890934

Reviewed By: r-barnes

Differential Revision: D66234699

fbshipit-source-id: 38586521c2e274d6f469b7d361adcd739a6736af
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 21, 2024
1 parent bea3968 commit 8993811
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 3 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def fill_random_scale_bias(
weights_precision: SparseType,
) -> None:
for t in range(T):
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
(weights, scale_shift) = emb.split_embedding_weights()[t]
if scale_shift is not None:
(E, R) = scale_shift.shape
Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/bench/split_embeddings_cache_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,14 @@ def replay_populate(linear_indices: Tensor) -> None:

total_rows = 0
for request in requests:
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a
# function.
prev = replay_cc.lxu_cache_state.clone().detach()
replay_populate(request)
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a
# function.
after = replay_cc.lxu_cache_state.clone().detach()

diff = after - prev
Expand Down Expand Up @@ -538,8 +544,14 @@ def replay_populate(linear_indices: Tensor) -> None:

total_rows = 0
for request in requests:
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a
# function.
prev = replay_cc.lxu_cache_state.clone().detach()
replay_populate(request)
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a
# function.
after = replay_cc.lxu_cache_state.clone().detach()

diff = after - prev
Expand Down
20 changes: 18 additions & 2 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ def device( # noqa C901
emb = emb.to(get_device())

if weights_precision == SparseType.INT8:
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
# min_val: float, max_val: float) -> None, (self:
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
# None, Tensor, Module]` is not a function.
emb.init_embedding_weights_uniform(-0.0003, 0.0003)

nparams = sum(d * E for d in Ds)
Expand Down Expand Up @@ -886,10 +890,16 @@ def cache( # noqa C901
NOT_FOUND = -1
for req in requests:
indices, offsets = req.unpack_2()
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a
# function.
old_lxu_cache_state = emb.lxu_cache_state.clone()
emb.prefetch(indices.long(), offsets.long())
exchanged_cache_lines.append(
(emb.lxu_cache_state != old_lxu_cache_state).sum().item()
# pyre-fixme[16]: Item `bool` of `bool | Tensor` has no attribute `sum`.
(emb.lxu_cache_state != old_lxu_cache_state)
.sum()
.item()
)
cache_misses.append((emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item())
emb.forward(indices.long(), offsets.long())
Expand Down Expand Up @@ -2433,10 +2443,16 @@ def nbit_cache( # noqa C901

for req in requests:
indices, offsets = req.unpack_2()
# pyre-fixme[29]: `Union[(self: TensorBase, memory_format:
# Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a
# function.
old_lxu_cache_state = emb.lxu_cache_state.clone()
emb.prefetch(indices, offsets)
exchanged_cache_lines.append(
(emb.lxu_cache_state != old_lxu_cache_state).sum().item()
# pyre-fixme[16]: Item `bool` of `bool | Tensor` has no attribute `sum`.
(emb.lxu_cache_state != old_lxu_cache_state)
.sum()
.item()
)
cache_misses.append(
(emb.lxu_cache_locations_list.top() == NOT_FOUND).sum().item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,8 @@ def print_uvm_cache_stats(self) -> None:
def prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.timestep_counter.increment()
self.timestep_prefetch_size.increment()
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
# a function.
if not self.lxu_cache_weights.numel():
return

Expand Down Expand Up @@ -891,6 +893,8 @@ def _update_tablewise_cache_miss(
CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32)
CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32)

# pyre-fixme[6]: For 1st argument expected
# `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`.
num_tables = len(self.cache_hash_size_cumsum) - 1
num_offsets_per_table = (len(offsets) - 1) // num_tables
cache_missed_locations = torch.where(
Expand Down Expand Up @@ -962,6 +966,8 @@ def forward(
self.index_remappings_array,
self.index_remappings_array_offsets,
)
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
# a function.
if self.lxu_cache_weights.numel() > 0:
if self.timestep_prefetch_size.get() <= 0:
self.prefetch(indices, offsets)
Expand Down Expand Up @@ -1136,8 +1142,13 @@ def recompute_module_buffers(self) -> None:
self.index_remappings_array_offsets = torch.empty_like(
self.index_remappings_array_offsets, device=self.current_device
)
# pyre-fixme[16]: `IntNBitTableBatchedEmbeddingBagsCodegen` has no attribute
# `lxu_cache_weights`.
self.lxu_cache_weights = torch.empty_like(
self.lxu_cache_weights, device=self.current_device
# pyre-fixme[6]: For 1st argument expected `Tensor` but got
# `Union[Module, Tensor]`.
self.lxu_cache_weights,
device=self.current_device,
)
self.original_rows_per_table = torch.empty_like(
self.original_rows_per_table, device=self.current_device
Expand Down Expand Up @@ -1432,6 +1443,8 @@ def _apply_cache_state(
self.reset_uvm_cache_stats()

def reset_cache_states(self) -> None:
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
# a function.
if not self.lxu_cache_weights.numel():
return
self.lxu_cache_state.fill_(-1)
Expand Down Expand Up @@ -1842,6 +1855,8 @@ def embedding_inplace_update_internal(
)

lxu_cache_locations = None
# pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not
# a function.
if self.lxu_cache_weights.numel() > 0:
linear_cache_indices = (
torch.ops.fbgemm.linearize_cache_indices_from_row_idx(
Expand Down
Loading

0 comments on commit 8993811

Please sign in to comment.