Skip to content

Commit

Permalink
implement pluggable auth in python with basic and sigv4 as reference (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
xiazcy committed Aug 28, 2024
1 parent bfe1178 commit 9797799
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 62 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ image::https://raw.githubusercontent.com/apache/tinkerpop/master/docs/static/ima
* `EmbeddedRemoteConnection` will use `Gremlinlang`, not `JavaTranslator`.
* Java `Client` will no longer support submitting traversals. `DriverRemoteConnection` should be used instead.
* Removed usage of `Bytecode` from `gremlin-python`.
* Added `auth` module in `gremlin-python` for pluggable authentication.
* Fixed `GremlinLangScriptEngine` handling for some strategies.
* Modified the `split()` step to split a string into a list of its characters if the given separator is an empty string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, call_from_event_loop=None, read_timeout=None, write_timeout=N
self._client_session = None
self._http_req_resp = None
self._enable_ssl = False
self._url = None

# Set all inner variables to parameters passed in.
self._aiohttp_kwargs = kwargs
Expand All @@ -65,31 +66,31 @@ def __del__(self):
self.close()

def connect(self, url, headers=None):
self._url = url
# Inner function to perform async connect.
async def async_connect():
# Start client session and use it to send all HTTP requests. Base url is the endpoint, headers are set here
# Base url can only parse basic url with no path, see https://github.com/aio-libs/aiohttp/issues/6647
# Start client session and use it to send all HTTP requests. Headers can be set here.
if self._enable_ssl:
# ssl context is established through tcp connector
tcp_conn = aiohttp.TCPConnector(ssl_context=self._ssl_context)
self._client_session = aiohttp.ClientSession(connector=tcp_conn,
base_url=url, headers=headers, loop=self._loop)
headers=headers, loop=self._loop)
else:
self._client_session = aiohttp.ClientSession(base_url=url, headers=headers, loop=self._loop)
self._client_session = aiohttp.ClientSession(headers=headers, loop=self._loop)

# Execute the async connect synchronously.
self._loop.run_until_complete(async_connect())

def write(self, message):
# Inner function to perform async write.
async def async_write():
basic_auth = None
# basic password authentication for https connections
# To pass url into message for request authentication processing
message.update({'url': self._url})
if message['auth']:
basic_auth = aiohttp.BasicAuth(message['auth']['username'], message['auth']['password'])
message['auth'](message)

async with async_timeout.timeout(self._write_timeout):
self._http_req_resp = await self._client_session.post(url="/gremlin",
auth=basic_auth,
self._http_req_resp = await self._client_session.post(url=self._url,
data=message['payload'],
headers=message['headers'],
**self._aiohttp_kwargs)
Expand Down
55 changes: 55 additions & 0 deletions gremlin-python/src/main/python/gremlin_python/driver/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#


def basic(username, password):
from aiohttp import BasicAuth as aiohttpBasicAuth

def apply(request):
return request['headers'].update({'authorization': aiohttpBasicAuth(username, password).encode()})

return apply


def sigv4(region, service):
import os
from boto3 import Session
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

def apply(request):
access_key = os.environ.get('AWS_ACCESS_KEY_ID', '')
secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', '')
session_token = os.environ.get('AWS_SESSION_TOKEN', '')

session = Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
aws_session_token=session_token,
region_name=region
)

sigv4_request = AWSRequest(method="POST", url=request['url'], data=request['payload'])
SigV4Auth(session.get_credentials(), service, region).add_auth(sigv4_request)
request['headers'].update(sigv4_request.headers)
request['payload'] = sigv4_request.data
return request

return apply

18 changes: 5 additions & 13 deletions gremlin-python/src/main/python/gremlin_python/driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
import logging
import warnings
import queue
import re
from concurrent.futures import ThreadPoolExecutor

from gremlin_python.driver import connection, protocol, request, serializer
from gremlin_python.process import traversal
from gremlin_python.driver.request import TokensV4

log = logging.getLogger("gremlinpython")

Expand All @@ -44,13 +41,10 @@ class Client:

def __init__(self, url, traversal_source, protocol_factory=None,
transport_factory=None, pool_size=None, max_workers=None,
message_serializer=None, username="", password="", headers=None,
message_serializer=None, auth=None, headers=None,
enable_user_agent_on_connect=True, **transport_kwargs):
log.info("Creating Client with url '%s'", url)

# check via url that we are using http protocol
self._use_http = re.search('^http', url)

self._closed = False
self._url = url
self._headers = headers
Expand All @@ -62,8 +56,7 @@ def __init__(self, url, traversal_source, protocol_factory=None,
message_serializer = serializer.GraphBinarySerializersV4()

self._message_serializer = message_serializer
self._username = username
self._password = password
self._auth = auth

if transport_factory is None:
try:
Expand All @@ -82,8 +75,7 @@ def transport_factory():
def protocol_factory():
return protocol.GremlinServerHTTPProtocol(
self._message_serializer,
username=self._username,
password=self._password)
auth=self._auth)
self._protocol_factory = protocol_factory

if pool_size is None:
Expand Down Expand Up @@ -163,11 +155,11 @@ def submit_async(self, message, bindings=None, request_options=None):

if isinstance(message, str):
log.debug("fields='%s', gremlin='%s'", str(fields), str(message))
message = request.RequestMessageV4(fields=fields, gremlin=message)
message = request.RequestMessage(fields=fields, gremlin=message)

conn = self._pool.get(True)
if request_options:
message.fields.update({token: request_options[token] for token in TokensV4
message.fields.update({token: request_options[token] for token in request.Tokens
if token in request_options and token != 'bindings'})
if 'bindings' in request_options:
if 'bindings' in message.fields:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import warnings

from gremlin_python.driver import client, serializer
from gremlin_python.driver.remote_connection import (
RemoteConnection, RemoteTraversal)
from gremlin_python.driver.request import TokensV4
from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal
from gremlin_python.driver.request import Tokens

log = logging.getLogger("gremlinpython")

Expand All @@ -34,7 +33,7 @@ class DriverRemoteConnection(RemoteConnection):

def __init__(self, url, traversal_source="g", protocol_factory=None,
transport_factory=None, pool_size=None, max_workers=None,
username="", password="",
auth=None,
message_serializer=None, headers=None,
enable_user_agent_on_connect=True, **transport_kwargs):
log.info("Creating DriverRemoteConnection with url '%s'", str(url))
Expand All @@ -44,8 +43,7 @@ def __init__(self, url, traversal_source="g", protocol_factory=None,
self.__transport_factory = transport_factory
self.__pool_size = pool_size
self.__max_workers = max_workers
self.__username = username
self.__password = password
self.__auth = auth
self.__message_serializer = message_serializer
self.__headers = headers
self.__enable_user_agent_on_connect = enable_user_agent_on_connect
Expand All @@ -59,8 +57,7 @@ def __init__(self, url, traversal_source="g", protocol_factory=None,
pool_size=pool_size,
max_workers=max_workers,
message_serializer=message_serializer,
username=username,
password=password,
auth=auth,
headers=headers,
enable_user_agent_on_connect=enable_user_agent_on_connect,
**transport_kwargs)
Expand Down Expand Up @@ -121,7 +118,7 @@ def is_closed(self):
def extract_request_options(gremlin_lang):
request_options = {}
for os in gremlin_lang.options_strategies:
request_options.update({token: os.configuration[token] for token in TokensV4
request_options.update({token: os.configuration[token] for token in Tokens
if token in os.configuration})
if gremlin_lang.parameters is not None and len(gremlin_lang.parameters) > 0:
request_options["params"] = gremlin_lang.parameters
Expand Down
21 changes: 6 additions & 15 deletions gremlin-python/src/main/python/gremlin_python/driver/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
#
import json
import logging
import abc

Expand Down Expand Up @@ -54,12 +53,9 @@ def write(self, request_message):

class GremlinServerHTTPProtocol(AbstractBaseProtocol):

def __init__(self,
message_serializer,
username='', password=''):
def __init__(self, message_serializer, auth=None):
self._auth = auth
self._message_serializer = message_serializer
self._username = username
self._password = password
self._response_msg = {'status': {'code': 0,
'message': '',
'exception': ''},
Expand All @@ -71,18 +67,13 @@ def connection_made(self, transport):
super(GremlinServerHTTPProtocol, self).connection_made(transport)

def write(self, request_message):

basic_auth = {}
if self._username and self._password:
basic_auth['username'] = self._username
basic_auth['password'] = self._password

content_type = str(self._message_serializer.version, encoding='utf-8')

message = {
'headers': {'CONTENT-TYPE': content_type,
'ACCEPT': content_type},
'headers': {'content-type': content_type,
'accept': content_type},
'payload': self._message_serializer.serialize_message(request_message),
'auth': basic_auth
'auth': self._auth
}

self._transport.write(message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

__author__ = 'David M. Brown ([email protected])'

RequestMessageV4 = collections.namedtuple(
'RequestMessageV4', ['fields', 'gremlin'])
RequestMessage = collections.namedtuple(
'RequestMessage', ['fields', 'gremlin'])

TokensV4 = ['batchSize', 'bindings', 'g', 'gremlin', 'language',
'evaluationTimeout', 'materializeProperties', 'timeoutMs', 'userAgent']
Tokens = ['batchSize', 'bindings', 'g', 'gremlin', 'language',
'evaluationTimeout', 'materializeProperties', 'timeoutMs', 'userAgent']
17 changes: 8 additions & 9 deletions gremlin-python/src/main/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,17 @@
import logging
import queue


from gremlin_python.driver.client import Client
from gremlin_python.driver.connection import Connection
from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection
from gremlin_python.driver.protocol import GremlinServerHTTPProtocol
from gremlin_python.driver.serializer import GraphBinarySerializersV4
from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport

from gremlin_python.driver.auth import basic, sigv4

"""HTTP server testing variables"""
gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/')
gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 'https://localhost:{}/')
gremlin_server_url = os.environ.get('GREMLIN_SERVER_URL_HTTP', 'http://localhost:{}/gremlin')
gremlin_basic_auth_url = os.environ.get('GREMLIN_SERVER_BASIC_AUTH_URL_HTTP', 'https://localhost:{}/gremlin')
anonymous_url = gremlin_server_url.format(45940)
basic_url = gremlin_basic_auth_url.format(45941)

Expand All @@ -44,15 +43,14 @@
logging.basicConfig(format='%(asctime)s [%(levelname)8s] [%(filename)15s:%(lineno)d - %(funcName)10s()] - %(message)s',
level=logging.DEBUG if verbose_logging else logging.INFO)


"""
Tests below are for the HTTP server with GraphBinaryV4
"""
@pytest.fixture
def connection(request):
protocol = GremlinServerHTTPProtocol(
GraphBinarySerializersV4(),
username='stephen', password='password')
auth=basic('stephen', 'password'))
executor = concurrent.futures.ThreadPoolExecutor(5)
pool = queue.Queue()
try:
Expand Down Expand Up @@ -91,7 +89,8 @@ def authenticated_client(request):
# turn off certificate verification for testing purposes only
ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ssl_opts.verify_mode = ssl.CERT_NONE
client = Client(basic_url, 'gmodern', username='stephen', password='password',
client = Client(basic_url, 'gmodern',
auth=basic('stephen', 'password'),
transport_factory=lambda: AiohttpHTTPTransport(ssl_options=ssl_opts))
else:
raise ValueError("Invalid authentication option - " + request.param)
Expand Down Expand Up @@ -144,7 +143,7 @@ def fin():
return remote_conn


# TODO: revisit once auth is updated
# TODO: revisit once testing for multiple types of auth is enabled
@pytest.fixture(params=['basic'])
def remote_connection_authenticated(request):
try:
Expand All @@ -153,7 +152,7 @@ def remote_connection_authenticated(request):
ssl_opts = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ssl_opts.verify_mode = ssl.CERT_NONE
remote_conn = DriverRemoteConnection(basic_url, 'gmodern',
username='stephen', password='password',
auth=basic('stephen', 'password'),
transport_factory=lambda: AiohttpHTTPTransport(ssl_options=ssl_opts))
else:
raise ValueError("Invalid authentication option - " + request.param)
Expand Down
Loading

0 comments on commit 9797799

Please sign in to comment.