Skip to content

Commit

Permalink
feat(connector): adding validation for auth params
Browse files Browse the repository at this point in the history
Fixes #380
  • Loading branch information
pallavibharadwaj authored and dovahcrow committed Nov 2, 2020
1 parent f9d7b08 commit 0a7c712
Show file tree
Hide file tree
Showing 3 changed files with 771 additions and 675 deletions.
28 changes: 28 additions & 0 deletions dataprep/connector/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Module defines errors used in this library.
"""
from typing import Set

from ..errors import DataprepError


Expand Down Expand Up @@ -64,3 +66,29 @@ def __init__(self, param: str) -> None:

def __str__(self) -> str:
return f"the parameter {self.param} is invalid, refer info method"


class MissingRequiredAuthParams(ValueError):
"""Some parameters for Authorization are missing."""

params: Set[str]

def __init__(self, params: Set[str]) -> None:
super().__init__()
self.params = params

def __str__(self) -> str:
return f"Missing required authorization parameter(s) {self.params} in _auth"


class InvalidAuthParams(ValueError):
"""The parameters used for Authorization are invalid."""

params: Set[str]

def __init__(self, params: Set[str]) -> None:
super().__init__()
self.params = params

def __str__(self) -> str:
return f"Authorization parameter(s) {self.params} in _auth are not required."
23 changes: 22 additions & 1 deletion dataprep/connector/schema/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from pathlib import Path
from threading import Thread
from time import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, Set
from urllib.parse import parse_qs, urlparse
import socket
import requests
from pydantic import Field

from ...utils import is_notebook
from .base import BaseDef, BaseDefT
from ..errors import MissingRequiredAuthParams, InvalidAuthParams

# pylint: disable=missing-class-docstring,missing-function-docstring
FILE_PATH: Path = Path(__file__).resolve().parent
Expand All @@ -30,6 +31,16 @@ def get_random_string(length: int) -> str:
return result_str


def validate_auth(required: Set[str], passed: Dict[str, Any]) -> None:
required_not_passed = required - passed.keys()
passed_not_required = passed.keys() - required
if required_not_passed:
raise MissingRequiredAuthParams(required_not_passed)

if passed_not_required:
raise InvalidAuthParams(passed_not_required)


class OffsetPaginationDef(BaseDef):
type: str = Field("offset", const=True)
max_count: int
Expand Down Expand Up @@ -130,6 +141,8 @@ def build(
port = params.get("port", 9999)
code = self._auth(params["client_id"], port)

validate_auth({"client_id", "client_secret"}, params)

ckey = params["client_id"]
csecret = params["client_secret"]
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
Expand Down Expand Up @@ -208,6 +221,8 @@ def build(
raise ValueError("storage is required for OAuth2")

if "access_token" not in storage or storage.get("expires_at", 0) < time():
validate_auth({"client_id", "client_secret"}, params)

# Not yet authorized
ckey = params["client_id"]
csecret = params["client_secret"]
Expand Down Expand Up @@ -242,6 +257,8 @@ def build(
) -> None:
"""Populate some required fields to the request data."""

validate_auth({"access_token"}, params)

req_data["params"][self.key_param] = params["access_token"]


Expand All @@ -256,6 +273,8 @@ def build(
) -> None:
"""Populate some required fields to the request data."""

validate_auth({"access_token"}, params)

req_data["headers"]["Authorization"] = f"Bearer {params['access_token']}"


Expand All @@ -272,6 +291,8 @@ def build(
) -> None:
"""Populate some required fields to the request data."""

validate_auth({"access_token"}, params)

req_data["headers"][self.key_name] = params["access_token"]
req_data["headers"].update(self.extra)

Expand Down
Loading

0 comments on commit 0a7c712

Please sign in to comment.