Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSH tunnel support for S3Store #882

Merged
merged 19 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions src/maggma/stores/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from monty.msgpack import default as monty_default

from maggma.core import Sort, Store
from maggma.stores.ssh_tunnel import SSHTunnel
from maggma.utils import grouper, to_isoformat_ceil_ms

try:
Expand Down Expand Up @@ -40,6 +41,7 @@
sub_dir: Optional[str] = None,
s3_workers: int = 1,
s3_resource_kwargs: Optional[dict] = None,
ssh_tunnel: Optional[SSHTunnel] = None,
key: str = "fs_id",
store_hash: bool = True,
unpack_data: bool = True,
Expand All @@ -59,11 +61,16 @@
aws_session_token (string) -- AWS temporary session token
region_name (string) -- Default region when creating new connections
compress: compress files inserted into the store
endpoint_url: endpoint_url to allow interface to minio service
sub_dir: (optional) subdirectory of the s3 bucket to store the data
endpoint_url: endpoint_url to allow interface to minio service; ignored if
`ssh_tunnel` is provided, in which case the endpoint_url is inferred
sub_dir: (optional) subdirectory of the s3 bucket to store the data
s3_workers: number of concurrent S3 puts to run
s3_resource_kwargs: additional kwargs to pass to the boto3 session resource
ssh_tunnel: optional SSH tunnel to use for the S3 connection
key: main key to index on
store_hash: store the sha1 hash right before insertion to the database.
unpack_data: whether to decompress and unpack byte data when querying from the bucket.
unpack_data: whether to decompress and unpack byte data when querying from
the bucket
searchable_fields: fields to keep in the index store
"""
if boto3 is None:
Expand All @@ -79,6 +86,7 @@
self.s3_bucket: Any = None
self.s3_workers = s3_workers
self.s3_resource_kwargs = s3_resource_kwargs if s3_resource_kwargs is not None else {}
self.ssh_tunnel = ssh_tunnel
self.unpack_data = unpack_data
self.searchable_fields = searchable_fields if searchable_fields is not None else []
self.store_hash = store_hash
Expand Down Expand Up @@ -107,7 +115,8 @@
def connect(self, *args, **kwargs): # lgtm[py/conflicting-attributes]
"""Connect to the source data."""
session = self._get_session()
resource = session.resource("s3", endpoint_url=self.endpoint_url, **self.s3_resource_kwargs)
endpoint_url = self._get_endpoint_url()
resource = session.resource("s3", endpoint_url=endpoint_url, **self.s3_resource_kwargs)

if not self.s3:
self.s3 = resource
Expand All @@ -127,6 +136,9 @@
self.s3 = None
self.s3_bucket = None

if self.ssh_tunnel is not None:
self.ssh_tunnel.stop()

Check warning on line 140 in src/maggma/stores/aws.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/aws.py#L140

Added line #L140 was not covered by tests

@property
def _collection(self):
"""
Expand Down Expand Up @@ -266,7 +278,7 @@

Args:
key: single key to index
unique: Whether or not this index contains only unique keys
unique: whether or not this index contains only unique keys

Returns:
bool indicating if the index exists/was created
Expand Down Expand Up @@ -322,19 +334,30 @@
self.index.update(search_docs, key=self.key)

def _get_session(self):
if self.ssh_tunnel is not None:
self.ssh_tunnel.start()

Check warning on line 338 in src/maggma/stores/aws.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/aws.py#L338

Added line #L338 was not covered by tests

if not hasattr(self._thread_local, "s3_bucket"):
if isinstance(self.s3_profile, dict):
return Session(**self.s3_profile)
return Session(profile_name=self.s3_profile)
return None

def _get_endpoint_url(self):
if self.ssh_tunnel is None:
return self.endpoint_url
else:
host, port = self.ssh_tunnel.local_address
return f"http://{host}:{port}"

Check warning on line 351 in src/maggma/stores/aws.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/aws.py#L350-L351

Added lines #L350 - L351 were not covered by tests

def _get_bucket(self):
"""If on the main thread return the bucket created above, else create a new bucket on each thread."""
if threading.current_thread().name == "MainThread":
return self.s3_bucket
if not hasattr(self._thread_local, "s3_bucket"):
session = self._get_session()
resource = session.resource("s3", endpoint_url=self.endpoint_url)
endpoint_url = self._get_endpoint_url()
resource = session.resource("s3", endpoint_url=endpoint_url)
self._thread_local.s3_bucket = resource.Bucket(self.bucket)
return self._thread_local.s3_bucket

Expand Down
3 changes: 2 additions & 1 deletion src/maggma/stores/gridfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ruamel import yaml

from maggma.core import Sort, Store, StoreError
from maggma.stores.mongolike import MongoStore, SSHTunnel
from maggma.stores.mongolike import MongoStore
from maggma.stores.ssh_tunnel import SSHTunnel

# https://github.com/mongodb/specifications/
# blob/master/source/gridfs/gridfs-spec.rst#terms
Expand Down
87 changes: 4 additions & 83 deletions src/maggma/stores/mongolike.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import warnings
from itertools import chain, groupby
from pathlib import Path
from socket import socket

from ruamel import yaml

from maggma.stores.ssh_tunnel import SSHTunnel

try:
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Union
except ImportError:
Expand All @@ -23,12 +24,11 @@
import orjson
from monty.dev import requires
from monty.io import zopen
from monty.json import MSONable, jsanitize
from monty.json import jsanitize
from monty.serialization import loadfn
from pydash import get, has, set_
from pymongo import MongoClient, ReplaceOne, uri_parser
from pymongo.errors import ConfigurationError, DocumentTooLarge, OperationFailure
from sshtunnel import SSHTunnelForwarder

from maggma.core import Sort, Store, StoreError
from maggma.utils import confirm_field_index, to_dt
Expand All @@ -39,79 +39,6 @@
MontyClient = None


class SSHTunnel(MSONable):
__TUNNELS: Dict[str, SSHTunnelForwarder] = {}

def __init__(
self,
tunnel_server_address: str,
remote_server_address: str,
username: Optional[str] = None,
password: Optional[str] = None,
private_key: Optional[str] = None,
**kwargs,
):
"""
Args:
tunnel_server_address: string address with port for the SSH tunnel server
remote_server_address: string address with port for the server to connect to
username: optional username for the ssh tunnel server
password: optional password for the ssh tunnel server; If a private_key is
supplied this password is assumed to be the private key password
private_key: ssh private key to authenticate to the tunnel server
kwargs: any extra args passed to the SSHTunnelForwarder
"""

self.tunnel_server_address = tunnel_server_address
self.remote_server_address = remote_server_address
self.username = username
self.password = password
self.private_key = private_key
self.kwargs = kwargs

if remote_server_address in SSHTunnel.__TUNNELS:
self.tunnel = SSHTunnel.__TUNNELS[remote_server_address]
else:
open_port = _find_free_port("127.0.0.1")
local_bind_address = ("127.0.0.1", open_port)

ssh_address, ssh_port = tunnel_server_address.split(":")
ssh_port = int(ssh_port) # type: ignore

remote_bind_address, remote_bind_port = remote_server_address.split(":")
remote_bind_port = int(remote_bind_port) # type: ignore

if private_key is not None:
ssh_password = None
ssh_private_key_password = password
else:
ssh_password = password
ssh_private_key_password = None

self.tunnel = SSHTunnelForwarder(
ssh_address_or_host=(ssh_address, ssh_port),
local_bind_address=local_bind_address,
remote_bind_address=(remote_bind_address, remote_bind_port),
ssh_username=username,
ssh_password=ssh_password,
ssh_private_key_password=ssh_private_key_password,
ssh_pkey=private_key,
**kwargs,
)

def start(self):
if not self.tunnel.is_active:
self.tunnel.start()

def stop(self):
if self.tunnel.tunnel_is_up:
self.tunnel.stop()

@property
def local_address(self) -> Tuple[str, int]:
return self.tunnel.local_bind_address


class MongoStore(Store):
"""
A Store that connects to a Mongo collection
Expand Down Expand Up @@ -742,7 +669,7 @@ def __init__(

super().__init__(**kwargs)

def connect(self, force_reset: bool =False):
def connect(self, force_reset: bool = False):
"""
Loads the files into the collection in memory

Expand Down Expand Up @@ -1005,9 +932,3 @@ def update(self, docs: Union[List[Dict], Dict], key: Union[List, str, None] = No
search_doc = {k: d[k] for k in key} if isinstance(key, list) else {key: d[key]}

self._collection.replace_one(search_doc, d, upsert=True)


def _find_free_port(address="0.0.0.0"):
s = socket()
s.bind((address, 0)) # Bind to a free port provided by the host.
return s.getsockname()[1] # Return the port number assigned.
89 changes: 89 additions & 0 deletions src/maggma/stores/ssh_tunnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from socket import socket
from typing import Dict, Optional, Tuple

from monty.json import MSONable
from sshtunnel import SSHTunnelForwarder


class SSHTunnel(MSONable):
__TUNNELS: Dict[str, SSHTunnelForwarder] = {}

def __init__(
self,
tunnel_server_address: str,
remote_server_address: str,
local_port: Optional[int] = None,
username: Optional[str] = None,
password: Optional[str] = None,
private_key: Optional[str] = None,
**kwargs,
):
"""
Args:
tunnel_server_address: string address with port for the SSH tunnel server
remote_server_address: string address with port for the server to connect to
local_port: optional port to use for the local address (127.0.0.1);
if `None`, a random open port will be automatically selected
username: optional username for the ssh tunnel server
password: optional password for the ssh tunnel server; If a private_key is
supplied this password is assumed to be the private key password
private_key: ssh private key to authenticate to the tunnel server
kwargs: any extra args passed to the SSHTunnelForwarder
"""

self.tunnel_server_address = tunnel_server_address
self.remote_server_address = remote_server_address
self.local_port = local_port
self.username = username
self.password = password
self.private_key = private_key
self.kwargs = kwargs

Check warning on line 40 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L34-L40

Added lines #L34 - L40 were not covered by tests

if remote_server_address in SSHTunnel.__TUNNELS:
self.tunnel = SSHTunnel.__TUNNELS[remote_server_address]

Check warning on line 43 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L42-L43

Added lines #L42 - L43 were not covered by tests
else:
if local_port is None:
local_port = _find_free_port("127.0.0.1")
local_bind_address = ("127.0.0.1", local_port)

Check warning on line 47 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L45-L47

Added lines #L45 - L47 were not covered by tests

ssh_address, ssh_port = tunnel_server_address.split(":")
ssh_port = int(ssh_port) # type: ignore

Check warning on line 50 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L49-L50

Added lines #L49 - L50 were not covered by tests

remote_bind_address, remote_bind_port = remote_server_address.split(":")
remote_bind_port = int(remote_bind_port) # type: ignore

Check warning on line 53 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L52-L53

Added lines #L52 - L53 were not covered by tests

if private_key is not None:
ssh_password = None
ssh_private_key_password = password

Check warning on line 57 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L55-L57

Added lines #L55 - L57 were not covered by tests
else:
ssh_password = password
ssh_private_key_password = None

Check warning on line 60 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L59-L60

Added lines #L59 - L60 were not covered by tests

self.tunnel = SSHTunnelForwarder(

Check warning on line 62 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L62

Added line #L62 was not covered by tests
ssh_address_or_host=(ssh_address, ssh_port),
local_bind_address=local_bind_address,
remote_bind_address=(remote_bind_address, remote_bind_port),
ssh_username=username,
ssh_password=ssh_password,
ssh_private_key_password=ssh_private_key_password,
ssh_pkey=private_key,
**kwargs,
)

def start(self):
if not self.tunnel.is_active:
self.tunnel.start()

Check warning on line 75 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L74-L75

Added lines #L74 - L75 were not covered by tests

def stop(self):
if self.tunnel.tunnel_is_up:
self.tunnel.stop()

Check warning on line 79 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L78-L79

Added lines #L78 - L79 were not covered by tests

@property
def local_address(self) -> Tuple[str, int]:
return self.tunnel.local_bind_address

Check warning on line 83 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L83

Added line #L83 was not covered by tests


def _find_free_port(address="0.0.0.0"):
s = socket()
s.bind((address, 0)) # Bind to a free port provided by the host.

Check warning

Code scanning / CodeQL

Binding a socket to all network interfaces Medium

'0.0.0.0' binds a socket to all interfaces.
return s.getsockname()[1] # Return the port number assigned.

Check warning on line 89 in src/maggma/stores/ssh_tunnel.py

View check run for this annotation

Codecov / codecov/patch

src/maggma/stores/ssh_tunnel.py#L87-L89

Added lines #L87 - L89 were not covered by tests
3 changes: 2 additions & 1 deletion tests/stores/test_ssh_tunnel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import paramiko
import pymongo
import pytest
from maggma.stores.mongolike import MongoStore, SSHTunnel
from maggma.stores.mongolike import MongoStore
from maggma.stores.ssh_tunnel import SSHTunnel
from monty.serialization import dumpfn, loadfn
from paramiko.ssh_exception import AuthenticationException, NoValidConnectionsError, SSHException

Expand Down