Skip to content

Commit

Permalink
Linting with updated black
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhosking committed Mar 25, 2024
1 parent e96dd1e commit f482b2c
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 52 deletions.
24 changes: 15 additions & 9 deletions torchseq/datasets/json_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def __init__(self, config, data_path="./data", train_samples=None, dev_samples=N
config=config,
input_tokenizer=self.input_tokenizer,
output_tokenizer=self.output_tokenizer,
path=os.path.join(data_path, self.config.json_dataset.path)
if self.config.json_dataset.path is not None
else None,
path=(
os.path.join(data_path, self.config.json_dataset.path)
if self.config.json_dataset.path is not None
else None
),
samples=train_samples,
dev=False,
test=False,
Expand All @@ -46,9 +48,11 @@ def __init__(self, config, data_path="./data", train_samples=None, dev_samples=N
config=config,
input_tokenizer=self.input_tokenizer,
output_tokenizer=self.output_tokenizer,
path=os.path.join(data_path, self.config.json_dataset.path)
if self.config.json_dataset.path is not None
else None,
path=(
os.path.join(data_path, self.config.json_dataset.path)
if self.config.json_dataset.path is not None
else None
),
samples=dev_samples,
dev=True,
test=False,
Expand All @@ -59,9 +63,11 @@ def __init__(self, config, data_path="./data", train_samples=None, dev_samples=N
config=config,
input_tokenizer=self.input_tokenizer,
output_tokenizer=self.output_tokenizer,
path=os.path.join(data_path, self.config.json_dataset.path)
if self.config.json_dataset.path is not None
else None,
path=(
os.path.join(data_path, self.config.json_dataset.path)
if self.config.json_dataset.path is not None
else None
),
samples=test_samples,
dev=False,
test=True,
Expand Down
24 changes: 15 additions & 9 deletions torchseq/datasets/qa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def __init__(self, config, data_path, train_samples=None, dev_samples=None, test
self.logger = logging.getLogger("DataLoader")

train = QADataset(
path=os.path.join(data_path, config.training.dataset) + "/"
if self.config.training.dataset is not None
else None,
path=(
os.path.join(data_path, config.training.dataset) + "/"
if self.config.training.dataset is not None
else None
),
samples=train_samples,
config=config,
tokenizer=self.tokenizer,
Expand All @@ -36,9 +38,11 @@ def __init__(self, config, data_path, train_samples=None, dev_samples=None, test
length_limit=self.config.training.get("truncate_dataset", None),
)
valid = QADataset(
path=os.path.join(data_path, self.config.training.dataset) + "/"
if self.config.training.dataset is not None
else None,
path=(
os.path.join(data_path, self.config.training.dataset) + "/"
if self.config.training.dataset is not None
else None
),
samples=dev_samples,
config=config,
tokenizer=self.tokenizer,
Expand All @@ -47,9 +51,11 @@ def __init__(self, config, data_path, train_samples=None, dev_samples=None, test
length_limit=self.config.eval.get("truncate_dataset", None),
)
test = QADataset(
path=os.path.join(data_path, self.config.training.dataset) + "/"
if self.config.training.dataset is not None
else None,
path=(
os.path.join(data_path, self.config.training.dataset) + "/"
if self.config.training.dataset is not None
else None
),
samples=test_samples,
config=config,
tokenizer=self.tokenizer,
Expand Down
8 changes: 5 additions & 3 deletions torchseq/metric_hooks/hrq_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,9 +1297,11 @@ def prefilter_condition(sentence, hotel_aspect_filter=True, amazon_filter=False,
truncation_length=None,
prune_min_weight=0.01,
prune_max_paths=None,
use_tfidf=True
if config.eval.metrics.hrq_agg.get("summary_smart_generic_weight_scheme", None) is not None
else False,
use_tfidf=(
True
if config.eval.metrics.hrq_agg.get("summary_smart_generic_weight_scheme", None) is not None
else False
),
tfidf_weighting_scheme=config.eval.metrics.hrq_agg.get("summary_smart_generic_weight_scheme", 5),
# block_paths={k: [p[:1] for p in v] for k, v in summary_paths_specific.items()},
block_paths={k: v for k, v in summary_paths_specific.items()},
Expand Down
31 changes: 17 additions & 14 deletions torchseq/metric_hooks/opsumm_cluster_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ def on_end_epoch(self, agent, use_test=False):

if self.config.eval.metrics.opsumm_cluster_aug.get("run_selection_oracle_comparison", False):
logger.info("Running cluster vs oracle comparison...")
self.scores[
"hiro_selection_vs_oracle"
] = OpSummClusterAugMetricHook.eval_compare_selected_clusters_to_oracle(
self.config,
agent.data_path,
generated_summaries["paths"],
# {
# sent: code
# for sent, code in zip(generated_summaries["inputs"], generated_summaries["all_codes"])
# },
test=use_test,
self.scores["hiro_selection_vs_oracle"] = (
OpSummClusterAugMetricHook.eval_compare_selected_clusters_to_oracle(
self.config,
agent.data_path,
generated_summaries["paths"],
# {
# sent: code
# for sent, code in zip(generated_summaries["inputs"], generated_summaries["all_codes"])
# },
test=use_test,
)
)
logger.info("...done!")
if self.config.eval.metrics.opsumm_cluster_aug.get("run_selection_prevalence", False):
Expand Down Expand Up @@ -1356,9 +1356,12 @@ def prefilter_condition(sentence, hotel_aspect_filter=True, amazon_filter=False,
truncation_length=None,
prune_min_weight=config.eval.metrics.opsumm_cluster_aug.get("summary_smart_generic_min_weight", 0.01),
prune_max_paths=None,
use_tfidf=True
if config.eval.metrics.opsumm_cluster_aug.get("summary_smart_generic_weight_scheme", None) is not None
else False,
use_tfidf=(
True
if config.eval.metrics.opsumm_cluster_aug.get("summary_smart_generic_weight_scheme", None)
is not None
else False
),
tfidf_weighting_scheme=config.eval.metrics.opsumm_cluster_aug.get(
"summary_smart_generic_weight_scheme", 5
),
Expand Down
34 changes: 20 additions & 14 deletions torchseq/models/exemplar_guided_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def __init__(self, config, input_tokenizer, output_tokenizer, src_field="source"
self.seq_encoder_2 = SequenceEncoder(
global_config=config,
encoder_config=config.get_first(["template_encoder", "encoder"]),
tokenizer=self.output_tokenizer
if self.config.bottleneck.get("template_tokenizer", "input") == "output"
else self.input_tokenizer,
freeze_embeddings=config.template_encoder.get(
"freeze_embeddings", config.get("freeze_embeddings", False)
)
if "template_encoder" in config.data
else config.freeze_embeddings,
tokenizer=(
self.output_tokenizer
if self.config.bottleneck.get("template_tokenizer", "input") == "output"
else self.input_tokenizer
),
freeze_embeddings=(
config.template_encoder.get("freeze_embeddings", config.get("freeze_embeddings", False))
if "template_encoder" in config.data
else config.freeze_embeddings
),
)
# self.bottleneck_2 = PoolingBottleneck(config)

Expand Down Expand Up @@ -245,16 +247,20 @@ def forward(self, batch, output, memory=None, tgt_field=None):

if self.config.bottleneck.code_predictor.get("sem_only", False):
codepred_input = self.reduce_fn(
sem_encoding_pooled[:, :, : self.sep_splice_ix]
if self.config.bottleneck.code_predictor.get("post_bottleneck", False)
else prebn_sem_encoding_pooled[:, :, : self.sep_splice_ix],
(
sem_encoding_pooled[:, :, : self.sep_splice_ix]
if self.config.bottleneck.code_predictor.get("post_bottleneck", False)
else prebn_sem_encoding_pooled[:, :, : self.sep_splice_ix]
),
mask=memory["encoding_mask"],
)
else:
codepred_input = self.reduce_fn(
sem_encoding_pooled
if self.config.bottleneck.code_predictor.get("post_bottleneck", False)
else prebn_sem_encoding_pooled,
(
sem_encoding_pooled
if self.config.bottleneck.code_predictor.get("post_bottleneck", False)
else prebn_sem_encoding_pooled
),
mask=memory["encoding_mask"],
)

Expand Down
1 change: 1 addition & 0 deletions torchseq/models/pythae_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class VQVAEConfig:
use_ema (bool): Whether to use the Exponential Movng Average Update (EMA). Default: False.
decay (float): The decay to apply in the EMA update. Must be in [0, 1]. Default: 0.99.
"""

latent_dim: int = 10
commitment_loss_factor: float = 0.25
quantization_loss_factor: float = 1.0
Expand Down
8 changes: 5 additions & 3 deletions torchseq/models/vq_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ def __init__(
[nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) for _ in range(num_heads)]
)
for hix, embedding in enumerate(self._embedding):
torch.nn.init.xavier_uniform_(
embedding.weight.data, gain=6.0 * init_scale * init_decay_weight**hix
) if init_embeds_xavier else embedding.weight.data.normal_(std=init_scale * init_decay_weight**hix)
(
torch.nn.init.xavier_uniform_(embedding.weight.data, gain=6.0 * init_scale * init_decay_weight**hix)
if init_embeds_xavier
else embedding.weight.data.normal_(std=init_scale * init_decay_weight**hix)
)
if init_sphere:
embedding.weight.data = (
embedding.weight.data
Expand Down
1 change: 1 addition & 0 deletions torchseq/utils/fleiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Cardillo G. (2007) Fleisses kappa: compute the Fleiss'es kappa for multiple raters.
http://www.mathworks.com/matlabcentral/fileexchange/15426
"""

import numpy

# from scipy.special import erfc
Expand Down

0 comments on commit f482b2c

Please sign in to comment.