Skip to content

Commit

Permalink
[py] get auth header from client config
Browse files Browse the repository at this point in the history
Signed-off-by: Viet Nguyen Duc <[email protected]>
  • Loading branch information
VietND96 committed Oct 3, 2024
1 parent 5600cc7 commit 26b1978
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 8 deletions.
67 changes: 59 additions & 8 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import base64
import os
from urllib import parse

Expand All @@ -26,11 +27,19 @@ def __init__(
self,
remote_server_addr: str,
keep_alive: bool = True,
proxy=None,
proxy: Proxy = Proxy(raw={"proxyType": ProxyType.SYSTEM}),
username: str = None,
password: str = None,
auth_type: str = "Basic",
token: str = None,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
self.proxy = proxy
self.username = username
self.password = password
self.auth_type = auth_type
self.token = token

@property
def remote_server_addr(self) -> str:
Expand All @@ -57,8 +66,6 @@ def keep_alive(self, value: bool) -> None:
@property
def proxy(self) -> Proxy:
""":Returns: The proxy used for communicating to the driver/server."""

self._proxy = self._proxy or Proxy(raw={"proxyType": ProxyType.SYSTEM})
return self._proxy

@proxy.setter
Expand All @@ -71,17 +78,49 @@ def proxy(self, proxy: Proxy) -> None:
"""
self._proxy = proxy

def get_proxy_url(self):
@property
def username(self) -> str:
return self._username

@username.setter
def username(self, value: str) -> None:
self._username = value

@property
def password(self) -> str:
return self._password

@password.setter
def password(self, value: str) -> None:
self._password = value

@property
def auth_type(self) -> str:
return self._auth_type

@auth_type.setter
def auth_type(self, value: str) -> None:
self._auth_type = value

@property
def token(self) -> str:
return self._token

@token.setter
def token(self, value: str) -> None:
self._token = value

def get_proxy_url(self) -> str:
if self.proxy.proxy_type == ProxyType.DIRECT:
return None
elif self.proxy.proxy_type == ProxyType.SYSTEM:
_no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY"))
if _no_proxy:
for npu in _no_proxy.split(","):
npu = npu.strip()
if npu == "*":
for entry in _no_proxy.split(","):
entry = entry.strip()
if entry == "*":
return None
n_url = parse.urlparse(npu)
n_url = parse.urlparse(entry)
remote_add = parse.urlparse(self.remote_server_addr)
if n_url.netloc:
if remote_add.netloc == n_url.netloc:
Expand All @@ -102,3 +141,15 @@ def get_proxy_url(self):
return None
else:
return None

def get_auth_header(self):
auth_type = self.auth_type.lower()
if auth_type == "basic" and self.username and self.password:
credentials = f"{self.username}:{self.password}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
return {"Authorization": f"Basic {encoded_credentials}"}
elif auth_type == "bearer" and self.token:
return {"Authorization": f"Bearer {self.token}"}
elif auth_type == "oauth" and self.token:
return {"Authorization": f"OAuth {self.token}"}
return None
5 changes: 5 additions & 0 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def _request(self, method, url, body=None):
"""
parsed_url = parse.urlparse(url)
headers = self.get_remote_connection_headers(parsed_url, self._client_config.keep_alive)
auth_header = self._client_config.get_auth_header()

if auth_header:
headers.update(auth_header)

if body and method not in ("POST", "PUT"):
body = None

Expand Down
14 changes: 14 additions & 0 deletions py/test/unit/selenium/webdriver/remote/remote_connection_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from selenium import __version__
from selenium.webdriver.remote.remote_connection import RemoteConnection
from selenium.webdriver.remote.remote_connection import ClientConfig


def test_get_remote_connection_headers_defaults():
Expand Down Expand Up @@ -54,6 +55,19 @@ def test_get_proxy_url_http(mock_proxy_settings):
assert proxy_url == proxy


def test_get_auth_header_if_client_config_pass():
custom_config = ClientConfig(
remote_server_addr="http://remote",
keep_alive=True,
username="user",
password="pass",
auth_type="Basic"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
assert headers.get("Authorization") == "Basic dXNlcjpwYXNz"


def test_get_proxy_url_https(mock_proxy_settings):
proxy = "http://https_proxy.com:8080"
remote_connection = RemoteConnection("https://remote", keep_alive=False)
Expand Down

0 comments on commit 26b1978

Please sign in to comment.