diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index a7f43a0ead..7d2c650925 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -496,6 +496,7 @@ def fill_random_scale_bias( weights_precision: SparseType, ) -> None: for t in range(T): + # pyre-fixme[29]: `Union[Module, Tensor]` is not a function. (weights, scale_shift) = emb.split_embedding_weights()[t] if scale_shift is not None: (E, R) = scale_shift.shape diff --git a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py index d3169ca814..8d9be72cb6 100644 --- a/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py +++ b/fbgemm_gpu/bench/split_embeddings_cache_benchmark.py @@ -421,8 +421,14 @@ def replay_populate(linear_indices: Tensor) -> None: total_rows = 0 for request in requests: + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a + # function. prev = replay_cc.lxu_cache_state.clone().detach() replay_populate(request) + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a + # function. after = replay_cc.lxu_cache_state.clone().detach() diff = after - prev @@ -538,8 +544,14 @@ def replay_populate(linear_indices: Tensor) -> None: total_rows = 0 for request in requests: + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a + # function. prev = replay_cc.lxu_cache_state.clone().detach() replay_populate(request) + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Module, Tensor]` is not a + # function. after = replay_cc.lxu_cache_state.clone().detach() diff = after - prev diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index f6c29f1a5c..beb78120d7 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -290,6 +290,10 @@ def device( # noqa C901 emb = emb.to(get_device()) if weights_precision == SparseType.INT8: + # pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen, + # min_val: float, max_val: float) -> None, (self: + # SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) -> + # None, Tensor, Module]` is not a function. emb.init_embedding_weights_uniform(-0.0003, 0.0003) nparams = sum(d * E for d in Ds) @@ -886,10 +890,16 @@ def cache( # noqa C901 NOT_FOUND = -1 for req in requests: indices, offsets = req.unpack_2() + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a + # function. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices.long(), offsets.long()) exchanged_cache_lines.append( - (emb.lxu_cache_state != old_lxu_cache_state).sum().item() + # pyre-fixme[16]: Item `bool` of `bool | Tensor` has no attribute `sum`. + (emb.lxu_cache_state != old_lxu_cache_state) + .sum() + .item() ) cache_misses.append((emb.lxu_cache_locations_list[0] == NOT_FOUND).sum().item()) emb.forward(indices.long(), offsets.long()) @@ -2433,10 +2443,16 @@ def nbit_cache( # noqa C901 for req in requests: indices, offsets = req.unpack_2() + # pyre-fixme[29]: `Union[(self: TensorBase, memory_format: + # Optional[memory_format] = ...) -> Tensor, Tensor, Module]` is not a + # function. old_lxu_cache_state = emb.lxu_cache_state.clone() emb.prefetch(indices, offsets) exchanged_cache_lines.append( - (emb.lxu_cache_state != old_lxu_cache_state).sum().item() + # pyre-fixme[16]: Item `bool` of `bool | Tensor` has no attribute `sum`. + (emb.lxu_cache_state != old_lxu_cache_state) + .sum() + .item() ) cache_misses.append( (emb.lxu_cache_locations_list.top() == NOT_FOUND).sum().item() diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index eba628ab3e..63a646dfbc 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -709,6 +709,8 @@ def print_uvm_cache_stats(self) -> None: def prefetch(self, indices: Tensor, offsets: Tensor) -> None: self.timestep_counter.increment() self.timestep_prefetch_size.increment() + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if not self.lxu_cache_weights.numel(): return @@ -891,6 +893,8 @@ def _update_tablewise_cache_miss( CACHE_MISS = torch.tensor([-1], device=self.current_device, dtype=torch.int32) CACHE_HIT = torch.tensor([-2], device=self.current_device, dtype=torch.int32) + # pyre-fixme[6]: For 1st argument expected + # `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`. num_tables = len(self.cache_hash_size_cumsum) - 1 num_offsets_per_table = (len(offsets) - 1) // num_tables cache_missed_locations = torch.where( @@ -962,6 +966,8 @@ def forward( self.index_remappings_array, self.index_remappings_array_offsets, ) + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if self.lxu_cache_weights.numel() > 0: if self.timestep_prefetch_size.get() <= 0: self.prefetch(indices, offsets) @@ -1136,8 +1142,13 @@ def recompute_module_buffers(self) -> None: self.index_remappings_array_offsets = torch.empty_like( self.index_remappings_array_offsets, device=self.current_device ) + # pyre-fixme[16]: `IntNBitTableBatchedEmbeddingBagsCodegen` has no attribute + # `lxu_cache_weights`. self.lxu_cache_weights = torch.empty_like( - self.lxu_cache_weights, device=self.current_device + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. + self.lxu_cache_weights, + device=self.current_device, ) self.original_rows_per_table = torch.empty_like( self.original_rows_per_table, device=self.current_device @@ -1432,6 +1443,8 @@ def _apply_cache_state( self.reset_uvm_cache_stats() def reset_cache_states(self) -> None: + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if not self.lxu_cache_weights.numel(): return self.lxu_cache_state.fill_(-1) @@ -1842,6 +1855,8 @@ def embedding_inplace_update_internal( ) lxu_cache_locations = None + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if self.lxu_cache_weights.numel() > 0: linear_cache_indices = ( torch.ops.fbgemm.linearize_cache_indices_from_row_idx( diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 8eb3a3de51..25b8255f19 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -1411,6 +1411,7 @@ def get_cache_miss_counter(self) -> Tensor: Returns: The cache miss counter """ + # pyre-fixme[7]: Expected `Tensor` but got `Union[Module, Tensor]`. return self.cache_miss_counter @torch.jit.export @@ -1709,11 +1710,23 @@ def forward( # noqa: C901 ) common_args = invokers.lookup_args.CommonArgs( placeholder_autograd_tensor=self.placeholder_autograd_tensor, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. dev_weights=self.weights_dev, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. host_weights=self.weights_host, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. uvm_weights=self.weights_uvm, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. lxu_cache_weights=self.lxu_cache_weights, + # pyre-fixme[6]: For 6th argument expected `Tensor` but got + # `Union[Module, Tensor]`. weights_placements=self.weights_placements, + # pyre-fixme[6]: For 7th argument expected `Tensor` but got + # `Union[Module, Tensor]`. weights_offsets=self.weights_offsets, D_offsets=self.D_offsets, total_D=self.total_D, @@ -1762,10 +1775,20 @@ def forward( # noqa: C901 ) momentum1 = invokers.lookup_args.Momentum( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. dev=self.momentum1_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. host=self.momentum1_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. uvm=self.momentum1_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. offsets=self.momentum1_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. placements=self.momentum1_placements, ) @@ -1785,10 +1808,20 @@ def forward( # noqa: C901 ) momentum2 = invokers.lookup_args.Momentum( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. dev=self.momentum2_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. host=self.momentum2_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. uvm=self.momentum2_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. offsets=self.momentum2_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. placements=self.momentum2_placements, ) # Sync with loaded state @@ -1847,17 +1880,37 @@ def forward( # noqa: C901 ) prev_iter = invokers.lookup_args.Momentum( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. dev=self.prev_iter_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. host=self.prev_iter_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. uvm=self.prev_iter_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. offsets=self.prev_iter_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. placements=self.prev_iter_placements, ) row_counter = invokers.lookup_args.Momentum( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. dev=self.row_counter_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. host=self.row_counter_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. uvm=self.row_counter_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. offsets=self.row_counter_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. placements=self.row_counter_placements, ) @@ -1914,6 +1967,8 @@ def forward( # noqa: C901 momentum1, iter=iter_int, apply_global_weight_decay=apply_global_weight_decay, + # pyre-fixme[6]: For 6th argument expected + # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, ), @@ -2051,6 +2106,8 @@ def _report_uvm_cache_stats(self) -> None: ) self.last_reported_uvm_stats = uvm_cache_stats + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. element_size = self.lxu_cache_weights.element_size() for stat_index in UVMCacheStatsIndex: stats_reporter.report_data_amount( @@ -2108,6 +2165,8 @@ def _prefetch( self.timestep += 1 self.timesteps_prefetched.append(self.timestep) + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if not self.lxu_cache_weights.numel(): return @@ -2248,8 +2307,12 @@ def _update_cache_miss_counter( miss_count = torch.sum(unique_ids_count_list) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A... + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A... self.cache_miss_counter[0] += (miss_count > 0).to(torch.int64) + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A... + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A... self.cache_miss_counter[1] += miss_count def _update_tablewise_cache_miss( @@ -2261,6 +2324,8 @@ def _update_tablewise_cache_miss( CACHE_MISS = -1 CACHE_HIT = -2 + # pyre-fixme[6]: For 1st argument expected + # `pyre_extensions.PyreReadOnly[Sized]` but got `Union[Module, Tensor]`. num_tables = len(self.cache_hash_size_cumsum) - 1 num_offsets_per_table = (len(offsets) - 1) // num_tables cache_missed_locations = torch.where( @@ -2308,7 +2373,9 @@ def split_embedding_weights(self) -> List[Tensor]: for t, (rows, dim, _, _) in enumerate(self.embedding_specs): if self.weights_precision == SparseType.INT8: dim += self.int8_emb_row_dim_offset + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An... placement = self.weights_physical_placements[t] + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An... offset = self.weights_physical_offsets[t] if placement == EmbeddingLocation.DEVICE.value: weights = self.weights_dev @@ -2316,6 +2383,8 @@ def split_embedding_weights(self) -> List[Tensor]: weights = self.weights_host else: weights = self.weights_uvm + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is + # not a function. if weights.dim() == 2: weights = weights.flatten() splits.append( @@ -2464,10 +2533,20 @@ def get_optimizer_states( if self.optimizer not in (OptimType.EXACT_SGD,): states.append( get_optimizer_states( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum1_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum1_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum1_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum1_physical_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum1_physical_placements, rowwise=self.optimizer in [ @@ -2485,10 +2564,20 @@ def get_optimizer_states( ): states.append( get_optimizer_states( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum2_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum2_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum2_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum2_physical_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.momentum2_physical_placements, rowwise=self.optimizer in (OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB), @@ -2500,10 +2589,20 @@ def get_optimizer_states( ): states.append( get_optimizer_states( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.prev_iter_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.prev_iter_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.prev_iter_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.prev_iter_physical_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.prev_iter_physical_placements, rowwise=True, ) @@ -2511,10 +2610,20 @@ def get_optimizer_states( if self._used_rowwise_adagrad_with_counter: states.append( get_optimizer_states( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.row_counter_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.row_counter_host, + # pyre-fixme[6]: For 3rd argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.row_counter_uvm, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.row_counter_physical_offsets, + # pyre-fixme[6]: For 5th argument expected `Tensor` but got + # `Union[Module, Tensor]`. self.row_counter_physical_placements, rowwise=True, ) @@ -2593,6 +2702,8 @@ def set_optimizer_step(self, step: int) -> None: @torch.jit.export def flush(self) -> None: + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if not self.lxu_cache_weights.numel(): return torch.ops.fbgemm.lxu_cache_flush( @@ -2997,6 +3108,8 @@ def _init_uvm_cache_stats(self) -> None: self.last_uvm_cache_print_state = torch.zeros_like(self.uvm_cache_stats) def reset_cache_states(self) -> None: + # pyre-fixme[29]: `Union[(self: TensorBase) -> int, Module, Tensor]` is not + # a function. if not self.lxu_cache_weights.numel(): return self.lxu_cache_state.fill_(-1) @@ -3139,6 +3252,8 @@ def _debug_print_input_stats_factory_impl( per_sample_weights (Optional[Tensor]): Input per sample weights """ + # pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex, + # float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function. if self.debug_step % 100 == 0: # Get number of features (T) and batch size (B) T = len(self.feature_table_map) @@ -3224,6 +3339,10 @@ def compute_numel_and_avg(counts: Tensor) -> Tuple[int, float]: avg_seglen_cta_per_row_mth, ) ) + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `debug_step`. + # pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex, + # float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function. self.debug_step += 1 @torch.jit.ignore @@ -3235,6 +3354,8 @@ def _debug_print_input_stats_factory_null( pass if int(os.environ.get("FBGEMM_DEBUG_PRINT_INPUT_STATS", "0")) == 1: + # pyre-fixme[16]: `SplitTableBatchedEmbeddingBagsCodegen` has no + # attribute `debug_step`. self.debug_step = 0 return _debug_print_input_stats_factory_impl return _debug_print_input_stats_factory_null diff --git a/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py b/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py index 50856aec28..003159e59b 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_overflow_test.py @@ -81,6 +81,7 @@ def test_cache_int32_overflow(self, stochastic_rounding: bool) -> None: lxu_cache_locations = to_device(lxu_cache_locations, use_cpu=False) # Does prefetch into the cache + # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A... cc.lxu_cache_weights[cache_idx] = cc_ref.weights_dev.view(-1, D)[0] # Mimic cache prefetching behavior diff --git a/fbgemm_gpu/test/tbe/cache/cache_test.py b/fbgemm_gpu/test/tbe/cache/cache_test.py index 40ef05e21e..8757617fa1 100644 --- a/fbgemm_gpu/test/tbe/cache/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache/cache_test.py @@ -264,6 +264,8 @@ def _test_cache_prefetch_pipeline( # noqa C901 .cuda() ) torch.cuda.synchronize() # make sure TBEs and inputs are ready + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[bool, + # Tensor]`. self.assertTrue(torch.all(cc.lxu_cache_locking_counter == 0)) cur_stream: torch.cuda.Stream = torch.cuda.current_stream() @@ -335,6 +337,8 @@ def _prefetch( ) assert_cache(output, output_ref, stochastic_rounding) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[bool, + # Tensor]`. self.assertTrue(torch.all(cc.lxu_cache_locking_counter == 0)) if prefetch_stream: diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index d700785364..955395b6ef 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -308,9 +308,13 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( # cache status; we use the exact same logic, but still assigning ways in a associative cache can be # arbitrary. We compare sum along ways in each set, instead of expecting exact tensor match. cache_weights_ref = torch.reshape( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. cc_ref.lxu_cache_weights, [-1, associativity], ) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. cache_weights = torch.reshape(cc.lxu_cache_weights, [-1, associativity]) torch.testing.assert_close( torch.sum(cache_weights_ref, 1), @@ -318,16 +322,26 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( equal_nan=True, ) torch.testing.assert_close( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. torch.sum(cc.lxu_cache_state, 1), + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. torch.sum(cc_ref.lxu_cache_state, 1), equal_nan=True, ) # lxu_state can be different as time_stamp values can be different. # we check the entries with max value. + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. max_timestamp_ref = torch.max(cc_ref.lxu_state) + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. max_timestamp_uvm_caching = torch.max(cc.lxu_state) x = cc_ref.lxu_state == max_timestamp_ref y = cc.lxu_state == max_timestamp_uvm_caching + # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[bool, + # Tensor]`. torch.testing.assert_close(torch.sum(x, 1), torch.sum(y, 1)) # int_nbit_split_embedding_uvm_caching_codegen_lookup_function for UVM. diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py index ab015aaaaf..e9aa836116 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_global_weight_decay_test.py @@ -336,6 +336,8 @@ def execute_global_weight_decay( # noqa C901 T, Bs, tbe_ref, + # pyre-fixme[6]: For 4th argument expected `Tensor` but got + # `Union[Tensor, Module]`. tbe.prev_iter_dev, i, indices, @@ -365,7 +367,11 @@ def execute_global_weight_decay( # noqa C901 # compare weights output_ref.backward(grad_ref) compare_output( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. tbe_ref.weights_dev, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Tensor, Module]`. tbe.weights_dev, is_fp32, ) diff --git a/fbgemm_gpu/test/tbe/training/backward_none_test.py b/fbgemm_gpu/test/tbe/training/backward_none_test.py index 3a6d687bfe..5da1bc089d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_none_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_none_test.py @@ -370,9 +370,15 @@ def execute_backward_none_( # noqa C901 fc2.backward(goc) if optimizer is not None: + # pyre-fixme[6]: For 1st argument expected `Parameter` but got + # `Union[Tensor, Module]`. params = SplitEmbeddingOptimizerParams(weights_dev=cc.weights_dev) embedding_args = SplitEmbeddingArgs( + # pyre-fixme[6]: For 1st argument expected `Tensor` but got + # `Union[Tensor, Module]`. weights_placements=cc.weights_placements, + # pyre-fixme[6]: For 2nd argument expected `Tensor` but got + # `Union[Tensor, Module]`. weights_offsets=cc.weights_offsets, max_D=cc.max_D, ) @@ -403,6 +409,8 @@ def execute_backward_none_( # noqa C901 ref_grad.half() if weights_precision == SparseType.FP16 else ref_grad ) else: + # pyre-fixme[16]: Item `None` of `None | Tensor | Module` has no + # attribute `_indices`. indices = cc.weights_dev.grad._indices().flatten() # Select only the part in the table that is updated test_tensor = torch.index_select(cc.weights_dev.view(-1, D), 0, indices)