Skip to content

Commit

Permalink
Add various type annotations (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardm-stripe authored Oct 13, 2023
1 parent f41d521 commit 22d7d88
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 71 deletions.
3 changes: 3 additions & 0 deletions flake8_stripe/flake8_stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class TypingImportsChecker:
"List",
"Generic",
"Mapping",
"Tuple",
"Iterator",
"Mapping",
]

def __init__(self, tree: ast.AST):
Expand Down
119 changes: 119 additions & 0 deletions flake8_stripe/flake8_stripe.py.orig
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Hint: if you're developing this plugin, test changes with:
# venv/bin/tox -e lint -r
# so that tox re-installs the plugin from the local directory
import ast
from typing import Iterator, Tuple


class TypingImportsChecker:
name = __name__
version = "0.1.0"

# Rules:
# * typing_extensions v4.1.1 is the latest that supports Python 3.6
# so don't depend on anything from a more recent version than that.
#
# If we need something newer, maybe we can provide it for users on
# newer versions with a conditional import, but we'll cross that
# bridge when we come to it.

# If a symbol exists in both `typing` and `typing_extensions`, which
# should you use? Prefer `typing_extensions` if the symbol available there.
# in 4.1.1. In typing_extensions 4.7.0, `typing_extensions` started re-exporting
# EVERYTHING from `typing` but this is not the case in v4.1.1.
allowed_typing_extensions_imports = [
"Literal",
"NoReturn",
"Protocol",
"TYPE_CHECKING",
"Type",
"TypedDict",
"NotRequired",
"Self",
"Unpack",
]

allowed_typing_imports = [
"Any",
"ClassVar",
"Optional",
"TypeVar",
"Union",
"cast",
"overload",
"Dict",
"Tuple",
"List",
"Generic",
<<<<<<< HEAD
"Mapping",
||||||| parent of f2e8187 (Lint)
"Tuple",
=======
"Tuple",
"Iterator",
"Mapping",
>>>>>>> f2e8187 (Lint)
]

def __init__(self, tree: ast.AST):
self.tree = tree

intersection = set(self.allowed_typing_imports) & set(
self.allowed_typing_extensions_imports
)
if len(intersection) > 0:
raise AssertionError(
"TypingImportsChecker: allowed_typing_imports and allowed_typing_extensions_imports must not overlap. Both entries contained: %s"
% (intersection)
)

def run(self) -> Iterator[Tuple[int, int, str, type]]:
for node in ast.walk(self.tree):
if isinstance(node, ast.ImportFrom):
if node.module == "typing":
for name in node.names:
if name.name not in self.allowed_typing_imports:
msg = None
if (
name.name
in self.allowed_typing_extensions_imports
):
msg = (
"SPY100 Don't import %s from 'typing', instead import from 'typing_extensions'"
% (name.name)
)
else:
msg = (
"SPY101 Importing %s from 'typing' is prohibited. Do you need to add to the allowlist in flake8_stripe.py?"
% (name.name)
)
yield (
name.lineno,
name.col_offset,
msg,
type(self),
)
elif node.module == "typing_extensions":
for name in node.names:
if (
name.name
not in self.allowed_typing_extensions_imports
):
msg = None
if name.name in self.allowed_typing_imports:
msg = (
"SPY102 Don't import '%s' from 'typing_extensions', instead import from 'typing'"
% (name.name)
)
else:
msg = (
"SPY103 Importing '%s' from 'typing_extensions' is prohibited. Do you need to add to the allowlist in flake8_stripe.py?"
% (name.name)
)
yield (
name.lineno,
name.col_offset,
msg,
type(self),
)
15 changes: 11 additions & 4 deletions stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
# Configuration variables
from stripe.api_version import _ApiVersion

api_key = None
client_id = None
from stripe.app_info import AppInfo

api_key: Optional[str] = None
client_id: Optional[str] = None
api_base = "https://api.stripe.com"
connect_api_base = "https://connect.stripe.com"
upload_api_base = "https://files.stripe.com"
api_version = _ApiVersion.CURRENT
verify_ssl_certs = True
proxy = None
default_http_client = None
app_info = None
app_info: Optional[AppInfo] = None
enable_telemetry = True
max_network_retries = 0
ca_bundle_path = os.path.join(
Expand Down Expand Up @@ -52,7 +54,12 @@
# communicating with Stripe.
#
# Takes a name and optional version and plugin URL.
def set_app_info(name, partner_id=None, url=None, version=None):
def set_app_info(
name: str,
partner_id: Optional[str] = None,
url: Optional[str] = None,
version: Optional[str] = None,
):
global app_info
app_info = {
"name": name,
Expand Down
131 changes: 84 additions & 47 deletions stripe/api_resources/list_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# pyright: strict
from typing_extensions import Self

from typing import List, Generic, TypeVar
from typing import (
Any,
Iterator,
List,
Generic,
Optional,
TypeVar,
cast,
Mapping,
)
from stripe.stripe_object import StripeObject

from urllib.parse import quote_plus
Expand All @@ -15,52 +25,62 @@ class ListObject(StripeObject, Generic[T]):
url: str

def list(
self, api_key=None, stripe_version=None, stripe_account=None, **params
):
self,
api_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
**params: Mapping[str, Any]
) -> Self:
url = self.get("url")
if not isinstance(url, str):
raise ValueError(
'Cannot call .list on a list object without a string "url" property'
)
return self._request(
"get",
url,
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
return cast(
Self,
self._request(
"get",
url,
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
),
)

def create(
self,
api_key=None,
idempotency_key=None,
stripe_version=None,
stripe_account=None,
**params
):
api_key: Optional[str] = None,
idempotency_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
**params: Mapping[str, Any]
) -> T:
url = self.get("url")
if not isinstance(url, str):
raise ValueError(
'Cannot call .create on a list object for the collection of an object without a string "url" property'
)
return self._request(
"post",
url,
api_key=api_key,
idempotency_key=idempotency_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
return cast(
T,
self._request(
"post",
url,
api_key=api_key,
idempotency_key=idempotency_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
),
)

def retrieve(
self,
id,
api_key=None,
stripe_version=None,
stripe_account=None,
**params
id: str,
api_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
**params: Mapping[str, Any]
):
url = self.get("url")
if not isinstance(url, str):
Expand All @@ -69,17 +89,20 @@ def retrieve(
)

url = "%s/%s" % (self.get("url"), quote_plus(id))
return self._request(
"get",
url,
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
return cast(
T,
self._request(
"get",
url,
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
),
)

def __getitem__(self, k):
if isinstance(k, str):
def __getitem__(self, k: str) -> T:
if isinstance(k, str): # pyright: ignore
return super(ListObject, self).__getitem__(k)
else:
raise KeyError(
Expand All @@ -89,16 +112,19 @@ def __getitem__(self, k):
".data[%s])" % (repr(k), repr(k))
)

def __iter__(self):
# Pyright doesn't like this because ListObject inherits from StripeObject inherits from Dict[str, Any]
# and so it wants the type of __iter__ to agree with __iter__ from Dict[str, Any]
# But we are iterating through "data", which is a List[T].
def __iter__(self) -> Iterator[T]: # pyright: ignore
return getattr(self, "data", []).__iter__()

def __len__(self):
def __len__(self) -> int:
return getattr(self, "data", []).__len__()

def __reversed__(self):
def __reversed__(self) -> Iterator[T]: # pyright: ignore (see above)
return getattr(self, "data", []).__reversed__()

def auto_paging_iter(self):
def auto_paging_iter(self) -> Iterator[T]:
page = self

while True:
Expand All @@ -119,8 +145,11 @@ def auto_paging_iter(self):

@classmethod
def empty_list(
cls, api_key=None, stripe_version=None, stripe_account=None
):
cls,
api_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
) -> Self:
return cls.construct_from(
{"data": []},
key=api_key,
Expand All @@ -130,11 +159,15 @@ def empty_list(
)

@property
def is_empty(self):
def is_empty(self) -> bool:
return not self.data

def next_page(
self, api_key=None, stripe_version=None, stripe_account=None, **params
self,
api_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
**params: Mapping[str, Any]
) -> Self:
if not self.has_more:
return self.empty_list(
Expand Down Expand Up @@ -163,7 +196,11 @@ def next_page(
return result

def previous_page(
self, api_key=None, stripe_version=None, stripe_account=None, **params
self,
api_key: Optional[str] = None,
stripe_version: Optional[str] = None,
stripe_account: Optional[str] = None,
**params: Mapping[str, Any]
) -> Self:
if not self.has_more:
return self.empty_list(
Expand Down
Loading

0 comments on commit 22d7d88

Please sign in to comment.