Skip to content

Commit

Permalink
fix generating with all tags if all probability < 0
Browse files Browse the repository at this point in the history
  • Loading branch information
toshiaki1729 committed Dec 29, 2022
1 parent 362da88 commit 238318c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions scripts/t2p/prompt_generator/wd_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
probs_np = probs_cpu.detach().numpy()

if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
else:
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
else:
Expand All @@ -150,7 +150,7 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
probs_np = np.nan_to_num(probs_np)

if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
Expand All @@ -176,7 +176,7 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
probs_np = np.nan_to_num(probs_np)

if np.count_nonzero(probs_np) <= opts.n:
results = tags_np
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
Expand Down

0 comments on commit 238318c

Please sign in to comment.