diff --git a/.github/DEVELOPMENT.md b/.github/DEVELOPMENT.md index acd934b9..b52359ad 100644 --- a/.github/DEVELOPMENT.md +++ b/.github/DEVELOPMENT.md @@ -63,6 +63,8 @@ See [Commits and pull requests](https://github.com/trinodb/trino/blob/master/.gi To run linting and formatting checks before opening a PR: `pip install pre-commit && pre-commit run --all-files` +Code can also be automatically checked on commit by a [pre-commit](https://pre-commit.com/) git hook by executing `pre-commit install`. + In addition to that you should also adhere to the following: ### Readability diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 34a0c2d9..8196638a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,7 +29,7 @@ jobs: run: pip install pre-commit - name: "Run pre-commit checks" - run: pre-commit run --all-files + run: pre-commit run --hook-stage manual --all-files build: runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 32bf21cc..16ee054d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,3 +13,9 @@ repos: additional_dependencies: - "types-pytz" - "types-requests" + + - repo: https://github.com/pycqa/isort + rev: 5.6.4 + hooks: + - id: isort + args: [ "--profile", "black"] diff --git a/setup.py b/setup.py index 3a06d40d..737ac912 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ import re import textwrap -from setuptools import setup, find_packages +from setuptools import find_packages, setup _version_re = re.compile(r"__version__\s+=\s+(.*)") @@ -40,6 +40,9 @@ "pytest-runner", "click", "sqlalchemy_utils", + "pre-commit", + "black", + "isort", ] setup( @@ -76,7 +79,7 @@ "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Database :: Front-Ends", ], - python_requires='>=3.7', + python_requires=">=3.7", install_requires=["pytz", "requests"], extras_require={ "all": all_require, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 840ee8f3..06b5249c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -18,11 +18,11 @@ from uuid import uuid4 import click -import trino.logging import pytest -from trino.client import TrinoQuery, TrinoRequest, ClientSession -from trino.constants import DEFAULT_PORT +import trino.logging +from trino.client import ClientSession, TrinoQuery, TrinoRequest +from trino.constants import DEFAULT_PORT logger = trino.logging.get_logger(__name__) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index e660b586..fb398296 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from datetime import datetime, time, date, timezone, timedelta +from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal import pytest @@ -20,7 +20,7 @@ import trino from tests.integration.conftest import trino_version from trino import constants -from trino.exceptions import TrinoQueryError, TrinoUserError, NotSupportedError +from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError from trino.transaction import IsolationLevel diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 593e71f9..fccfb35f 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -11,7 +11,7 @@ # limitations under the License import pytest import sqlalchemy as sqla -from sqlalchemy.sql import and_, or_, not_ +from sqlalchemy.sql import and_, not_, or_ from sqlalchemy_utils import create_view from tests.unit.conftest import sqlalchemy_version diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 5749a820..101c347b 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -1,6 +1,8 @@ import math -import pytest from decimal import Decimal + +import pytest + import trino diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8c5284e7..1f84d13e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import MagicMock, patch +import pytest + @pytest.fixture(scope="session") def sample_post_response_data(): diff --git a/tests/unit/sqlalchemy/conftest.py b/tests/unit/sqlalchemy/conftest.py index e80f19b8..71d6f74d 100644 --- a/tests/unit/sqlalchemy/conftest.py +++ b/tests/unit/sqlalchemy/conftest.py @@ -12,7 +12,7 @@ import pytest from sqlalchemy.sql.sqltypes import ARRAY -from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME +from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP, SQLType @pytest.fixture(scope="session") diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 9c27c041..1051bf3d 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -10,15 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sqlalchemy import ( - Column, - insert, - Integer, - MetaData, - select, - String, - Table, -) +from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select from sqlalchemy.schema import CreateTable from sqlalchemy.sql import column, table diff --git a/tests/unit/sqlalchemy/test_datatype_parse.py b/tests/unit/sqlalchemy/test_datatype_parse.py index daee569c..c345e3d4 100644 --- a/tests/unit/sqlalchemy/test_datatype_parse.py +++ b/tests/unit/sqlalchemy/test_datatype_parse.py @@ -11,23 +11,11 @@ # limitations under the License. import pytest from sqlalchemy.exc import UnsupportedCompilationError -from sqlalchemy.sql.sqltypes import ( - CHAR, - VARCHAR, - ARRAY, - INTEGER, - DECIMAL, - DATE -) +from sqlalchemy.sql.sqltypes import ARRAY, CHAR, DATE, DECIMAL, INTEGER, VARCHAR from sqlalchemy.sql.type_api import TypeEngine from trino.sqlalchemy import datatype -from trino.sqlalchemy.datatype import ( - MAP, - ROW, - TIME, - TIMESTAMP -) +from trino.sqlalchemy.datatype import MAP, ROW, TIME, TIMESTAMP @pytest.mark.parametrize( diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index 31c29670..3316fa6c 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -2,13 +2,17 @@ from unittest import mock import pytest -from sqlalchemy.engine.url import make_url, URL +from sqlalchemy.engine.url import URL, make_url from trino.auth import BasicAuthentication from trino.dbapi import Connection -from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect -from trino.transaction import IsolationLevel from trino.sqlalchemy import URL as trino_url +from trino.sqlalchemy.dialect import ( + CertificateAuthentication, + JWTAuthentication, + TrinoDialect, +) +from trino.transaction import IsolationLevel class TestTrinoDialect: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 5477a39d..abcf76c6 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,7 +13,7 @@ import threading import time import uuid -from typing import Optional, Dict +from typing import Dict, Optional from unittest import mock from urllib.parse import urlparse @@ -24,13 +24,28 @@ from requests_kerberos.exceptions import KerberosExchangeError import trino.exceptions -from tests.unit.oauth_test_utils import RedirectHandler, GetTokenCallback, PostStatementCallback, \ - MultithreadedTokenServer, _post_statement_requests, _get_token_requests, REDIRECT_RESOURCE, TOKEN_RESOURCE, \ - SERVER_ADDRESS +from tests.unit.oauth_test_utils import ( + REDIRECT_RESOURCE, + SERVER_ADDRESS, + TOKEN_RESOURCE, + GetTokenCallback, + MultithreadedTokenServer, + PostStatementCallback, + RedirectHandler, + _get_token_requests, + _post_statement_requests, +) from trino import constants from trino.auth import KerberosAuthentication, _OAuth2TokenBearer -from trino.client import TrinoQuery, TrinoRequest, TrinoResult, ClientSession, _DelayExponential, _retry_with, \ - _RetryWithExponentialBackoff +from trino.client import ( + ClientSession, + TrinoQuery, + TrinoRequest, + TrinoResult, + _DelayExponential, + _retry_with, + _RetryWithExponentialBackoff, +) @mock.patch("trino.client.TrinoRequest.http") diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index 81386430..7065dd4b 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -17,8 +17,16 @@ from httpretty import httprettified from requests import Session -from tests.unit.oauth_test_utils import _post_statement_requests, _get_token_requests, RedirectHandler, \ - GetTokenCallback, REDIRECT_RESOURCE, TOKEN_RESOURCE, PostStatementCallback, SERVER_ADDRESS +from tests.unit.oauth_test_utils import ( + REDIRECT_RESOURCE, + SERVER_ADDRESS, + TOKEN_RESOURCE, + GetTokenCallback, + PostStatementCallback, + RedirectHandler, + _get_token_requests, + _post_statement_requests, +) from trino import constants from trino.auth import OAuth2Authentication from trino.dbapi import connect diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index fd3c5e2d..9753bab5 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -10,8 +10,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trino.client import get_header_values, get_session_property_values from trino import constants +from trino.client import get_header_values, get_session_property_values def test_get_header_values(): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index e97903cf..ec919ce3 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -10,9 +10,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trino.transaction import IsolationLevel import pytest +from trino.transaction import IsolationLevel + def test_isolation_level_levels() -> None: levels = { diff --git a/trino/__init__.py b/trino/__init__.py index 8dd7f7c2..31c5dee2 100644 --- a/trino/__init__.py +++ b/trino/__init__.py @@ -10,12 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import auth -from . import dbapi -from . import client -from . import constants -from . import exceptions -from . import logging +from . import auth, client, constants, dbapi, exceptions, logging __all__ = ['auth', 'dbapi', 'client', 'constants', 'exceptions', 'logging'] diff --git a/trino/auth.py b/trino/auth.py index e6b4f04c..f7921e03 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -11,18 +11,18 @@ # limitations under the License. import abc +import importlib import json import os import re import threading import webbrowser -from typing import Optional, List, Callable +from typing import Callable, List, Optional from urllib.parse import urlparse from requests import Request from requests.auth import AuthBase, extract_cookies_to_jar from requests.utils import parse_dict_header -import importlib import trino.logging from trino.client import exceptions diff --git a/trino/constants.py b/trino/constants.py index 30046908..f8477b08 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -12,7 +12,6 @@ from typing import Any, Optional - DEFAULT_PORT = 8080 DEFAULT_TLS_PORT = 443 DEFAULT_SOURCE = "trino-python-client" diff --git a/trino/dbapi.py b/trino/dbapi.py index 683ced67..6a9542d7 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -17,30 +17,29 @@ Fetch methods returns rows as a list of lists on purpose to let the caller decide to convert then to a list of tuples. """ -from decimal import Decimal -from typing import Any, List, Optional # NOQA for mypy types - -import uuid import datetime import math +import uuid +from decimal import Decimal +from typing import Any, List, Optional # NOQA for mypy types -from trino import constants -import trino.exceptions import trino.client +import trino.exceptions import trino.logging -from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION +from trino import constants from trino.exceptions import ( - Warning, - Error, - InterfaceError, DatabaseError, DataError, - OperationalError, + Error, IntegrityError, + InterfaceError, InternalError, - ProgrammingError, NotSupportedError, + OperationalError, + ProgrammingError, + Warning, ) +from trino.transaction import NO_TRANSACTION, IsolationLevel, Transaction __all__ = [ # https://www.python.org/dev/peps/pep-0249/#globals diff --git a/trino/sqlalchemy/__init__.py b/trino/sqlalchemy/__init__.py index 3c10f0b8..0908d275 100644 --- a/trino/sqlalchemy/__init__.py +++ b/trino/sqlalchemy/__init__.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.dialects import registry + from .util import _url as URL # noqa registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect") diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 0078e689..8747d190 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -12,7 +12,6 @@ from sqlalchemy.sql import compiler from sqlalchemy.sql.base import DialectKWArgs - # https://trino.io/docs/current/language/reserved.html RESERVED_WORDS = { "alter", diff --git a/trino/sqlalchemy/datatype.py b/trino/sqlalchemy/datatype.py index 44961762..224996e9 100644 --- a/trino/sqlalchemy/datatype.py +++ b/trino/sqlalchemy/datatype.py @@ -11,7 +11,7 @@ # limitations under the License. import json import re -from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from sqlalchemy import util from sqlalchemy.sql import sqltypes diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 46349088..7b0b4bf9 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -20,7 +20,8 @@ from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.url import URL -from trino import dbapi as trino_dbapi, logging +from trino import dbapi as trino_dbapi +from trino import logging from trino.auth import BasicAuthentication, CertificateAuthentication, JWTAuthentication from trino.dbapi import Cursor from trino.sqlalchemy import compiler, datatype, error diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py index 44814020..117830bb 100644 --- a/trino/sqlalchemy/util.py +++ b/trino/sqlalchemy/util.py @@ -1,8 +1,8 @@ import json import re +from typing import Dict, List, Optional, Tuple, Union from urllib.parse import quote_plus -from typing import Optional, Dict, List, Union, Tuple from sqlalchemy import exc diff --git a/trino/transaction.py b/trino/transaction.py index e6c85234..c6e6257d 100644 --- a/trino/transaction.py +++ b/trino/transaction.py @@ -12,11 +12,10 @@ from enum import Enum, unique from typing import Iterable -from trino import constants import trino.client import trino.exceptions import trino.logging - +from trino import constants logger = trino.logging.get_logger(__name__)