From f4647f536eb92f7b0f18857ee6560ad20544be09 Mon Sep 17 00:00:00 2001 From: pierwill Date: Wed, 3 Jun 2020 16:06:11 -0700 Subject: [PATCH] Update syntax for existing type annotations Uses syntax described in PEP484. --- securedrop/crypto_util.py | 3 +- securedrop/journalist_app/__init__.py | 12 +-- securedrop/journalist_app/utils.py | 6 +- securedrop/models.py | 140 +++++++++++--------------- securedrop/sdconfig.py | 3 +- securedrop/source_app/__init__.py | 3 +- securedrop/store.py | 50 +++++---- securedrop/worker.py | 12 +-- 8 files changed, 99 insertions(+), 130 deletions(-) diff --git a/securedrop/crypto_util.py b/securedrop/crypto_util.py index 5d7a05008e7..6918f3916bd 100644 --- a/securedrop/crypto_util.py +++ b/securedrop/crypto_util.py @@ -135,8 +135,7 @@ def do_runtime_tests(self): if not rm.check_secure_delete_capability(): raise AssertionError("Secure file deletion is not possible.") - def get_wordlist(self, locale): - # type: (Text) -> List[str] + def get_wordlist(self, locale: Text) -> List[str]: """" Ensure the wordlist for the desired locale is read and available in the words global variable. If there is no wordlist for the desired local, fallback to the default english wordlist. diff --git a/securedrop/journalist_app/__init__.py b/securedrop/journalist_app/__init__.py index 959dd46bc3c..33f03e7ce4b 100644 --- a/securedrop/journalist_app/__init__.py +++ b/securedrop/journalist_app/__init__.py @@ -38,8 +38,7 @@ _insecure_views = ['main.login', 'main.select_logo', 'static'] -def create_app(config): - # type: (SDConfig) -> Flask +def create_app(config: SDConfig) -> Flask: app = Flask(__name__, template_folder=config.JOURNALIST_TEMPLATES_DIR, static_folder=path.join(config.SECUREDROP_ROOT, 'static')) @@ -82,16 +81,14 @@ def create_app(config): ) @app.errorhandler(CSRFError) - def handle_csrf_error(e): - # type: (CSRFError) -> Response + def handle_csrf_error(e: CSRFError) -> Response: # render the message first to ensure it's localized. msg = gettext('You have been logged out due to inactivity') session.clear() flash(msg, 'error') return redirect(url_for('main.login')) - def _handle_http_exception(error): - # type: (HTTPException) -> Tuple[Union[Response, str], Optional[int]] + def _handle_http_exception(error: HTTPException) -> Tuple[Union[Response, str], Optional[int]]: # Workaround for no blueprint-level 404/5 error handlers, see: # https://github.com/pallets/flask/issues/503#issuecomment-71383286 handler = list(app.error_handler_spec['api'][error.code].values())[0] @@ -129,8 +126,7 @@ def load_instance_config(): app.instance_config = InstanceConfig.get_current() @app.before_request - def setup_g(): - # type: () -> Optional[Response] + def setup_g() -> Optional[Response]: """Store commonly used values in Flask's special g object""" if 'expires' in session and datetime.utcnow() >= session['expires']: session.clear() diff --git a/securedrop/journalist_app/utils.py b/securedrop/journalist_app/utils.py index ca941572672..6f1bf73010f 100644 --- a/securedrop/journalist_app/utils.py +++ b/securedrop/journalist_app/utils.py @@ -27,8 +27,7 @@ from sdconfig import SDConfig # noqa: F401 -def logged_in(): - # type: () -> bool +def logged_in() -> bool: # When a user is logged in, we push their user ID (database primary key) # into the session. setup_g checks for this value, and if it finds it, # stores a reference to the user's Journalist object in g. @@ -255,8 +254,7 @@ def col_delete(cols_selected): return redirect(url_for('main.index')) -def make_password(config): - # type: (SDConfig) -> str +def make_password(config: SDConfig) -> str: while True: password = current_app.crypto_util.genrandomid( 7, diff --git a/securedrop/models.py b/securedrop/models.py index 68e02ca16fe..b6faeb3ea0f 100644 --- a/securedrop/models.py +++ b/securedrop/models.py @@ -42,8 +42,9 @@ ARGON2_PARAMS = dict(memory_cost=2**16, rounds=4, parallelism=2) -def get_one_or_else(query, logger, failure_method): - # type: (Query, Logger, Callable[[int], None]) -> None +def get_one_or_else(query: Query, + logger: Logger, + failure_method: Callable[[int], None]) -> None: try: return query.one() except MultipleResultsFound as e: @@ -80,25 +81,23 @@ class Source(db.Model): NUM_WORDS = 7 MAX_CODENAME_LEN = 128 - def __init__(self, filesystem_id=None, journalist_designation=None): - # type: (str, str) -> None + def __init__(self, + filesystem_id: str = None, + journalist_designation: str = None) -> None: self.filesystem_id = filesystem_id self.journalist_designation = journalist_designation self.uuid = str(uuid.uuid4()) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '' % (self.journalist_designation) @property - def journalist_filename(self): - # type: () -> str + def journalist_filename(self) -> str: valid_chars = 'abcdefghijklmnopqrstuvwxyz1234567890-_' return ''.join([c for c in self.journalist_designation.lower().replace( ' ', '_') if c in valid_chars]) - def documents_messages_count(self): - # type: () -> Dict[str, int] + def documents_messages_count(self) -> Dict[str, int]: self.docs_msgs_count = {'messages': 0, 'documents': 0} for submission in self.submissions: if submission.filename.endswith('msg.gpg'): @@ -109,8 +108,7 @@ def documents_messages_count(self): return self.docs_msgs_count @property - def collection(self): - # type: () -> List[Union[Submission, Reply]] + def collection(self) -> List[Union[Submission, Reply]]: """Return the list of submissions and replies for this source, sorted in ascending order by the filename/interaction count.""" collection = [] # type: List[Union[Submission, Reply]] @@ -132,22 +130,18 @@ def fingerprint(self): raise NotImplementedError @property - def public_key(self): - # type: () -> str + def public_key(self) -> str: return current_app.crypto_util.get_pubkey(self.filesystem_id) @public_key.setter - def public_key(self, value): - # type: (str) -> None + def public_key(self, value: str) -> None: raise NotImplementedError @public_key.deleter - def public_key(self): - # type: () -> None + def public_key(self) -> None: raise NotImplementedError - def to_json(self): - # type: () -> Dict[str, Union[str, bool, int, str]] + def to_json(self) -> Dict[str, Union[str, bool, int, str]]: docs_msg_count = self.documents_messages_count() if self.last_updated: @@ -208,20 +202,17 @@ class Submission(db.Model): ''' checksum = Column(String(255)) - def __init__(self, source, filename): - # type: (Source, str) -> None + def __init__(self, source: Source, filename: str) -> None: self.source_id = source.id self.filename = filename self.uuid = str(uuid.uuid4()) self.size = os.stat(current_app.storage.path(source.filesystem_id, filename)).st_size - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '' % (self.filename) - def to_json(self): - # type: () -> Dict[str, Union[str, int, bool]] + def to_json(self) -> Dict[str, Union[str, int, bool]]: json_submission = { 'source_url': url_for('api.single_source', source_uuid=self.source.uuid), @@ -268,8 +259,10 @@ class Reply(db.Model): deleted_by_source = Column(Boolean, default=False, nullable=False) - def __init__(self, journalist, source, filename): - # type: (Journalist, Source, str) -> None + def __init__(self, + journalist: Journalist, + source: Source, + filename: str) -> None: self.journalist_id = journalist.id self.source_id = source.id self.uuid = str(uuid.uuid4()) @@ -277,12 +270,10 @@ def __init__(self, journalist, source, filename): self.size = os.stat(current_app.storage.path(source.filesystem_id, filename)).st_size - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return '' % (self.filename) - def to_json(self): - # type: () -> Dict[str, Union[str, int, bool]] + def to_json(self) -> Dict[str, Union[str, int, bool]]: username = "deleted" first_name = "" last_name = "" @@ -316,15 +307,13 @@ class SourceStar(db.Model): source_id = Column("source_id", Integer, ForeignKey('sources.id')) starred = Column("starred", Boolean, default=True) - def __eq__(self, other): - # type: (Any) -> bool + def __eq__(self, other: Any) -> bool: if isinstance(other, SourceStar): return (self.source_id == other.source_id and self.id == other.id and self.starred == other.starred) return False - def __init__(self, source, starred=True): - # type: (Source, bool) -> None + def __init__(self, source: Source, starred: bool = True) -> None: self.source_id = source.id self.starred = starred @@ -379,12 +368,10 @@ class InvalidPasswordLength(PasswordError): password length. """ - def __init__(self, passphrase): - # type: (str) -> None + def __init__(self, passphrase: str) -> None: self.passphrase_len = len(passphrase) - def __str__(self): - # type: () -> str + def __str__(self) -> str: if self.passphrase_len > Journalist.MAX_PASSWORD_LEN: return "Password too long (len={})".format(self.passphrase_len) if self.passphrase_len < Journalist.MIN_PASSWORD_LEN: @@ -428,9 +415,13 @@ class Journalist(db.Model): MIN_NAME_LEN = 0 MAX_NAME_LEN = 100 - def __init__(self, username, password, first_name=None, last_name=None, is_admin=False, - otp_secret=None): - # type: (str, str, Optional[str], Optional[str], bool, Optional[str]) -> None + def __init__(self, + username: str, + password: str, + first_name: Optional[str] = None, + last_name: Optional[str] = None, + is_admin: bool = False, + otp_secret: Optional[str] = None) -> None: self.check_username_acceptable(username) self.username = username @@ -447,23 +438,20 @@ def __init__(self, username, password, first_name=None, last_name=None, is_admin if otp_secret: self.set_hotp_secret(otp_secret) - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return "".format( self.username, " [admin]" if self.is_admin else "") _LEGACY_SCRYPT_PARAMS = dict(N=2**14, r=8, p=1) - def _scrypt_hash(self, password, salt): - # type: (str, str) -> str + def _scrypt_hash(self, password: str, salt: str) -> str: return scrypt.hash(str(password), salt, **self._LEGACY_SCRYPT_PARAMS) MAX_PASSWORD_LEN = 128 MIN_PASSWORD_LEN = 14 - def set_password(self, passphrase): - # type: (str) -> None + def set_password(self, passphrase: str) -> None: self.check_password_acceptable(passphrase) # "migrate" from the legacy case @@ -490,8 +478,7 @@ def set_name(self, first_name, last_name): self.last_name = last_name @classmethod - def check_username_acceptable(cls, username): - # type: (str) -> None + def check_username_acceptable(cls, username: str) -> None: if len(username) < cls.MIN_USERNAME_LEN: raise InvalidUsernameException( 'Username "{}" must be at least {} characters long.' @@ -504,8 +491,7 @@ def check_name_acceptable(cls, name): raise InvalidNameLength(name) @classmethod - def check_password_acceptable(cls, password): - # type: (str) -> None + def check_password_acceptable(cls, password: str) -> None: # Enforce a reasonable maximum length for passwords to avoid DoS if len(password) > cls.MAX_PASSWORD_LEN: raise InvalidPasswordLength(password) @@ -518,8 +504,7 @@ def check_password_acceptable(cls, password): if len(password.split()) < 7: raise NonDicewarePassword() - def valid_password(self, passphrase): - # type: (str) -> bool + def valid_password(self, passphrase: str) -> bool: # Avoid hashing passwords that are over the maximum length if len(passphrase) > self.MAX_PASSWORD_LEN: raise InvalidPasswordLength(passphrase) @@ -549,12 +534,10 @@ def valid_password(self, passphrase): return is_valid - def regenerate_totp_shared_secret(self): - # type: () -> None + def regenerate_totp_shared_secret(self) -> None: self.otp_secret = pyotp.random_base32() - def set_hotp_secret(self, otp_secret): - # type: (str) -> None + def set_hotp_secret(self, otp_secret) -> None: self.otp_secret = base64.b32encode( binascii.unhexlify( otp_secret.replace( @@ -564,24 +547,21 @@ def set_hotp_secret(self, otp_secret): self.hotp_counter = 0 @property - def totp(self): - # type: () -> OTP + def totp(self) -> OTP: if self.is_totp: return pyotp.TOTP(self.otp_secret) else: raise ValueError('{} is not using TOTP'.format(self)) @property - def hotp(self): - # type: () -> OTP + def hotp(self) -> OTP: if not self.is_totp: return pyotp.HOTP(self.otp_secret) else: raise ValueError('{} is not using HOTP'.format(self)) @property - def shared_secret_qrcode(self): - # type: () -> Markup + def shared_secret_qrcode(self) -> Markup: uri = self.totp.provisioning_uri( self.username, issuer_name="SecureDrop") @@ -598,8 +578,7 @@ def shared_secret_qrcode(self): return Markup(svg_out.getvalue().decode('utf-8')) @property - def formatted_otp_secret(self): - # type: () -> str + def formatted_otp_secret(self) -> str: """The OTP secret is easier to read and manually enter if it is all lowercase and split into four groups of four characters. The secret is base32-encoded, so it is case insensitive.""" @@ -607,14 +586,12 @@ def formatted_otp_secret(self): chunks = [sec[i:i + 4] for i in range(0, len(sec), 4)] return ' '.join(chunks).lower() - def _format_token(self, token): - # type: (str) -> str + def _format_token(self, token: str) -> str: """Strips from authentication tokens the whitespace that many clients add for readability""" return ''.join(token.split()) - def verify_token(self, token): - # type: (str) -> bool + def verify_token(self, token: str) -> bool: token = self._format_token(token) # Store latest token to prevent OTP token reuse @@ -641,8 +618,7 @@ def verify_token(self, token): _MAX_LOGIN_ATTEMPTS_PER_PERIOD = 5 @classmethod - def throttle_login(cls, user): - # type: (Journalist) -> None + def throttle_login(cls, user: Journalist) -> None: # Record the login attempt... login_attempt = JournalistLoginAttempt(user) db.session.add(login_attempt) @@ -661,8 +637,10 @@ def throttle_login(cls, user): cls._LOGIN_ATTEMPT_PERIOD)) @classmethod - def login(cls, username, password, token): - # type: (str, str, str) -> Journalist + def login(cls, + username: str, + password: str, + token: str) -> Journalist: try: user = Journalist.query.filter_by(username=username).one() except NoResultFound: @@ -683,8 +661,7 @@ def login(cls, username, password, token): raise WrongPasswordException("invalid password") return user - def generate_api_token(self, expiration): - # type: (int) -> str + def generate_api_token(self, expiration: int) -> str: s = TimedJSONWebSignatureSerializer( current_app.config['SECRET_KEY'], expires_in=expiration) return s.dumps({'id': self.id}).decode('ascii') # type:ignore @@ -700,8 +677,7 @@ def validate_token_is_not_expired_or_invalid(token): return True @staticmethod - def validate_api_token_and_get_user(token): - # type: (str) -> Union[Journalist, None] + def validate_api_token_and_get_user(token: str) -> Union[Journalist, None]: s = TimedJSONWebSignatureSerializer(current_app.config['SECRET_KEY']) try: data = s.loads(token) @@ -714,8 +690,7 @@ def validate_api_token_and_get_user(token): return Journalist.query.get(data['id']) - def to_json(self): - # type: () -> Dict[str, Union[str, bool, str]] + def to_json(self) -> Dict[str, Union[str, bool, str]]: json_user = { 'username': self.username, 'last_login': self.last_access.isoformat() + 'Z', @@ -737,8 +712,7 @@ class JournalistLoginAttempt(db.Model): timestamp = Column(DateTime, default=datetime.datetime.utcnow) journalist_id = Column(Integer, ForeignKey('journalists.id')) - def __init__(self, journalist): - # type: (Journalist) -> None + def __init__(self, journalist: Journalist) -> None: self.journalist_id = journalist.id diff --git a/securedrop/sdconfig.py b/securedrop/sdconfig.py index 638e71a7872..c0267a6a38e 100644 --- a/securedrop/sdconfig.py +++ b/securedrop/sdconfig.py @@ -13,8 +13,7 @@ class SDConfig(object): - def __init__(self): - # type: () -> None + def __init__(self) -> None: try: self.JournalistInterfaceFlaskConfig = \ _config.JournalistInterfaceFlaskConfig # type: ignore diff --git a/securedrop/source_app/__init__.py b/securedrop/source_app/__init__.py index 9ae573294cd..1845c2e46d5 100644 --- a/securedrop/source_app/__init__.py +++ b/securedrop/source_app/__init__.py @@ -31,8 +31,7 @@ from sdconfig import SDConfig # noqa: F401 -def create_app(config): - # type: (SDConfig) -> Flask +def create_app(config: SDConfig) -> Flask: app = Flask(__name__, template_folder=config.SOURCE_TEMPLATES_DIR, static_folder=path.join(config.SECUREDROP_ROOT, 'static')) diff --git a/securedrop/store.py b/securedrop/store.py index ad5271579d3..a1962051f24 100644 --- a/securedrop/store.py +++ b/securedrop/store.py @@ -89,8 +89,7 @@ def safe_renames(old, new): class Storage: - def __init__(self, storage_path, temp_dir, gpg_key): - # type: (str, str, str) -> None + def __init__(self, storage_path: str, temp_dir: str, gpg_key: str) -> None: if not os.path.isabs(storage_path): raise PathException("storage_path {} is not absolute".format( storage_path)) @@ -165,8 +164,7 @@ def path(self, filesystem_id: str, filename: str = '') -> str: ) return absolute - def path_without_filesystem_id(self, filename): - # type: (str) -> str + def path_without_filesystem_id(self, filename: str) -> str: """Get the normalized, absolute file path, within `self.__storage_path` for a filename when the filesystem_id is not known. @@ -191,8 +189,9 @@ def path_without_filesystem_id(self, filename): ) return absolute - def get_bulk_archive(self, selected_submissions, zip_directory=''): - # type: (List, str) -> _TemporaryFileWrapper + def get_bulk_archive(self, + selected_submissions: List, + zip_directory: str = '') -> _TemporaryFileWrapper: """Generate a zip file from the selected submissions""" zip_file = tempfile.NamedTemporaryFile( prefix='tmp_securedrop_bulk_dl_', @@ -295,9 +294,12 @@ def clear_shredder(self): os.rmdir(d) current_app.logger.debug("Removed directory {}/{}: {}".format(i, dir_count, d)) - def save_file_submission(self, filesystem_id, count, journalist_filename, - filename, stream): - # type: (str, int, str, str, BufferedIOBase) -> str + def save_file_submission(self, + filesystem_id: str, + count: int, + journalist_filename: str, + filename: str, + stream: BufferedIOBase) -> str: sanitized_filename = secure_filename(filename) # We store file submissions in a .gz file for two reasons: @@ -333,9 +335,11 @@ def save_file_submission(self, filesystem_id, count, journalist_filename, return encrypted_file_name - def save_pre_encrypted_reply(self, filesystem_id, count, - journalist_filename, content): - # type: (str, int, str, str) -> str + def save_pre_encrypted_reply(self, + filesystem_id: str, + count: int, + journalist_filename: str, + content: str) -> str: if '-----BEGIN PGP MESSAGE-----' not in content.split('\n')[0]: raise NotEncrypted @@ -348,17 +352,18 @@ def save_pre_encrypted_reply(self, filesystem_id, count, return encrypted_file_path - def save_message_submission(self, filesystem_id, count, - journalist_filename, message): - # type: (str, int, str, str) -> str + def save_message_submission(self, + filesystem_id: str, + count: int, + journalist_filename: str, + message: str) -> str: filename = "{0}-{1}-msg.gpg".format(count, journalist_filename) msg_loc = self.path(filesystem_id, filename) current_app.crypto_util.encrypt(message, self.__gpg_key, msg_loc) return filename -def async_add_checksum_for_file(db_obj): - # type: (Union[Submission, Reply]) -> str +def async_add_checksum_for_file(db_obj: Union[Submission, Reply]) -> str: return create_queue().enqueue( queued_add_checksum_for_file, type(db_obj), @@ -368,8 +373,10 @@ def async_add_checksum_for_file(db_obj): ) -def queued_add_checksum_for_file(db_model, model_id, file_path, db_uri): - # type: (Union[Type[Submission], Type[Reply]], int, str, str) -> str +def queued_add_checksum_for_file(db_model: Union[Type[Submission], Type[Reply]], + model_id: int, + file_path: str, + db_uri: str) -> str: # we have to create our own DB session because there is no app context session = sessionmaker(bind=create_engine(db_uri))() db_obj = session.query(db_model).filter_by(id=model_id).one() @@ -378,8 +385,9 @@ def queued_add_checksum_for_file(db_model, model_id, file_path, db_uri): return "success" -def add_checksum_for_file(session, db_obj, file_path): - # type: (Session, Union[Submission, Reply], str) -> None +def add_checksum_for_file(session: Session, + db_obj: Union[Submission, Reply], + file_path: str) -> None: hasher = sha256() with open(file_path, 'rb') as f: while True: diff --git a/securedrop/worker.py b/securedrop/worker.py index ed2ea63a493..3360af9642e 100644 --- a/securedrop/worker.py +++ b/securedrop/worker.py @@ -11,8 +11,7 @@ from sdconfig import config -def create_queue(name=None, timeout=3600): - # type: (str, int) -> Queue +def create_queue(name: str = None, timeout: int = 3600) -> Queue: """ Create an rq ``Queue`` named ``name`` with default timeout ``timeout``. @@ -24,8 +23,7 @@ def create_queue(name=None, timeout=3600): return q -def rq_workers(queue=None): - # type: (Queue) -> List[Worker] +def rq_workers(queue: Queue = None) -> List[Worker]: """ Returns the list of current rq ``Worker``s. """ @@ -33,8 +31,7 @@ def rq_workers(queue=None): return Worker.all(connection=Redis(), queue=queue) -def worker_for_job(job_id): - # type: (str) -> Optional[Worker] +def worker_for_job(job_id: str) -> Optional[Worker]: """ If the job is being run, return its ``Worker``. """ @@ -55,8 +52,7 @@ def worker_for_job(job_id): return None -def requeue_interrupted_jobs(queue_name=None): - # type: (str) -> None +def requeue_interrupted_jobs(queue_name: str = None) -> None: """ Requeues jobs found in the given queue's started job registry.