Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SSH port forwarding #8

Merged
merged 2 commits into from
Jul 21, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions jupyter_forward/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import getpass
import random
import time
import urllib.parse
from collections import namedtuple

import invoke
Expand All @@ -11,7 +10,6 @@

random.seed(42)


app = typer.Typer(help='Jupyter Lab Port Forwarding Utility')


Expand Down Expand Up @@ -66,7 +64,7 @@ def open_browser(port: int = None, token: str = None, url: str = None):
webbrowser.open(url, new=2)


def setup_port_forwarding(port: int, username: str, hostname: str):
def setup_port_forwarding(port: int, username: str, hostname: str, host: str):
"""
Sets up SSH port forwarding

Expand All @@ -76,25 +74,47 @@ def setup_port_forwarding(port: int, username: str, hostname: str):
port number to use
username : str
hostname : str
host : str
"""
print('*** Setting up port forwarding ***')
command = f'ssh -N -L {port}:localhost:{port} {username}@{hostname}'
command = f'ssh -N -L localhost:{port}:{hostname}:{port} {username}@{host}'
print(command)
invoke.run(command, asynchronous=True)
time.sleep(3)


def parse_stdout(stdout: str):
"""
Parses stdout to determine remote_hostname, port, token, url

Parameters
----------
stdout : str
Contents of the log file/stdout

Returns
-------
dict
A dictionary containing hotname, port, token, and url
"""
import re
import urllib.parse

hostname, port, token, url = None, None, None, None
stdout = stdout.splitlines()
for line in stdout:
line = line.strip()
if line.startswith('http') and ('127.0.0.1' not in line):
result = urllib.parse.urlparse(line)
url = line
urls = set(
re.findall(
r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',
stdout,
)
)
for url in urls:
url = url.strip()
if '127.0.0.1' not in url:
result = urllib.parse.urlparse(url)
hostname, port = result.netloc.split(':')
if 'token' in result.query:
token = result.query.split('token=')[-1].strip()
break
return {'hostname': hostname, 'port': port, 'token': token, 'url': url}


Expand Down Expand Up @@ -156,23 +176,23 @@ def start(

# jupyter lab will pipe output to logfile, which should not exist prior to running
# Logfile will be in $TMPDIR if defined on the remote machine, otherwise in $HOME
tmpdir = session.run('echo $TMPDIR', hide=True).stdout.strip()
tmpdir = session.run('echo $TMPDIR', hide='out').stdout.strip()
if len(tmpdir) == 0:
tmpdir = session.run('echo $HOME', hide=True).stdout.strip()
tmpdir = session.run('echo $HOME', hide='out').stdout.strip()
if len(tmpdir) == 0:
tmpdir = '~'
logfile = f'{tmpdir}/.jforward.{port}'
_ = session.run(f'rm -f {logfile}')
log_dir = f'{tmpdir}/.jupyter_forward'
session.run(f'mkdir -p {log_dir}', hide='out')
logfile = f'{log_dir}/jforward.{port}'
session.run(f'rm -f {logfile}', hide='out')

# start jupyter lab on remote machine
command = f'conda activate {conda_env} && jupyter lab --no-browser --ip=`hostname` --port={port} --notebook-dir={notebook_dir}'
jlab_exe = session.run(f'{command} > {logfile} 2>&1', asynchronous=True)
print(f'DEBUG: jlab_exe is of type {type(jlab_exe)}')

_ = session.run(f'{command} > {logfile} 2>&1', asynchronous=True)
# wait for logfile to contain access info, then write it to screen
condition = True
stdout = None
pattern = 'To access the notebook, open this file in a browser:'
pattern = 'The Jupyter Notebook is running at:'
while condition:
try:
result = session.run(f'tail {logfile}', hide='out')
Expand All @@ -184,12 +204,17 @@ def start(
pass

parsed_result = parse_stdout(stdout)
print(parsed_result)
if port_forwarding:
setup_port_forwarding(parsed_result['port'], session.user, parsed_result['hostname'])
setup_port_forwarding(
parsed_result['port'], session.user, parsed_result['hostname'], session.host
)
open_browser(port=parsed_result['port'], token=parsed_result['token'])
else:
open_browser(url=parsed_result['url'])

session.run(f'tail -f {logfile}')


@app.command()
def resume():
Expand Down