Skip to content

Commit

Permalink
black formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 9, 2024
1 parent 83da918 commit e99605d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
4 changes: 1 addition & 3 deletions egs/libritts/CODEC/encodec/balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ def averager(beta: float = 1):
fix: Dict[str, float] = defaultdict(float)
total: Dict[str, float] = defaultdict(float)

def _update(
metrics: Dict[str, Any], weight: float = 1
) -> Dict[str, float]:
def _update(metrics: Dict[str, Any], weight: float = 1) -> Dict[str, float]:
nonlocal total, fix
for key, value in metrics.items():
total[key] = total[key] * beta + weight * float(value)
Expand Down
36 changes: 21 additions & 15 deletions egs/libritts/CODEC/encodec/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,10 @@ def save_bad_model(suffix: str = ""):
+ balancer.weights["feature_stft_loss"] * feature_stft_loss
+ balancer.weights["feature_period_loss"] * feature_period_loss
+ balancer.weights["feature_scale_loss"] * feature_scale_loss
+ balancer.weights["wav_reconstruction_loss"] * wav_reconstruction_loss
+ balancer.weights["mel_reconstruction_loss"] * mel_reconstruction_loss
+ balancer.weights["wav_reconstruction_loss"]
* wav_reconstruction_loss
+ balancer.weights["mel_reconstruction_loss"]
* mel_reconstruction_loss
)
else:
gen_loss = (
Expand Down Expand Up @@ -1112,19 +1114,23 @@ def run(rank, world_size, args):
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])

balancer = Balancer(
weights={
"gen_stft_adv_loss": 3.0,
"gen_period_adv_loss": 3.0,
"gen_scale_adv_loss": 3.0,
"feature_stft_loss": 3.0,
"feature_period_loss": 3.0,
"feature_scale_loss": 3.0,
"wav_reconstruction_loss": 0.1,
"mel_reconstruction_loss": 1.0,
}
# this setup follows the one described in the Encodec paper
) if params.use_balancer else None
balancer = (
Balancer(
weights={
"gen_stft_adv_loss": 3.0,
"gen_period_adv_loss": 3.0,
"gen_scale_adv_loss": 3.0,
"feature_stft_loss": 3.0,
"feature_period_loss": 3.0,
"feature_scale_loss": 3.0,
"wav_reconstruction_loss": 0.1,
"mel_reconstruction_loss": 1.0,
}
# this setup follows the one described in the Encodec paper
)
if params.use_balancer
else None
)

for epoch in range(params.start_epoch, params.num_epochs + 1):
logging.info(f"Start epoch {epoch}")
Expand Down

0 comments on commit e99605d

Please sign in to comment.