Skip to content

Commit

Permalink
Update wd_like.py
Browse files Browse the repository at this point in the history
  • Loading branch information
toshiaki1729 committed Dec 29, 2022
1 parent 238318c commit f45016a
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions scripts/t2p/prompt_generator/wd_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
if opts.n <= 0: return []
if opts.weighted:
probs_np = probs_cpu.detach().numpy()

if np.count_nonzero(probs_np) <= opts.n:
results = np.random.choice(a=tags_np, size=opts.n, replace=False)
num_nonzero = np.count_nonzero(probs_np)
if num_nonzero <= opts.n:
if num_nonzero > 0:
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(a=tags_np, size=opts.n, replace=False, p=probs_np)
else:
Expand All @@ -148,9 +151,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
probs_np = probs.detach().numpy()
probs_np /= np.sum(probs_np)
probs_np = np.nan_to_num(probs_np)

if np.count_nonzero(probs_np) <= opts.n:
results = np.random.choice(tags_np, opts.n, replace=False)
num_nonzero = np.count_nonzero(probs_np)
if num_nonzero <= opts.n:
if num_nonzero > 0:
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
Expand All @@ -174,9 +180,12 @@ def __call__(self, text: str, text_neg: str, neg_weight: float, opts: pgen.Gener
probs_np = np.array([sorted_probs[i] for i in indices])
probs_np /= np.sum(probs_np)
probs_np = np.nan_to_num(probs_np)

if np.count_nonzero(probs_np) <= opts.n:
results = np.random.choice(tags_np, opts.n, replace=False)
num_nonzero = np.count_nonzero(probs_np)
if num_nonzero <= opts.n:
if num_nonzero > 0:
results=np.random.choice(tags_np, num_nonzero, replace=False, p=probs_np)
else:
results = np.random.choice(tags_np, opts.n, replace=False)
else:
results = np.random.choice(tags_np, opts.n, replace=False, p=probs_np)
else:
Expand Down

0 comments on commit f45016a

Please sign in to comment.