Skip to content

Commit

Permalink
Fix arg formatting in preprocess.py and add fmt control for black for…
Browse files Browse the repository at this point in the history
…matting (#399)

Summary:
Not switching to Black formatting just yet, but adding fmt: off directives in case we decide to later.
Pull Request resolved: facebookresearch/fairseq#399

Differential Revision: D13364674

Pulled By: myleott

fbshipit-source-id: a20a11a18be3d583ee30eff770278fb4bd05b93c
  • Loading branch information
myleott authored and yzpang committed Feb 19, 2021
1 parent fee772d commit 0c30a0d
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 50 deletions.
2 changes: 2 additions & 0 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def __init__(self, decoder):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
Expand Down Expand Up @@ -427,6 +428,7 @@ def add_args(parser):
help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
# fmt: on

@classmethod
def build_model(cls, args, task):
Expand Down
2 changes: 1 addition & 1 deletion fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ def step_update(self, num_updates):
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))

self.optimizer.set_lr(self.lr)
return self.lr
return self.lr
59 changes: 39 additions & 20 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,45 @@

def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default=None, help='target language')
parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test language prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N')
parser.add_argument('--workers', metavar='N', default=1, type=int, help='number of parallel workers')
# fmt: off
parser.add_argument("-s", "--source-lang", default=None, metavar="SRC",
help="source language")
parser.add_argument("-t", "--target-lang", default=None, metavar="TARGET",
help="target language")
parser.add_argument("--trainpref", metavar="FP", default=None,
help="train file prefix")
parser.add_argument("--validpref", metavar="FP", default=None,
help="comma separated, valid file prefixes")
parser.add_argument("--testpref", metavar="FP", default=None,
help="comma separated, test file prefixes")
parser.add_argument("--destdir", metavar="DIR", default="data-bin",
help="destination dir")
parser.add_argument("--thresholdtgt", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
parser.add_argument("--thresholdsrc", metavar="N", default=0, type=int,
help="map words appearing less than threshold times to unknown")
parser.add_argument("--tgtdict", metavar="FP",
help="reuse given target dictionary")
parser.add_argument("--srcdict", metavar="FP",
help="reuse given source dictionary")
parser.add_argument("--nwordstgt", metavar="N", default=-1, type=int,
help="number of target words to retain")
parser.add_argument("--nwordssrc", metavar="N", default=-1, type=int,
help="number of source words to retain")
parser.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
parser.add_argument("--output-format", metavar="FORMAT", default="binary",
choices=["binary", "raw"],
help="output format (optional)")
parser.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
parser.add_argument("--only-source", action="store_true",
help="Only process the source language")
parser.add_argument("--padding-factor", metavar="N", default=8, type=int,
help="Pad dictionary size to be multiple of N")
parser.add_argument("--workers", metavar="N", default=1, type=int,
help="number of parallel workers")
# fmt: on
return parser


Expand Down
5 changes: 4 additions & 1 deletion score.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

def get_parser():
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
# fmt: off
parser.add_argument('-s', '--sys', default='-', help='system output')
parser.add_argument('-r', '--ref', required=True, help='references')
parser.add_argument('-o', '--order', default=4, metavar='N',
type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring')
# fmt: on
return parser


Expand All @@ -44,7 +46,8 @@ def readlines(fd):
for line in fd.readlines():
if args.ignore_case:
yield line.lower()
yield line
else:
yield line

def score(fdsys):
with open(args.ref) as fdref:
Expand Down
38 changes: 13 additions & 25 deletions scripts/average_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,19 @@ def main():
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
)

parser.add_argument(
'--inputs',
required=True,
nargs='+',
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
parser.add_argument(
'--num',
type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last num of those',
)
parser.add_argument(
'--update-based-checkpoints',
action='store_true',
help='if set and used together with --num, averages update-based checkpoints instead of epoch-based checkpoints'
)
# fmt: off
parser.add_argument('--inputs', required=True, nargs='+',
help='Input checkpoint file paths.')
parser.add_argument('--output', required=True, metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this path.')
num_group = parser.add_mutually_exclusive_group()
num_group.add_argument('--num-epoch-checkpoints', type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last this many of them.')
num_group.add_argument('--num-update-checkpoints', type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.')
# fmt: on
args = parser.parse_args()
print(args)

Expand Down
2 changes: 2 additions & 0 deletions scripts/build_sym_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

def main():
parser = argparse.ArgumentParser(description='symmetric alignment builer')
# fmt: off
parser.add_argument('--fast_align_dir',
help='path to fast_align build directory')
parser.add_argument('--mosesdecoder_dir',
Expand All @@ -47,6 +48,7 @@ def main():
'in the target language')
parser.add_argument('--output_dir',
help='output directory')
# fmt: on
args = parser.parse_args()

fast_align_bin = os.path.join(args.fast_align_dir, 'fast_align')
Expand Down
4 changes: 1 addition & 3 deletions scripts/read_binarized.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ def get_parser():
parser = argparse.ArgumentParser(
description='writes text from binarized file to stdout')
# fmt: off
parser.add_argument('--dataset-impl', help='dataset implementation',
choices=indexed_dataset.get_available_dataset_impl())
parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words')
parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
# fmt: on

Expand Down

0 comments on commit 0c30a0d

Please sign in to comment.