From a9a85861eccd9f9463c633aee07171ecf71861c1 Mon Sep 17 00:00:00 2001 From: RRosio Date: Mon, 8 Aug 2022 08:30:44 -0700 Subject: [PATCH 1/2] update redirect login checks --- notebook/auth/login.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/notebook/auth/login.py b/notebook/auth/login.py index 47cfb79ae0..f1600f95fd 100644 --- a/notebook/auth/login.py +++ b/notebook/auth/login.py @@ -6,7 +6,7 @@ import re import os -from urllib.parse import urlparse +from urllib.parse import urlparse, urlunparse import uuid @@ -42,15 +42,18 @@ def _redirect_safe(self, url, default=None): # instead of %5C, causing `\\` to behave as `//` url = url.replace("\\", "%5C") parsed = urlparse(url) - if parsed.netloc or not (parsed.path + '/').startswith(self.base_url): + path_only = urlunparse(parsed._replace(netloc='', scheme='')) + if url != path_only or not (parsed.path + '/').startswith(self.base_url): # require that next_url be absolute path within our path allow = False # OR pass our cross-origin check - if parsed.netloc: + if url != path_only: # if full URL, run our cross-origin check: origin = f'{parsed.scheme}://{parsed.netloc}' origin = origin.lower() - if self.allow_origin: + if origin == f'{self.request.protocol}://{self.request.host}': + allow = True + elif self.allow_origin: allow = self.allow_origin == origin elif self.allow_origin_pat: allow = bool(self.allow_origin_pat.match(origin)) From 9aacc4dea10cb9843c59ba94108c00309c3bd1c8 Mon Sep 17 00:00:00 2001 From: RRosio Date: Mon, 8 Aug 2022 08:32:49 -0700 Subject: [PATCH 2/2] update tests for login --- notebook/auth/tests/test_login.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/notebook/auth/tests/test_login.py b/notebook/auth/tests/test_login.py index 2b5574204a..9120b539d4 100644 --- a/notebook/auth/tests/test_login.py +++ b/notebook/auth/tests/test_login.py @@ -31,6 +31,8 @@ def test_next_bad(self): "//host" + self.url_prefix + "tree", "https://google.com", "/absolute/not/base_url", + "///jupyter.org", + "/\\some-host", ): url = self.login(next=bad_next) self.assertEqual(url, self.url_prefix) @@ -39,10 +41,14 @@ def test_next_bad(self): def test_next_ok(self): for next_path in ( "tree/", - "//" + self.url_prefix + "tree", + self.base_url() + "has/host", "notebooks/notebook.ipynb", "tree//something", ): - expected = self.url_prefix + next_path + if "://" in next_path: + expected = next_path + else: + expected = self.url_prefix + next_path + actual = self.login(next=expected) self.assertEqual(actual, expected)