Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 5, 2021
1 parent de7825c commit 0a159a8
Show file tree
Hide file tree
Showing 208 changed files with 2,458 additions and 2,673 deletions.
66 changes: 33 additions & 33 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pt_lightning_sphinx_theme

_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.join(_PATH_HERE, '..', '..')
_PATH_ROOT = os.path.join(_PATH_HERE, "..", "..")
sys.path.insert(0, os.path.abspath(_PATH_ROOT))

try:
Expand All @@ -33,9 +33,9 @@ def _load_py_module(fname, pkg="flash"):

about = _load_py_module("__about__.py")

SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True))
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))

html_favicon = '_static/images/icon.svg'
html_favicon = "_static/images/icon.svg"

# -- Project information -----------------------------------------------------

Expand All @@ -49,22 +49,22 @@ def _load_py_module(fname, pkg="flash"):
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
"sphinx.ext.autodoc",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
# 'sphinx.ext.coverage',
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'sphinx.ext.imgmath',
'recommonmark',
"sphinx.ext.viewcode",
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
"sphinx.ext.imgmath",
"recommonmark",
# 'sphinx.ext.autosectionlabel',
# 'nbsphinx', # it seems some sphinx issue
'sphinx_autodoc_typehints',
'sphinx_copybutton',
'sphinx_paramlinks',
'sphinx_togglebutton',
"sphinx_autodoc_typehints",
"sphinx_copybutton",
"sphinx_paramlinks",
"sphinx_togglebutton",
]

# autodoc: Default to members and undoc-members
Expand Down Expand Up @@ -114,8 +114,8 @@ def _load_py_module(fname, pkg="flash"):
# documentation.

html_theme_options = {
'pytorch_project': 'https://pytorchlightning.ai',
'canonical_url': about.__docs_url__,
"pytorch_project": "https://pytorchlightning.ai",
"canonical_url": about.__docs_url__,
"collapse_navigation": False,
"display_version": True,
"logo_only": False,
Expand All @@ -132,47 +132,47 @@ def _load_py_module(fname, pkg="flash"):
def setup(app):
# this is for hiding doctest decoration,
# see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/
app.add_js_file('copybutton.js')
app.add_css_file('main.css')
app.add_js_file("copybutton.js")
app.add_css_file("main.css")


# Ignoring Third-party packages
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
def _package_list_from_file(pfile):
assert os.path.isfile(pfile)
with open(pfile, 'r') as fp:
with open(pfile, "r") as fp:
lines = fp.readlines()
list_pkgs = []
for ln in lines:
found = [ln.index(ch) for ch in list(',=<>#@') if ch in ln]
pkg = ln[:min(found)] if found else ln
found = [ln.index(ch) for ch in list(",=<>#@") if ch in ln]
pkg = ln[: min(found)] if found else ln
if pkg.strip():
list_pkgs.append(pkg.strip())
return list_pkgs


# define mapping from PyPI names to python imports
PACKAGE_MAPPING = {
'pytorch-lightning': 'pytorch_lightning',
'scikit-learn': 'sklearn',
'Pillow': 'PIL',
'PyYAML': 'yaml',
'rouge-score': 'rouge_score',
'lightning-bolts': 'pl_bolts',
'pytorch-tabnet': 'pytorch_tabnet',
'pyDeprecate': 'deprecate',
"pytorch-lightning": "pytorch_lightning",
"scikit-learn": "sklearn",
"Pillow": "PIL",
"PyYAML": "yaml",
"rouge-score": "rouge_score",
"lightning-bolts": "pl_bolts",
"pytorch-tabnet": "pytorch_tabnet",
"pyDeprecate": "deprecate",
}
MOCK_PACKAGES = []
if SPHINX_MOCK_REQUIREMENTS:
# mock also base packages when we are on RTD since we don't install them there
MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, 'requirements.txt'))
MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt"))
# replace PyPI packages by importing ones
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]

autodoc_mock_imports = MOCK_PACKAGES

# only run doctests marked with a ".. doctest::" directive
doctest_test_doctest_blocks = ''
doctest_test_doctest_blocks = ""
doctest_global_setup = """
import torch
import pytorch_lightning as pl
Expand Down
2 changes: 1 addition & 1 deletion flash/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__version__ = "0.4.1dev"
__author__ = "PyTorchLightning et al."
__author_email__ = "[email protected]"
__license__ = 'Apache-2.0'
__license__ = "Apache-2.0"
__copyright__ = f"Copyright (c) 2020-2021, f{__author__}."
__homepage__ = "https://github.com/PyTorchLightning/lightning-flash"
__docs_url__ = "https://lightning-flash.readthedocs.io/en/stable/"
Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if _IS_TESTING:
from pytorch_lightning import seed_everything

seed_everything(42)

__all__ = [
Expand Down
17 changes: 9 additions & 8 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ def main():


def register_command(command):

@main.command(context_settings=dict(
help_option_names=[],
ignore_unknown_options=True,
))
@click.argument('cli_args', nargs=-1, type=click.UNPROCESSED)
@main.command(
context_settings=dict(
help_option_names=[],
ignore_unknown_options=True,
)
)
@click.argument("cli_args", nargs=-1, type=click.UNPROCESSED)
@functools.wraps(command)
def wrapper(cli_args):
with patch('sys.argv', [command.__name__] + list(cli_args)):
with patch("sys.argv", [command.__name__] + list(cli_args)):
command()


Expand Down Expand Up @@ -63,5 +64,5 @@ def wrapper(cli_args):
except ImportError:
pass

if __name__ == '__main__':
if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions flash/audio/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def audio_classification():
AudioClassificationData,
default_datamodule_builder=from_urban8k,
default_arguments={
'trainer.max_epochs': 3,
}
"trainer.max_epochs": 3,
},
)

cli.trainer.save_checkpoint("audio_classification_model.pt")


if __name__ == '__main__':
if __name__ == "__main__":
audio_classification()
5 changes: 2 additions & 3 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class AudioClassificationPreprocess(Preprocess):

@requires_extras(["audio", "image"])
def __init__(
self,
Expand All @@ -35,7 +34,7 @@ def __init__(
spectrogram_size: Tuple[int, int] = (196, 196),
time_mask_param: int = 80,
freq_mask_param: int = 80,
deserializer: Optional['Deserializer'] = None,
deserializer: Optional["Deserializer"] = None,
):
self.spectrogram_size = spectrogram_size
self.time_mask_param = time_mask_param
Expand All @@ -48,7 +47,7 @@ def __init__(
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource()
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand Down
7 changes: 4 additions & 3 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]
}


def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int,
freq_mask_param: int) -> Dict[str, Callable]:
def train_default_transforms(
spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int
) -> Dict[str, Callable]:
"""During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``"""
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)),
)
}

Expand Down
4 changes: 2 additions & 2 deletions flash/audio/speech_recognition/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def speech_recognition():
SpeechRecognitionData,
default_datamodule_builder=from_timit,
default_arguments={
'trainer.max_epochs': 3,
"trainer.max_epochs": 3,
},
finetune=False,
)

cli.trainer.save_checkpoint("speech_recognition_model.pt")


if __name__ == '__main__':
if __name__ == "__main__":
speech_recognition()
41 changes: 18 additions & 23 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@


class SpeechRecognitionDeserializer(Deserializer):

def deserialize(self, sample: Any) -> Dict:
encoded_with_padding = (sample + "===").encode("ascii")
audio = base64.b64decode(encoded_with_padding)
buffer = io.BytesIO(audio)
data, sampling_rate = sf.read(buffer)
return {
DefaultDataKeys.INPUT: data,
DefaultDataKeys.METADATA: {
"sampling_rate": sampling_rate
},
DefaultDataKeys.METADATA: {"sampling_rate": sampling_rate},
}

@property
Expand All @@ -64,11 +61,13 @@ def example_input(self) -> str:


class BaseSpeechRecognition:

def _load_sample(self, sample: Dict[str, Any]) -> Any:
path = sample[DefaultDataKeys.INPUT]
if not os.path.isabs(path) and DefaultDataKeys.METADATA in sample and "root" in sample[DefaultDataKeys.METADATA
]:
if (
not os.path.isabs(path)
and DefaultDataKeys.METADATA in sample
and "root" in sample[DefaultDataKeys.METADATA]
):
path = os.path.join(sample[DefaultDataKeys.METADATA]["root"], path)
speech_array, sampling_rate = sf.read(path)
sample[DefaultDataKeys.INPUT] = speech_array
Expand All @@ -77,7 +76,6 @@ def _load_sample(self, sample: Dict[str, Any]) -> Any:


class SpeechRecognitionFileDataSource(DataSource, BaseSpeechRecognition):

def __init__(self, filetype: Optional[str] = None):
super().__init__()
self.filetype = filetype
Expand All @@ -87,50 +85,49 @@ def load_data(
data: Tuple[str, Union[str, List[str]], Union[str, List[str]]],
dataset: Optional[Any] = None,
) -> Union[Sequence[Mapping[str, Any]]]:
if self.filetype == 'json':
if self.filetype == "json":
file, input_key, target_key, field = data
else:
file, input_key, target_key = data
stage = self.running_stage.value
if self.filetype == 'json' and field is not None:
if self.filetype == "json" and field is not None:
dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)}, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files={stage: str(file)})

dataset = dataset_dict[stage]
meta = {"root": os.path.dirname(file)}
return [{
DefaultDataKeys.INPUT: input_file,
DefaultDataKeys.TARGET: target,
DefaultDataKeys.METADATA: meta,
} for input_file, target in zip(dataset[input_key], dataset[target_key])]
return [
{
DefaultDataKeys.INPUT: input_file,
DefaultDataKeys.TARGET: target,
DefaultDataKeys.METADATA: meta,
}
for input_file, target in zip(dataset[input_key], dataset[target_key])
]

def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:
return self._load_sample(sample)


class SpeechRecognitionCSVDataSource(SpeechRecognitionFileDataSource):

def __init__(self):
super().__init__(filetype='csv')
super().__init__(filetype="csv")


class SpeechRecognitionJSONDataSource(SpeechRecognitionFileDataSource):

def __init__(self):
super().__init__(filetype='json')
super().__init__(filetype="json")


class SpeechRecognitionDatasetDataSource(DatasetDataSource, BaseSpeechRecognition):

def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Union[Sequence[Mapping[str, Any]]]:
if isinstance(data, HFDataset):
data = list(zip(data["file"], data["text"]))
return super().load_data(data, dataset)


class SpeechRecognitionPathsDataSource(PathsDataSource, BaseSpeechRecognition):

def __init__(self):
super().__init__(("wav", "ogg", "flac", "mat"))

Expand All @@ -139,7 +136,6 @@ def load_sample(self, sample: Dict[str, Any], dataset: Any = None) -> Any:


class SpeechRecognitionPreprocess(Preprocess):

@requires_extras("audio")
def __init__(
self,
Expand Down Expand Up @@ -181,7 +177,6 @@ class SpeechRecognitionBackboneState(ProcessState):


class SpeechRecognitionPostprocess(Postprocess):

@requires_extras("audio")
def __init__(self):
super().__init__()
Expand Down
7 changes: 4 additions & 3 deletions flash/audio/speech_recognition/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def __init__(
# set os environ variable for multiprocesses
os.environ["PYTHONWARNINGS"] = "ignore"

model = self.backbones.get(backbone
)() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone)
model = (
self.backbones.get(backbone)() if backbone in self.backbones else Wav2Vec2ForCTC.from_pretrained(backbone)
)
super().__init__(
model=model,
loss_fn=loss_fn,
Expand All @@ -74,5 +75,5 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A

def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
out = self.model(batch["input_values"], labels=batch["labels"])
out["logs"] = {'loss': out.loss}
out["logs"] = {"loss": out.loss}
return out
Loading

0 comments on commit 0a159a8

Please sign in to comment.