Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

use Black #634

Merged
merged 5 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .deepsource.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,3 @@ enabled = true
[analyzers.meta]
runtime_version = "3.x.x"
max_line_length = 120

[[transformers]]
name = "autopep8"
enabled = true
30 changes: 0 additions & 30 deletions .github/workflows/code-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,6 @@ jobs:
- name: PEP8
run: flake8 .

#format-check-yapf:
# runs-on: ubuntu-20.04
# steps:
# - uses: actions/checkout@master
# - uses: actions/setup-python@v2
# with:
# python-version: 3.8
# - name: Install dependencies
# run: |
# pip install --upgrade pip
# pip install yapf
# pip list
# shell: bash
# - name: yapf
# run: yapf --diff --parallel --recursive .

#imports-check-isort:
# runs-on: ubuntu-20.04
# steps:
# - uses: actions/checkout@master
# - uses: actions/setup-python@v2
# with:
# python-version: 3.8
# - name: Install isort
# run: |
# pip install isort
# pip list
# - name: isort
# run: isort --check-only .

#typing-check-mypy:
# runs-on: ubuntu-20.04
# steps:
Expand Down
28 changes: 13 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,12 @@ repos:
- id: detect-private-key

- repo: https://github.com/PyCQA/isort
rev: 5.9.1
rev: 5.9.3
hooks:
- id: isort
name: imports
require_serial: false

- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0
hooks:
- id: yapf
name: formatting
language: python
require_serial: false

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
name: PEP8

- repo: https://github.com/kynan/nbstripout
rev: 0.5.0
hooks:
Expand All @@ -65,3 +51,15 @@ repos:
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]

- repo: https://github.com/psf/black
rev: 21.7b0
hooks:
- id: black
name: Format code

- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
name: PEP8
66 changes: 33 additions & 33 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pt_lightning_sphinx_theme

_PATH_HERE = os.path.abspath(os.path.dirname(__file__))
_PATH_ROOT = os.path.join(_PATH_HERE, '..', '..')
_PATH_ROOT = os.path.join(_PATH_HERE, "..", "..")
sys.path.insert(0, os.path.abspath(_PATH_ROOT))

try:
Expand All @@ -33,9 +33,9 @@ def _load_py_module(fname, pkg="flash"):

about = _load_py_module("__about__.py")

SPHINX_MOCK_REQUIREMENTS = int(os.environ.get('SPHINX_MOCK_REQUIREMENTS', True))
SPHINX_MOCK_REQUIREMENTS = int(os.environ.get("SPHINX_MOCK_REQUIREMENTS", True))

html_favicon = '_static/images/icon.svg'
html_favicon = "_static/images/icon.svg"

# -- Project information -----------------------------------------------------

Expand All @@ -49,22 +49,22 @@ def _load_py_module(fname, pkg="flash"):
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
"sphinx.ext.autodoc",
"sphinx.ext.doctest",
"sphinx.ext.intersphinx",
"sphinx.ext.todo",
# 'sphinx.ext.coverage',
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'sphinx.ext.imgmath',
'recommonmark',
"sphinx.ext.viewcode",
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
"sphinx.ext.imgmath",
"recommonmark",
# 'sphinx.ext.autosectionlabel',
# 'nbsphinx', # it seems some sphinx issue
'sphinx_autodoc_typehints',
'sphinx_copybutton',
'sphinx_paramlinks',
'sphinx_togglebutton',
"sphinx_autodoc_typehints",
"sphinx_copybutton",
"sphinx_paramlinks",
"sphinx_togglebutton",
]

# autodoc: Default to members and undoc-members
Expand Down Expand Up @@ -114,8 +114,8 @@ def _load_py_module(fname, pkg="flash"):
# documentation.

html_theme_options = {
'pytorch_project': 'https://pytorchlightning.ai',
'canonical_url': about.__docs_url__,
"pytorch_project": "https://pytorchlightning.ai",
"canonical_url": about.__docs_url__,
"collapse_navigation": False,
"display_version": True,
"logo_only": False,
Expand All @@ -132,47 +132,47 @@ def _load_py_module(fname, pkg="flash"):
def setup(app):
# this is for hiding doctest decoration,
# see: http://z4r.github.io/python/2011/12/02/hides-the-prompts-and-output/
app.add_js_file('copybutton.js')
app.add_css_file('main.css')
app.add_js_file("copybutton.js")
app.add_css_file("main.css")


# Ignoring Third-party packages
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
def _package_list_from_file(pfile):
assert os.path.isfile(pfile)
with open(pfile, 'r') as fp:
with open(pfile, "r") as fp:
lines = fp.readlines()
list_pkgs = []
for ln in lines:
found = [ln.index(ch) for ch in list(',=<>#@') if ch in ln]
pkg = ln[:min(found)] if found else ln
found = [ln.index(ch) for ch in list(",=<>#@") if ch in ln]
pkg = ln[: min(found)] if found else ln
if pkg.strip():
list_pkgs.append(pkg.strip())
return list_pkgs


# define mapping from PyPI names to python imports
PACKAGE_MAPPING = {
'pytorch-lightning': 'pytorch_lightning',
'scikit-learn': 'sklearn',
'Pillow': 'PIL',
'PyYAML': 'yaml',
'rouge-score': 'rouge_score',
'lightning-bolts': 'pl_bolts',
'pytorch-tabnet': 'pytorch_tabnet',
'pyDeprecate': 'deprecate',
"pytorch-lightning": "pytorch_lightning",
"scikit-learn": "sklearn",
"Pillow": "PIL",
"PyYAML": "yaml",
"rouge-score": "rouge_score",
"lightning-bolts": "pl_bolts",
"pytorch-tabnet": "pytorch_tabnet",
"pyDeprecate": "deprecate",
}
MOCK_PACKAGES = []
if SPHINX_MOCK_REQUIREMENTS:
# mock also base packages when we are on RTD since we don't install them there
MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, 'requirements.txt'))
MOCK_PACKAGES += _package_list_from_file(os.path.join(_PATH_ROOT, "requirements.txt"))
# replace PyPI packages by importing ones
MOCK_PACKAGES = [PACKAGE_MAPPING.get(pkg, pkg) for pkg in MOCK_PACKAGES]

autodoc_mock_imports = MOCK_PACKAGES

# only run doctests marked with a ".. doctest::" directive
doctest_test_doctest_blocks = ''
doctest_test_doctest_blocks = ""
doctest_global_setup = """
import torch
import pytorch_lightning as pl
Expand Down
2 changes: 1 addition & 1 deletion flash/__about__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__version__ = "0.4.1dev"
__author__ = "PyTorchLightning et al."
__author_email__ = "[email protected]"
__license__ = 'Apache-2.0'
__license__ = "Apache-2.0"
__copyright__ = f"Copyright (c) 2020-2021, f{__author__}."
__homepage__ = "https://github.com/PyTorchLightning/lightning-flash"
__docs_url__ = "https://lightning-flash.readthedocs.io/en/stable/"
Expand Down
1 change: 1 addition & 0 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if _IS_TESTING:
from pytorch_lightning import seed_everything

seed_everything(42)

__all__ = [
Expand Down
17 changes: 9 additions & 8 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ def main():


def register_command(command):

@main.command(context_settings=dict(
help_option_names=[],
ignore_unknown_options=True,
))
@click.argument('cli_args', nargs=-1, type=click.UNPROCESSED)
@main.command(
context_settings=dict(
help_option_names=[],
ignore_unknown_options=True,
)
)
@click.argument("cli_args", nargs=-1, type=click.UNPROCESSED)
@functools.wraps(command)
def wrapper(cli_args):
with patch('sys.argv', [command.__name__] + list(cli_args)):
with patch("sys.argv", [command.__name__] + list(cli_args)):
command()


Expand Down Expand Up @@ -63,5 +64,5 @@ def wrapper(cli_args):
except ImportError:
pass

if __name__ == '__main__':
if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions flash/audio/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def audio_classification():
AudioClassificationData,
default_datamodule_builder=from_urban8k,
default_arguments={
'trainer.max_epochs': 3,
}
"trainer.max_epochs": 3,
},
)

cli.trainer.save_checkpoint("audio_classification_model.pt")


if __name__ == '__main__':
if __name__ == "__main__":
audio_classification()
5 changes: 2 additions & 3 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class AudioClassificationPreprocess(Preprocess):

@requires_extras(["audio", "image"])
def __init__(
self,
Expand All @@ -35,7 +34,7 @@ def __init__(
spectrogram_size: Tuple[int, int] = (196, 196),
time_mask_param: int = 80,
freq_mask_param: int = 80,
deserializer: Optional['Deserializer'] = None,
deserializer: Optional["Deserializer"] = None,
):
self.spectrogram_size = spectrogram_size
self.time_mask_param = time_mask_param
Expand All @@ -48,7 +47,7 @@ def __init__(
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource()
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
Expand Down
7 changes: 4 additions & 3 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]
}


def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int,
freq_mask_param: int) -> Dict[str, Callable]:
def train_default_transforms(
spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int
) -> Dict[str, Callable]:
"""During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``"""
transforms = {
"post_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)),
)
}

Expand Down
4 changes: 2 additions & 2 deletions flash/audio/speech_recognition/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def speech_recognition():
SpeechRecognitionData,
default_datamodule_builder=from_timit,
default_arguments={
'trainer.max_epochs': 3,
"trainer.max_epochs": 3,
},
finetune=False,
)

cli.trainer.save_checkpoint("speech_recognition_model.pt")


if __name__ == '__main__':
if __name__ == "__main__":
speech_recognition()
Loading