diff --git a/test/test_utils.py b/test/test_utils.py index 2e3cd0179..3f45b0bd1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -70,6 +70,7 @@ from youtube_dl.utils import ( lowercase_escape, url_basename, base_url, + urljoin, urlencode_postdata, urshift, update_url_query, @@ -445,6 +446,19 @@ class TestUtil(unittest.TestCase): self.assertEqual(base_url('http://foo.de/bar/baz'), 'http://foo.de/bar/') self.assertEqual(base_url('http://foo.de/bar/baz?x=z/x/c'), 'http://foo.de/bar/') + def test_urljoin(self): + self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('http://foo.de', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('http://foo.de/', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin(None, 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('', 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin(['foobar'], 'http://foo.de/a/b/c.txt'), 'http://foo.de/a/b/c.txt') + self.assertEqual(urljoin('http://foo.de/', None), None) + self.assertEqual(urljoin('http://foo.de/', ''), None) + self.assertEqual(urljoin('http://foo.de/', ['foobar']), None) + def test_parse_age_limit(self): self.assertEqual(parse_age_limit(None), None) self.assertEqual(parse_age_limit(False), None) diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index 3d4951ad9..694e9a600 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -1700,6 +1700,16 @@ def base_url(url): return re.match(r'https?://[^?#&]+/', url).group() +def urljoin(base, path): + if not isinstance(path, compat_str) or not path: + return None + if re.match(r'https?://', path): + return path + if not isinstance(base, compat_str) or not re.match(r'https?://', base): + return None + return compat_urlparse.urljoin(base, path) + + class HEADRequest(compat_urllib_request.Request): def get_method(self): return 'HEAD'