Skip to content

Commit

Permalink
Add MyPy type checking to our CI process (#180)
Browse files Browse the repository at this point in the history
* Start adding mypy

* Add mypy to github actions.

* Add mypy workflow to gh actions.

* Use future annotations instead of string annotations.

* Temporarily pin `markupsafe` for breaking change.

- Temporarily solves pallets/markupsafe#284
- Remove `markupsafe` entry once `jinja2` has been upgraded.

* Temporarily exclude some files from mypy checks.

Co-authored-by: Tim DiLauro <[email protected]>
  • Loading branch information
jonathangreen and tdilauro authored Mar 2, 2022
1 parent a54bb16 commit 532b7c7
Show file tree
Hide file tree
Showing 60 changed files with 1,600 additions and 986 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Mypy (Type check)
on: [push, pull_request]
env:
PYTHON_VERSION: 3.9

jobs:
mypy:
runs-on: ubuntu-latest
permissions:
contents: read

steps:
- uses: actions/checkout@v2

- name: Set up Python 🐍
uses: actions/setup-python@v2
with:
python-version: ${{ env.PYTHON_VERSION }}

- name: Install Poetry 🎸
uses: ./.github/actions/poetry

- name: Install OS Packages 🧰
run: |
sudo apt-get update
sudo apt-get install --yes libxmlsec1-dev libxml2-dev
- name: Install Python Packages 📦
run: poetry install

- name: Run MyPy 🪄
run: poetry run mypy
4 changes: 2 additions & 2 deletions core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ class Configuration(ConfigurationConstants):
"key": facet,
"label": FacetConstants.FACET_DISPLAY_TITLES.get(facet),
}
for facet in FacetConstants.FACETS_BY_GROUP.get(group)
for facet in FacetConstants.FACETS_BY_GROUP.get(group, [])
],
"default": FacetConstants.FACETS_BY_GROUP.get(group),
"category": "Lanes & Filters",
Expand All @@ -348,7 +348,7 @@ class Configuration(ConfigurationConstants):
"key": facet,
"label": FacetConstants.FACET_DISPLAY_TITLES.get(facet),
}
for facet in FacetConstants.FACETS_BY_GROUP.get(group)
for facet in FacetConstants.FACETS_BY_GROUP.get(group, [])
],
"default": FacetConstants.DEFAULT_FACET.get(group),
"category": "Lanes & Filters",
Expand Down
15 changes: 8 additions & 7 deletions core/coverage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import traceback
from typing import Optional, Union

from sqlalchemy.orm.session import Session
from sqlalchemy.sql.functions import func
Expand Down Expand Up @@ -135,21 +136,21 @@ class BaseCoverageProvider(object):

# In your subclass, set this to the name of the service,
# e.g. "Overdrive Bibliographic Coverage Provider".
SERVICE_NAME = None
SERVICE_NAME: Optional[str] = None

# In your subclass, you _may_ set this to a string that distinguishes
# two different CoverageProviders from the same data source.
# (You may also override the operation method, if you need
# database access to determine which operation to use.)
OPERATION = None
OPERATION: Optional[str] = None

# The database session will be committed each time the
# BaseCoverageProvider has (attempted to) provide coverage to this
# number of Identifiers. You may change this in your subclass.
# It's also possible to change it by passing in a value for
# `batch_size` in the constructor, but generally nobody bothers
# doing this.
DEFAULT_BATCH_SIZE = 100
DEFAULT_BATCH_SIZE: int = 100

def __init__(
self,
Expand Down Expand Up @@ -582,7 +583,7 @@ class IdentifierCoverageProvider(BaseCoverageProvider):

# In your subclass, set this to the name of the data source you
# consult when providing coverage, e.g. DataSource.OVERDRIVE.
DATA_SOURCE_NAME = None
DATA_SOURCE_NAME: str

# In your subclass, set this to a single identifier type, or a list
# of identifier types. The CoverageProvider will attempt to give
Expand All @@ -591,7 +592,7 @@ class IdentifierCoverageProvider(BaseCoverageProvider):
# Setting this to None will attempt to give coverage to every single
# Identifier in the system, which is probably not what you want.
NO_SPECIFIED_TYPES = object()
INPUT_IDENTIFIER_TYPES = NO_SPECIFIED_TYPES
INPUT_IDENTIFIER_TYPES: Union[None, str, object] = NO_SPECIFIED_TYPES

# Set this to False if a given Identifier needs to be run through
# this CoverageProvider once for every Collection that has this
Expand Down Expand Up @@ -1077,14 +1078,14 @@ class CollectionCoverageProvider(IdentifierCoverageProvider):

# By default, this type of CoverageProvider will provide coverage to
# all Identifiers in the given Collection, regardless of their type.
INPUT_IDENTIFIER_TYPES = None
INPUT_IDENTIFIER_TYPES: Union[None, str, object] = None

DEFAULT_BATCH_SIZE = 10

# Set this to the name of the protocol managed by this type of
# CoverageProvider. If this CoverageProvider can manage collections
# for any protocol, leave this as None.
PROTOCOL = None
PROTOCOL: Optional[str] = None

# By default, Works calculated by a CollectionCoverageProvider update
# the ExternalSearchIndex. Set this value to True for applications that
Expand Down
16 changes: 11 additions & 5 deletions core/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from typing import Dict, List, Optional, Type


class EntryPoint(object):

"""A EntryPoint is a top-level entry point into a library's Lane structure
Expand All @@ -21,14 +26,15 @@ class by calling EntryPoint.register.
# enabled.
ENABLED_SETTING = "enabled_entry_points"

ENTRY_POINTS = []
DEFAULT_ENABLED = []
DISPLAY_TITLES = {}
BY_INTERNAL_NAME = {}
ENTRY_POINTS: List[Type[EntryPoint]] = []
DEFAULT_ENABLED: List[Type[EntryPoint]] = []
DISPLAY_TITLES: Dict[Type[EntryPoint], str] = {}
BY_INTERNAL_NAME: Dict[str, Type[EntryPoint]] = {}

# A distinctive URI designating the sort of thing found through this
# EntryPoint.
URI = None
URI: Optional[str] = None
INTERNAL_NAME: str

@classmethod
def register(cls, entrypoint_class, display_title, default_enabled=False):
Expand Down
15 changes: 8 additions & 7 deletions core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Optional


class BaseError(Exception):
"""Base class for all errors"""

def __init__(self, message=None, inner_exception=None):
def __init__(
self, message: Optional[str] = None, inner_exception: Optional[Exception] = None
):
"""Initializes a new instance of BaseError class
:param message: String containing description of the error occurred
Expand All @@ -18,22 +23,18 @@ def __hash__(self):
return hash(str(self))

@property
def inner_exception(self):
def inner_exception(self) -> Optional[str]:
"""Returns an inner exception
:return: Inner exception
:rtype: Exception
"""
return self._inner_exception

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
"""Compares two BaseError objects
:param other: BaseError object
:type other: BaseError
:return: Boolean value indicating whether two items are equal
:rtype: bool
"""
if not isinstance(other, BaseError):
return False
Expand Down
3 changes: 2 additions & 1 deletion core/external_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import time
from collections import defaultdict
from typing import Optional

from elasticsearch import Elasticsearch
from elasticsearch.exceptions import ElasticsearchException, RequestError
Expand Down Expand Up @@ -828,7 +829,7 @@ class Mapping(MappingDocument):
can change between versions without affecting anything.)
"""

VERSION_NAME = None
VERSION_NAME: Optional[str] = None

@classmethod
def version_name(cls):
Expand Down
41 changes: 20 additions & 21 deletions core/lane.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Optional
from urllib.parse import quote_plus

import elasticsearch
Expand All @@ -22,7 +23,6 @@
)
from sqlalchemy.dialects.postgresql import ARRAY, INT4RANGE, JSON
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
aliased,
backref,
Expand All @@ -31,9 +31,12 @@
joinedload,
relationship,
)
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import Select

from core.model.hybrid import hybrid_property

from .classifier import Classifier
from .config import Configuration
from .entrypoint import EntryPoint, EverythingEntryPoint
Expand All @@ -49,22 +52,23 @@
Genre,
Library,
LicensePool,
Session,
Work,
WorkGenre,
directly_modified,
get_one_or_create,
site_configuration_has_changed,
tuple_to_numericrange,
)
from .model.constants import EditionConstants
from .model.listeners import directly_modified, site_configuration_has_changed
from .problem_details import *
from .util import LanguageCodes
from .util.accept_language import parse_accept_language
from .util.datetime_helpers import utc_now
from .util.opds_writer import OPDSFeed
from .util.problem_detail import ProblemDetail

if TYPE_CHECKING:
from core.model import CachedMARCFile # noqa: autoflake


class BaseFacets(FacetConstants):
"""Basic faceting class that doesn't modify a search filter at all.
Expand All @@ -76,7 +80,7 @@ class BaseFacets(FacetConstants):
# type of feed (the way FeaturedFacets always implies a 'groups' feed),
# set the type of feed here. This will override any CACHED_FEED_TYPE
# associated with the WorkList.
CACHED_FEED_TYPE = None
CACHED_FEED_TYPE: Optional[str] = None

# By default, faceting objects have no opinion on how long the feeds
# generated using them should be cached.
Expand Down Expand Up @@ -1281,10 +1285,9 @@ class WorkList(object):
CACHED_FEED_TYPE = None

# By default, a WorkList is always visible.
visible = True

# By default, a WorkList does not draw from CustomLists
uses_customlists = False
@property
def visible(self) -> bool:
return True

def max_cache_age(self, type):
"""Determine how long a feed for this WorkList should be cached
Expand Down Expand Up @@ -2500,7 +2503,7 @@ def from_genre(cls, genre):
return lg


Genre.lane_genres = relationship(
Genre.lane_genres = relationship( # type: ignore
"LaneGenre", foreign_keys=LaneGenre.genre_id, backref="genre"
)

Expand Down Expand Up @@ -2565,11 +2568,11 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):

# A lane may be restricted to works classified for specific audiences
# (e.g. only Young Adult works).
_audiences = Column(ARRAY(Unicode), name="audiences")
_audiences = Column("audiences", ARRAY(Unicode))

# A lane may further be restricted to works classified as suitable
# for a specific age range.
_target_age = Column(INT4RANGE, name="target_age", index=True)
_target_age = Column("target_age", INT4RANGE, index=True)

# A lane may be restricted to works available in certain languages.
languages = Column(ARRAY(Unicode))
Expand All @@ -2594,7 +2597,7 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):

# Only the books on these specific CustomLists will be shown.
customlists = relationship(
"CustomList", secondary=lambda: lanes_customlists, backref="lane"
"CustomList", secondary=lambda: lanes_customlists, backref="lane" # type: ignore
)

# This has no effect unless list_datasource_id or
Expand Down Expand Up @@ -2626,7 +2629,7 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):

# Only a visible lane will show up in the user interface. The
# admin interface can see all the lanes, visible or not.
_visible = Column(Boolean, default=True, nullable=False, name="visible")
_visible = Column("visible", Boolean, default=True, nullable=False)

# A Lane may have many CachedFeeds.
cachedfeeds = relationship(
Expand All @@ -2648,10 +2651,6 @@ def get_library(self, _db):
"""For compatibility with WorkList.get_library()."""
return self.library

@property
def list_datasource_id(self):
return self._list_datasource_id

@property
def collection_ids(self):
return [x.id for x in self.library.collections]
Expand Down Expand Up @@ -3148,16 +3147,16 @@ def explain(self):
return lines


Library.lanes = relationship(
Library.lanes = relationship( # type: ignore
"Lane",
backref="library",
foreign_keys=Lane.library_id,
cascade="all, delete-orphan",
)
DataSource.list_lanes = relationship(
DataSource.list_lanes = relationship( # type: ignore
"Lane", backref="_list_datasource", foreign_keys=Lane._list_datasource_id
)
DataSource.license_lanes = relationship(
DataSource.license_lanes = relationship( # type: ignore
"Lane", backref="license_datasource", foreign_keys=Lane.license_datasource_id
)

Expand Down
Loading

0 comments on commit 532b7c7

Please sign in to comment.