diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index afb931bb3..ee4176d93 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -34,6 +34,7 @@ class EmbOptimType(enum.Enum): MADGRAD = "madgrad" EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad" + EMAWITHINTABLE_ROWWISE_ADAGRAD = "ema_within_table_row_wise_adagrad" NONE = "none" def __str__(self) -> str: diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 25b8255f1..57905cd5f 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -154,6 +154,13 @@ class EnsembleModeDefinition: step_mode: StepMode = StepMode.USE_ITER +@dataclass(frozen=True) +class EmawithintableModeDefinition: + step_ema: float = 10 + step_start: float = 0 + step_ema_coef: float = 0.6 + + # Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh class UVMCacheStatsIndex(enum.IntEnum): num_calls = 0 @@ -430,7 +437,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): (9) `ENSEMBLE_ROWWISE_ADAGRAD` = Ensemble rowwise-Adagrad - (10) `NONE` = Not applying an optimizer update in the backward pass + (10) `EMAWITHINTABLE_ROWWISE_ADAGRAD` = Ema within table rowwise-Adagrad + + (11) `NONE` = Not applying an optimizer update in the backward pass and outputting a sparse weight gradient record_cache_metrics (Optional[RecordCacheMetrics] = None): Record @@ -484,6 +493,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module): ensemble_mode (Optional[EnsembleModeDefinition] = None): Used by Ensemble Rowwise Adagrad + emawithintable_mode (Optional[EmawithintableModeDefinition] = None): + Used by Ema within table Rowwise Adagrad + counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None): Used by Rowwise Adagrad @@ -601,6 +613,7 @@ def __init__( # noqa C901 beta1: float = 0.9, beta2: float = 0.999, ensemble_mode: Optional[EnsembleModeDefinition] = None, + emawithintable_mode: Optional[EmawithintableModeDefinition] = None, counter_based_regularization: Optional[ CounterBasedRegularizationDefinition ] = None, @@ -863,6 +876,7 @@ def __init__( # noqa C901 OptimType.EXACT_ADAGRAD, OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_SGD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, ), f"Optimizer {optimizer} is not supported in CPU mode." else: assert optimizer in ( @@ -875,6 +889,7 @@ def __init__( # noqa C901 OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, OptimType.ENSEMBLE_ROWWISE_ADAGRAD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, OptimType.NONE, ), f"Optimizer {optimizer} is not supported." @@ -932,6 +947,12 @@ def __init__( # noqa C901 key: float(fval) for key, fval in ensemble_mode.__dict__.items() } + if emawithintable_mode is None: + emawithintable_mode = EmawithintableModeDefinition() + self._emawithintable_mode: Dict[str, float] = { + key: float(fval) for key, fval in emawithintable_mode.__dict__.items() + } + if counter_based_regularization is None: counter_based_regularization = CounterBasedRegularizationDefinition() if cowclip_regularization is None: @@ -1010,6 +1031,7 @@ def __init__( # noqa C901 rowwise = optimizer in [ OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.ENSEMBLE_ROWWISE_ADAGRAD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, ] self._apply_split( construct_split_state( @@ -1137,6 +1159,7 @@ def __init__( # noqa C901 OptimType.PARTIAL_ROWWISE_ADAM, OptimType.PARTIAL_ROWWISE_LAMB, OptimType.ENSEMBLE_ROWWISE_ADAGRAD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, ) or self._used_rowwise_adagrad_with_global_weight_decay ): @@ -1520,12 +1543,12 @@ def _generate_vbe_metadata( OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_SGD, OptimType.ENSEMBLE_ROWWISE_ADAGRAD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, OptimType.NONE, - ), ( - "Variable batch size TBE support is enabled for " - "OptimType.EXACT_ROWWISE_ADAGRAD and " - "ENSEMBLE_ROWWISE_ADAGRAD only" - ) + ), """ + Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD, + OptimType.ENSEMBLE_ROWWISE_ADAGRAD, and OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD only. + """ return generate_vbe_metadata( offsets, batch_size_per_feature_per_rank, @@ -1914,6 +1937,19 @@ def forward( # noqa: C901 placements=self.row_counter_placements, ) + if self.optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD: + with torch.no_grad(): + if self.training: + self.ema_within_table(self._emawithintable_mode) + return self._report_io_size_count( + "fwd_output", + invokers.lookup_rowwise_adagrad.invoke( + common_args, + self.optimizer_args, + momentum1, + ), + ) + if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD: assert self._feature_is_enabled( FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD @@ -1985,6 +2021,36 @@ def forward( # noqa: C901 raise ValueError(f"Invalid OptimType: {self.optimizer}") + def ema_within_table(self, emawithintable_mode: Dict[str, float]) -> None: + """ + Perform ema operations on the full sparse embedding tables. + We organize the sparse table, in the following way. + + Emb_table: + ------------------------------------------------- + - -- - + - Fast part -- Slow part - + - (RL) main part -- target part - + - -- - + ------------------------------------------------- + + In every "step_ema" step, we perform + slow_part += coef_ema * (fast_part - slow_part) + """ + iter_int = int(self.iter_cpu.item()) + if iter_int % int(emawithintable_mode["step_ema"]) == 0 and iter_int >= int( + emawithintable_mode["step_start"] + ): + weights = self.split_embedding_weights() + for table_i, (_, dim, _, _) in enumerate(self.embedding_specs): + assert ( + dim & 1 == 0 + ), f"table dimension {dim} is odd, not supported for ema_within_table" # make sure that the dimension is even + weights[table_i][:, dim // 2 :].data.lerp_( + weights[table_i][:, : dim // 2].data, + emawithintable_mode["step_ema_coef"], + ) + def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: """ Perform ensemble and swap operations on the full sparse embedding tables. @@ -2018,8 +2084,6 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None: if should_ema: if int(ensemble_mode["step_mode"]) == 0: # embedding scaling states[i][1].mul_(0.0) - elif int(ensemble_mode["step_mode"]) == 1: # nesterov - states[i][1].copy_(weights_cpu, non_blocking=True) # elif int(ensemble_mode["step_mode"]) == 2: pure ema def reset_uvm_cache_stats(self) -> None: @@ -2413,6 +2477,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]: if ( self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD or self.optimizer == OptimType.EXACT_ADAGRAD + or self.optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD ): list_of_state_dict = [ ( @@ -2552,6 +2617,7 @@ def get_optimizer_states( in [ OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.ENSEMBLE_ROWWISE_ADAGRAD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, ], ) ) diff --git a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py index c75ededf7..6b5bdaaba 100644 --- a/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_optimizers_test.py @@ -26,6 +26,7 @@ CounterBasedRegularizationDefinition, CounterWeightDecayMode, CowClipDefinition, + EmawithintableModeDefinition, EnsembleModeDefinition, GradSumDecay, LearningRateMode, @@ -333,6 +334,19 @@ def execute_backward_optimizers_( # noqa C901 step_mode=step_mode, ) + if optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD: + (eps, step_ema, step_start) = ( + 1e-4, + 1.0, + 0.0, + ) + optimizer_kwargs["eps"] = eps + optimizer_kwargs["emawithintable_mode"] = EmawithintableModeDefinition( + step_ema=step_ema, + step_start=step_start, + step_ema_coef=momentum, + ) + cc = emb_op( embedding_specs=[ (E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed) @@ -398,6 +412,7 @@ def execute_backward_optimizers_( # noqa C901 OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD, OptimType.ENSEMBLE_ROWWISE_ADAGRAD, + OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD, ) if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD): @@ -612,6 +627,34 @@ def execute_backward_optimizers_( # noqa C901 "sparse_ema", } + if optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD: + for t in range(T): + iter_ = cc.iter.item() + weights_new = split_weights[t] + + # forward update + weights_ref = bs[t].weight.cpu() + dim = weights_ref.shape[1] + weights_ref[:, dim // 2 :] = (1 - momentum) * weights_ref[ + :, dim // 2 : + ] + momentum * weights_ref[:, : dim // 2] + dense_cpu_grad = bs[t].weight.grad.cpu().to_dense() + v_hat_t = dense_cpu_grad.pow(2).mean(dim=1) + v_hat_t = v_hat_t.view(v_hat_t.numel(), 1) + weights_ref = torch.addcdiv( + weights_ref, + value=-lr, + tensor1=dense_cpu_grad, + tensor2=v_hat_t.sqrt_().add_(eps), + ) + + torch.testing.assert_close( + weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(), + weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()), + atol=1.0e-3, + rtol=1.0e-3, + ) + if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB): rowwise = optimizer == OptimType.PARTIAL_ROWWISE_LAMB for t in range(T): @@ -1107,6 +1150,66 @@ def test_backward_optimizers_ensemble_rowwise_adagrad( # noqa C901 optimizer_state_dtypes=optimizer_state_dtypes, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.integers(min_value=2, max_value=256), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=2, max_value=20), + weighted=st.booleans(), + mixed=st.booleans(), + mixed_B=st.booleans(), + optimizer=st.just(OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD), + long_segments=st.booleans(), + pooling_mode=st.sampled_from( + [ + PoolingMode.SUM, + PoolingMode.MEAN, + PoolingMode.NONE, + ] + ), + use_cpu=use_cpu_strategy(), + uvm_non_rowwise_momentum=st.booleans(), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large], + ) + @unittest.skipIf(*gpu_unavailable) + def test_backward_optimizers_emawithintable_rowwise_adagrad( # noqa C901 + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weighted: bool, + mixed: bool, + mixed_B: bool, + optimizer: OptimType, + long_segments: bool, + pooling_mode: PoolingMode, + use_cpu: bool, + uvm_non_rowwise_momentum: bool, + ) -> None: + self.execute_backward_optimizers_( + T, + D * 2, # dimension is required to be even + B, + log_E, + L, + weighted, + mixed, + mixed_B, + optimizer, + long_segments, + pooling_mode, + use_cpu, + uvm_non_rowwise_momentum=uvm_non_rowwise_momentum, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.integers(min_value=2, max_value=256),