Skip to content

Commit

Permalink
optimizer 1d -- EMA in place (fbgemm part) (#3402)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#490


Implement the ema_within_table_rowwise_adagrad
```
        Emb_table:
         -------------------------------------------------
         -                        --                     -
         -        Fast part       --      Slow part      -
         -    (RL) main part      --      target part    -
         -                        --                     -
         -------------------------------------------------

         In every "step_ema" step, we perform
            slow_part += coef_ema * (fast_part - slow_part)
```

It mainly serves the target network purpose in the reinforcement learning framework.
Design doc https://fburl.com/gdoc/qyfv7tyi

Differential Revision: D66015331
  • Loading branch information
Zhihao Cen authored and facebook-github-bot committed Nov 21, 2024
1 parent b94be33 commit 00bc046
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 8 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class EmbOptimType(enum.Enum):
MADGRAD = "madgrad"
EXACT_ROWWISE_WEIGHTED_ADAGRAD = "exact_row_wise_weighted_adagrad" # deprecated
ENSEMBLE_ROWWISE_ADAGRAD = "ensemble_row_wise_adagrad"
EMAWITHINTABLE_ROWWISE_ADAGRAD = "ema_within_table_row_wise_adagrad"
NONE = "none"

def __str__(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ class EnsembleModeDefinition:
step_mode: StepMode = StepMode.USE_ITER


@dataclass(frozen=True)
class EmawithintableModeDefinition:
step_ema: float = 10
step_start: float = 0
step_ema_coef: float = 0.6


# Keep in sync with fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache_cuda.cuh
class UVMCacheStatsIndex(enum.IntEnum):
num_calls = 0
Expand Down Expand Up @@ -430,7 +437,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
(9) `ENSEMBLE_ROWWISE_ADAGRAD` = Ensemble rowwise-Adagrad
(10) `NONE` = Not applying an optimizer update in the backward pass
(10) `EMAWITHINTABLE_ROWWISE_ADAGRAD` = Ema within table rowwise-Adagrad
(11) `NONE` = Not applying an optimizer update in the backward pass
and outputting a sparse weight gradient
record_cache_metrics (Optional[RecordCacheMetrics] = None): Record
Expand Down Expand Up @@ -484,6 +493,9 @@ class SplitTableBatchedEmbeddingBagsCodegen(nn.Module):
ensemble_mode (Optional[EnsembleModeDefinition] = None):
Used by Ensemble Rowwise Adagrad
emawithintable_mode (Optional[EmawithintableModeDefinition] = None):
Used by Ema within table Rowwise Adagrad
counter_based_regularization (Optional[CounterBasedRegularizationDefinition] = None):
Used by Rowwise Adagrad
Expand Down Expand Up @@ -601,6 +613,7 @@ def __init__( # noqa C901
beta1: float = 0.9,
beta2: float = 0.999,
ensemble_mode: Optional[EnsembleModeDefinition] = None,
emawithintable_mode: Optional[EmawithintableModeDefinition] = None,
counter_based_regularization: Optional[
CounterBasedRegularizationDefinition
] = None,
Expand Down Expand Up @@ -863,6 +876,7 @@ def __init__( # noqa C901
OptimType.EXACT_ADAGRAD,
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_SGD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
), f"Optimizer {optimizer} is not supported in CPU mode."
else:
assert optimizer in (
Expand All @@ -875,6 +889,7 @@ def __init__( # noqa C901
OptimType.PARTIAL_ROWWISE_ADAM,
OptimType.PARTIAL_ROWWISE_LAMB,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
OptimType.NONE,
), f"Optimizer {optimizer} is not supported."

Expand Down Expand Up @@ -932,6 +947,12 @@ def __init__( # noqa C901
key: float(fval) for key, fval in ensemble_mode.__dict__.items()
}

if emawithintable_mode is None:
emawithintable_mode = EmawithintableModeDefinition()
self._emawithintable_mode: Dict[str, float] = {
key: float(fval) for key, fval in emawithintable_mode.__dict__.items()
}

if counter_based_regularization is None:
counter_based_regularization = CounterBasedRegularizationDefinition()
if cowclip_regularization is None:
Expand Down Expand Up @@ -1010,6 +1031,7 @@ def __init__( # noqa C901
rowwise = optimizer in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
]
self._apply_split(
construct_split_state(
Expand Down Expand Up @@ -1137,6 +1159,7 @@ def __init__( # noqa C901
OptimType.PARTIAL_ROWWISE_ADAM,
OptimType.PARTIAL_ROWWISE_LAMB,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
)
or self._used_rowwise_adagrad_with_global_weight_decay
):
Expand Down Expand Up @@ -1520,12 +1543,12 @@ def _generate_vbe_metadata(
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_SGD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
OptimType.NONE,
), (
"Variable batch size TBE support is enabled for "
"OptimType.EXACT_ROWWISE_ADAGRAD and "
"ENSEMBLE_ROWWISE_ADAGRAD only"
)
), """
Variable batch size TBE support is enabled for OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD, and OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD only.
"""
return generate_vbe_metadata(
offsets,
batch_size_per_feature_per_rank,
Expand Down Expand Up @@ -1914,6 +1937,19 @@ def forward( # noqa: C901
placements=self.row_counter_placements,
)

if self.optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD:
with torch.no_grad():
if self.training:
self.ema_within_table(self._emawithintable_mode)
return self._report_io_size_count(
"fwd_output",
invokers.lookup_rowwise_adagrad.invoke(
common_args,
self.optimizer_args,
momentum1,
),
)

if self.optimizer == OptimType.ENSEMBLE_ROWWISE_ADAGRAD:
assert self._feature_is_enabled(
FeatureGateName.TBE_ENSEMBLE_ROWWISE_ADAGRAD
Expand Down Expand Up @@ -1985,6 +2021,36 @@ def forward( # noqa: C901

raise ValueError(f"Invalid OptimType: {self.optimizer}")

def ema_within_table(self, emawithintable_mode: Dict[str, float]) -> None:
"""
Perform ema operations on the full sparse embedding tables.
We organize the sparse table, in the following way.
Emb_table:
-------------------------------------------------
- -- -
- Fast part -- Slow part -
- (RL) main part -- target part -
- -- -
-------------------------------------------------
In every "step_ema" step, we perform
slow_part += coef_ema * (fast_part - slow_part)
"""
iter_int = int(self.iter_cpu.item())
if iter_int % int(emawithintable_mode["step_ema"]) == 0 and iter_int >= int(
emawithintable_mode["step_start"]
):
weights = self.split_embedding_weights()
for table_i, (_, dim, _, _) in enumerate(self.embedding_specs):
assert (
dim & 1 == 0
), f"table dimension {dim} is odd, not supported for ema_within_table" # make sure that the dimension is even
weights[table_i][:, dim // 2 :].data.lerp_(
weights[table_i][:, : dim // 2].data,
emawithintable_mode["step_ema_coef"],
)

def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
"""
Perform ensemble and swap operations on the full sparse embedding tables.
Expand Down Expand Up @@ -2018,8 +2084,6 @@ def ensemble_and_swap(self, ensemble_mode: Dict[str, float]) -> None:
if should_ema:
if int(ensemble_mode["step_mode"]) == 0: # embedding scaling
states[i][1].mul_(0.0)
elif int(ensemble_mode["step_mode"]) == 1: # nesterov
states[i][1].copy_(weights_cpu, non_blocking=True)
# elif int(ensemble_mode["step_mode"]) == 2: pure ema

def reset_uvm_cache_stats(self) -> None:
Expand Down Expand Up @@ -2413,6 +2477,7 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
if (
self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
or self.optimizer == OptimType.EXACT_ADAGRAD
or self.optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD
):
list_of_state_dict = [
(
Expand Down Expand Up @@ -2552,6 +2617,7 @@ def get_optimizer_states(
in [
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
],
)
)
Expand Down
103 changes: 103 additions & 0 deletions fbgemm_gpu/test/tbe/training/backward_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CounterBasedRegularizationDefinition,
CounterWeightDecayMode,
CowClipDefinition,
EmawithintableModeDefinition,
EnsembleModeDefinition,
GradSumDecay,
LearningRateMode,
Expand Down Expand Up @@ -333,6 +334,19 @@ def execute_backward_optimizers_( # noqa C901
step_mode=step_mode,
)

if optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD:
(eps, step_ema, step_start) = (
1e-4,
1.0,
0.0,
)
optimizer_kwargs["eps"] = eps
optimizer_kwargs["emawithintable_mode"] = EmawithintableModeDefinition(
step_ema=step_ema,
step_start=step_start,
step_ema_coef=momentum,
)

cc = emb_op(
embedding_specs=[
(E, D, M, compute_device) for (E, D, M) in zip(Es, Ds, managed)
Expand Down Expand Up @@ -398,6 +412,7 @@ def execute_backward_optimizers_( # noqa C901
OptimType.EXACT_ROWWISE_ADAGRAD,
OptimType.EXACT_ADAGRAD,
OptimType.ENSEMBLE_ROWWISE_ADAGRAD,
OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD,
)

if optimizer in (OptimType.EXACT_ROWWISE_ADAGRAD, OptimType.EXACT_ADAGRAD):
Expand Down Expand Up @@ -612,6 +627,34 @@ def execute_backward_optimizers_( # noqa C901
"sparse_ema",
}

if optimizer == OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD:
for t in range(T):
iter_ = cc.iter.item()
weights_new = split_weights[t]

# forward update
weights_ref = bs[t].weight.cpu()
dim = weights_ref.shape[1]
weights_ref[:, dim // 2 :] = (1 - momentum) * weights_ref[
:, dim // 2 :
] + momentum * weights_ref[:, : dim // 2]
dense_cpu_grad = bs[t].weight.grad.cpu().to_dense()
v_hat_t = dense_cpu_grad.pow(2).mean(dim=1)
v_hat_t = v_hat_t.view(v_hat_t.numel(), 1)
weights_ref = torch.addcdiv(
weights_ref,
value=-lr,
tensor1=dense_cpu_grad,
tensor2=v_hat_t.sqrt_().add_(eps),
)

torch.testing.assert_close(
weights_new.index_select(dim=0, index=xs[t].view(-1)).cpu(),
weights_ref.index_select(dim=0, index=xs[t].view(-1).cpu()),
atol=1.0e-3,
rtol=1.0e-3,
)

if optimizer in (OptimType.PARTIAL_ROWWISE_LAMB, OptimType.LAMB):
rowwise = optimizer == OptimType.PARTIAL_ROWWISE_LAMB
for t in range(T):
Expand Down Expand Up @@ -1107,6 +1150,66 @@ def test_backward_optimizers_ensemble_rowwise_adagrad( # noqa C901
optimizer_state_dtypes=optimizer_state_dtypes,
)

@given(
T=st.integers(min_value=1, max_value=5),
D=st.integers(min_value=2, max_value=256),
B=st.integers(min_value=1, max_value=128),
log_E=st.integers(min_value=3, max_value=5),
L=st.integers(min_value=2, max_value=20),
weighted=st.booleans(),
mixed=st.booleans(),
mixed_B=st.booleans(),
optimizer=st.just(OptimType.EMAWITHINTABLE_ROWWISE_ADAGRAD),
long_segments=st.booleans(),
pooling_mode=st.sampled_from(
[
PoolingMode.SUM,
PoolingMode.MEAN,
PoolingMode.NONE,
]
),
use_cpu=use_cpu_strategy(),
uvm_non_rowwise_momentum=st.booleans(),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.data_too_large],
)
@unittest.skipIf(*gpu_unavailable)
def test_backward_optimizers_emawithintable_rowwise_adagrad( # noqa C901
self,
T: int,
D: int,
B: int,
log_E: int,
L: int,
weighted: bool,
mixed: bool,
mixed_B: bool,
optimizer: OptimType,
long_segments: bool,
pooling_mode: PoolingMode,
use_cpu: bool,
uvm_non_rowwise_momentum: bool,
) -> None:
self.execute_backward_optimizers_(
T,
D * 2, # dimension is required to be even
B,
log_E,
L,
weighted,
mixed,
mixed_B,
optimizer,
long_segments,
pooling_mode,
use_cpu,
uvm_non_rowwise_momentum=uvm_non_rowwise_momentum,
)

@given(
T=st.integers(min_value=1, max_value=5),
D=st.integers(min_value=2, max_value=256),
Expand Down

0 comments on commit 00bc046

Please sign in to comment.