Eliminado venv y www del repositorio, agrege un requirements igual

This commit is contained in:
2020-11-22 21:14:46 -03:00
parent 18cf2d335a
commit 199a1e2a61
820 changed files with 15495 additions and 22017 deletions

View File

@@ -4,11 +4,7 @@ urllib3 - Thread-safe connection pooling and re-using.
from __future__ import absolute_import
import warnings
from .connectionpool import (
HTTPConnectionPool,
HTTPSConnectionPool,
connection_from_url
)
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool, connection_from_url
from . import exceptions
from .filepost import encode_multipart_formdata
@@ -24,25 +20,25 @@ from .util.retry import Retry
import logging
from logging import NullHandler
__author__ = 'Andrey Petrov (andrey.petrov@shazow.net)'
__license__ = 'MIT'
__version__ = '1.25.3'
__author__ = "Andrey Petrov (andrey.petrov@shazow.net)"
__license__ = "MIT"
__version__ = "1.25.9"
__all__ = (
'HTTPConnectionPool',
'HTTPSConnectionPool',
'PoolManager',
'ProxyManager',
'HTTPResponse',
'Retry',
'Timeout',
'add_stderr_logger',
'connection_from_url',
'disable_warnings',
'encode_multipart_formdata',
'get_host',
'make_headers',
'proxy_from_url',
"HTTPConnectionPool",
"HTTPSConnectionPool",
"PoolManager",
"ProxyManager",
"HTTPResponse",
"Retry",
"Timeout",
"add_stderr_logger",
"connection_from_url",
"disable_warnings",
"encode_multipart_formdata",
"get_host",
"make_headers",
"proxy_from_url",
)
logging.getLogger(__name__).addHandler(NullHandler())
@@ -59,10 +55,10 @@ def add_stderr_logger(level=logging.DEBUG):
# even if urllib3 is vendored within another package.
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(level)
logger.debug('Added a stderr logging handler to logger: %s', __name__)
logger.debug("Added a stderr logging handler to logger: %s", __name__)
return handler
@@ -74,18 +70,17 @@ del NullHandler
# shouldn't be: otherwise, it's very hard for users to use most Python
# mechanisms to silence them.
# SecurityWarning's always go off by default.
warnings.simplefilter('always', exceptions.SecurityWarning, append=True)
warnings.simplefilter("always", exceptions.SecurityWarning, append=True)
# SubjectAltNameWarning's should go off once per host
warnings.simplefilter('default', exceptions.SubjectAltNameWarning, append=True)
warnings.simplefilter("default", exceptions.SubjectAltNameWarning, append=True)
# InsecurePlatformWarning's don't vary between requests, so we keep it default.
warnings.simplefilter('default', exceptions.InsecurePlatformWarning,
append=True)
warnings.simplefilter("default", exceptions.InsecurePlatformWarning, append=True)
# SNIMissingWarnings should go off only once.
warnings.simplefilter('default', exceptions.SNIMissingWarning, append=True)
warnings.simplefilter("default", exceptions.SNIMissingWarning, append=True)
def disable_warnings(category=exceptions.HTTPWarning):
"""
Helper for quickly disabling all urllib3 warnings.
"""
warnings.simplefilter('ignore', category)
warnings.simplefilter("ignore", category)

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import
try:
from collections.abc import Mapping, MutableMapping
except ImportError:
@@ -6,6 +7,7 @@ except ImportError:
try:
from threading import RLock
except ImportError: # Platform-specific: No threads available
class RLock:
def __enter__(self):
pass
@@ -19,7 +21,7 @@ from .exceptions import InvalidHeader
from .packages.six import iterkeys, itervalues, PY3
__all__ = ['RecentlyUsedContainer', 'HTTPHeaderDict']
__all__ = ["RecentlyUsedContainer", "HTTPHeaderDict"]
_Null = object()
@@ -82,7 +84,9 @@ class RecentlyUsedContainer(MutableMapping):
return len(self._container)
def __iter__(self):
raise NotImplementedError('Iteration over this class is unlikely to be threadsafe.')
raise NotImplementedError(
"Iteration over this class is unlikely to be threadsafe."
)
def clear(self):
with self.lock:
@@ -150,7 +154,7 @@ class HTTPHeaderDict(MutableMapping):
def __getitem__(self, key):
val = self._container[key.lower()]
return ', '.join(val[1:])
return ", ".join(val[1:])
def __delitem__(self, key):
del self._container[key.lower()]
@@ -159,12 +163,13 @@ class HTTPHeaderDict(MutableMapping):
return key.lower() in self._container
def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
if not isinstance(other, Mapping) and not hasattr(other, "keys"):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return (dict((k.lower(), v) for k, v in self.itermerged()) ==
dict((k.lower(), v) for k, v in other.itermerged()))
return dict((k.lower(), v) for k, v in self.itermerged()) == dict(
(k.lower(), v) for k, v in other.itermerged()
)
def __ne__(self, other):
return not self.__eq__(other)
@@ -184,9 +189,9 @@ class HTTPHeaderDict(MutableMapping):
yield vals[0]
def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
"""D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
'''
"""
# Using the MutableMapping function directly fails due to the private marker.
# Using ordinary dict.pop would expose the internal structures.
# So let's reinvent the wheel.
@@ -228,8 +233,10 @@ class HTTPHeaderDict(MutableMapping):
with self.add instead of self.__setitem__
"""
if len(args) > 1:
raise TypeError("extend() takes at most 1 positional "
"arguments ({0} given)".format(len(args)))
raise TypeError(
"extend() takes at most 1 positional "
"arguments ({0} given)".format(len(args))
)
other = args[0] if len(args) >= 1 else ()
if isinstance(other, HTTPHeaderDict):
@@ -295,7 +302,7 @@ class HTTPHeaderDict(MutableMapping):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])
yield val[0], ", ".join(val[1:])
def items(self):
return list(self.iteritems())
@@ -306,7 +313,7 @@ class HTTPHeaderDict(MutableMapping):
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
obs_fold_continued_leaders = (' ', '\t')
obs_fold_continued_leaders = (" ", "\t")
headers = []
for line in message.headers:
@@ -316,14 +323,14 @@ class HTTPHeaderDict(MutableMapping):
# in RFC-7230 S3.2.4. This indicates a multiline header, but
# there exists no previous header to which we can attach it.
raise InvalidHeader(
'Header continuation with no previous header: %s' % line
"Header continuation with no previous header: %s" % line
)
else:
key, value = headers[-1]
headers[-1] = (key, value + ' ' + line.strip())
headers[-1] = (key, value + " " + line.strip())
continue
key, value = line.split(':', 1)
key, value = line.split(":", 1)
headers.append((key, value.strip()))
return cls(headers)

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import
import re
import datetime
import logging
import os
@@ -11,6 +12,7 @@ from .packages.six.moves.http_client import HTTPException # noqa: F401
try: # Compiled with SSL?
import ssl
BaseSSLError = ssl.SSLError
except (ImportError, AttributeError): # Platform-specific: No SSL.
ssl = None
@@ -41,7 +43,7 @@ from .util.ssl_ import (
resolve_ssl_version,
assert_fingerprint,
create_urllib3_context,
ssl_wrap_socket
ssl_wrap_socket,
)
@@ -51,20 +53,18 @@ from ._collections import HTTPHeaderDict
log = logging.getLogger(__name__)
port_by_scheme = {
'http': 80,
'https': 443,
}
port_by_scheme = {"http": 80, "https": 443}
# When updating RECENT_DATE, move it to within two years of the current date,
# and not less than 6 months ago.
# Example: if Today is 2018-01-01, then RECENT_DATE should be any date on or
# after 2016-01-01 (today - 2 years) AND before 2017-07-01 (today - 6 months)
RECENT_DATE = datetime.date(2017, 6, 30)
# When it comes time to update this value as a part of regular maintenance
# (ie test_recent_date is failing) update it to ~6 months before the current date.
RECENT_DATE = datetime.date(2019, 1, 1)
_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]")
class DummyConnection(object):
"""Used to detect a failed ConnectionCls import."""
pass
@@ -92,7 +92,7 @@ class HTTPConnection(_HTTPConnection, object):
Or you may want to disable the defaults by passing an empty list (e.g., ``[]``).
"""
default_port = port_by_scheme['http']
default_port = port_by_scheme["http"]
#: Disable Nagle's algorithm by default.
#: ``[(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]``
@@ -102,15 +102,15 @@ class HTTPConnection(_HTTPConnection, object):
is_verified = False
def __init__(self, *args, **kw):
if six.PY3:
kw.pop('strict', None)
if not six.PY2:
kw.pop("strict", None)
# Pre-set source_address.
self.source_address = kw.get('source_address')
self.source_address = kw.get("source_address")
#: The socket options provided by the user. If no options are
#: provided, we use the default options.
self.socket_options = kw.pop('socket_options', self.default_socket_options)
self.socket_options = kw.pop("socket_options", self.default_socket_options)
_HTTPConnection.__init__(self, *args, **kw)
@@ -131,7 +131,7 @@ class HTTPConnection(_HTTPConnection, object):
those cases where it's appropriate (i.e., when doing DNS lookup to establish the
actual TCP connection across which we're going to send HTTP requests).
"""
return self._dns_host.rstrip('.')
return self._dns_host.rstrip(".")
@host.setter
def host(self, value):
@@ -150,30 +150,34 @@ class HTTPConnection(_HTTPConnection, object):
"""
extra_kw = {}
if self.source_address:
extra_kw['source_address'] = self.source_address
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw['socket_options'] = self.socket_options
extra_kw["socket_options"] = self.socket_options
try:
conn = connection.create_connection(
(self._dns_host, self.port), self.timeout, **extra_kw)
(self._dns_host, self.port), self.timeout, **extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self, "Connection to %s timed out. (connect timeout=%s)" %
(self.host, self.timeout))
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except SocketError as e:
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e)
self, "Failed to establish a new connection: %s" % e
)
return conn
def _prepare_conn(self, conn):
self.sock = conn
# Google App Engine's httplib does not define _tunnel_host
if getattr(self, '_tunnel_host', None):
if getattr(self, "_tunnel_host", None):
# TODO: Fix tunnel so it doesn't depend on self.sock state.
self._tunnel()
# Mark this connection as not reusable
@@ -183,24 +187,32 @@ class HTTPConnection(_HTTPConnection, object):
conn = self._new_conn()
self._prepare_conn(conn)
def putrequest(self, method, url, *args, **kwargs):
"""Send a request to the server"""
match = _CONTAINS_CONTROL_CHAR_RE.search(method)
if match:
raise ValueError(
"Method cannot contain non-token characters %r (found at least %r)"
% (method, match.group())
)
return _HTTPConnection.putrequest(self, method, url, *args, **kwargs)
def request_chunked(self, method, url, body=None, headers=None):
"""
Alternative to the common request method, which sends the
body with chunked encoding and not as one block
"""
headers = HTTPHeaderDict(headers if headers is not None else {})
skip_accept_encoding = 'accept-encoding' in headers
skip_host = 'host' in headers
skip_accept_encoding = "accept-encoding" in headers
skip_host = "host" in headers
self.putrequest(
method,
url,
skip_accept_encoding=skip_accept_encoding,
skip_host=skip_host
method, url, skip_accept_encoding=skip_accept_encoding, skip_host=skip_host
)
for header, value in headers.items():
self.putheader(header, value)
if 'transfer-encoding' not in headers:
self.putheader('Transfer-Encoding', 'chunked')
if "transfer-encoding" not in headers:
self.putheader("Transfer-Encoding", "chunked")
self.endheaders()
if body is not None:
@@ -211,29 +223,42 @@ class HTTPConnection(_HTTPConnection, object):
if not chunk:
continue
if not isinstance(chunk, bytes):
chunk = chunk.encode('utf8')
chunk = chunk.encode("utf8")
len_str = hex(len(chunk))[2:]
self.send(len_str.encode('utf-8'))
self.send(b'\r\n')
self.send(len_str.encode("utf-8"))
self.send(b"\r\n")
self.send(chunk)
self.send(b'\r\n')
self.send(b"\r\n")
# After the if clause, to always have a closed body
self.send(b'0\r\n\r\n')
self.send(b"0\r\n\r\n")
class HTTPSConnection(HTTPConnection):
default_port = port_by_scheme['https']
default_port = port_by_scheme["https"]
cert_reqs = None
ca_certs = None
ca_cert_dir = None
ca_cert_data = None
ssl_version = None
assert_fingerprint = None
def __init__(self, host, port=None, key_file=None, cert_file=None,
key_password=None, strict=None,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
ssl_context=None, server_hostname=None, **kw):
def __init__(
self,
host,
port=None,
key_file=None,
cert_file=None,
key_password=None,
strict=None,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
ssl_context=None,
server_hostname=None,
**kw
):
HTTPConnection.__init__(self, host, port, strict=strict,
timeout=timeout, **kw)
HTTPConnection.__init__(self, host, port, strict=strict, timeout=timeout, **kw)
self.key_file = key_file
self.cert_file = cert_file
@@ -243,54 +268,20 @@ class HTTPSConnection(HTTPConnection):
# Required property for Google AppEngine 1.9.0 which otherwise causes
# HTTPS requests to go out as HTTP. (See Issue #356)
self._protocol = 'https'
self._protocol = "https"
def connect(self):
conn = self._new_conn()
self._prepare_conn(conn)
# Wrap socket using verification with the root certs in
# trusted_root_certs
default_ssl_context = False
if self.ssl_context is None:
default_ssl_context = True
self.ssl_context = create_urllib3_context(
ssl_version=resolve_ssl_version(self.ssl_version),
cert_reqs=resolve_cert_reqs(self.cert_reqs),
)
# Try to load OS default certs if none are given.
# Works well on Windows (requires Python3.4+)
context = self.ssl_context
if (not self.ca_certs and not self.ca_cert_dir and default_ssl_context
and hasattr(context, 'load_default_certs')):
context.load_default_certs()
self.sock = ssl_wrap_socket(
sock=conn,
keyfile=self.key_file,
certfile=self.cert_file,
key_password=self.key_password,
ssl_context=self.ssl_context,
server_hostname=self.server_hostname
)
class VerifiedHTTPSConnection(HTTPSConnection):
"""
Based on httplib.HTTPSConnection but wraps the socket with
SSL certification.
"""
cert_reqs = None
ca_certs = None
ca_cert_dir = None
ssl_version = None
assert_fingerprint = None
def set_cert(self, key_file=None, cert_file=None,
cert_reqs=None, key_password=None, ca_certs=None,
assert_hostname=None, assert_fingerprint=None,
ca_cert_dir=None):
def set_cert(
self,
key_file=None,
cert_file=None,
cert_reqs=None,
key_password=None,
ca_certs=None,
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
ca_cert_data=None,
):
"""
This method should only be called once, before the connection is used.
"""
@@ -310,6 +301,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
self.assert_fingerprint = assert_fingerprint
self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
self.ca_cert_data = ca_cert_data
def connect(self):
# Add certificate verification
@@ -317,7 +309,7 @@ class VerifiedHTTPSConnection(HTTPSConnection):
hostname = self.host
# Google App Engine's httplib does not define _tunnel_host
if getattr(self, '_tunnel_host', None):
if getattr(self, "_tunnel_host", None):
self.sock = conn
# Calls self._set_hostport(), so self.host is
# self._tunnel_host below.
@@ -334,10 +326,12 @@ class VerifiedHTTPSConnection(HTTPSConnection):
is_time_off = datetime.date.today() < RECENT_DATE
if is_time_off:
warnings.warn((
'System time is way off (before {0}). This will probably '
'lead to SSL verification errors').format(RECENT_DATE),
SystemTimeWarning
warnings.warn(
(
"System time is way off (before {0}). This will probably "
"lead to SSL verification errors"
).format(RECENT_DATE),
SystemTimeWarning,
)
# Wrap socket using verification with the root certs in
@@ -355,8 +349,13 @@ class VerifiedHTTPSConnection(HTTPSConnection):
# Try to load OS default certs if none are given.
# Works well on Windows (requires Python3.4+)
if (not self.ca_certs and not self.ca_cert_dir and default_ssl_context
and hasattr(context, 'load_default_certs')):
if (
not self.ca_certs
and not self.ca_cert_dir
and not self.ca_cert_data
and default_ssl_context
and hasattr(context, "load_default_certs")
):
context.load_default_certs()
self.sock = ssl_wrap_socket(
@@ -366,32 +365,39 @@ class VerifiedHTTPSConnection(HTTPSConnection):
key_password=self.key_password,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
ca_cert_data=self.ca_cert_data,
server_hostname=server_hostname,
ssl_context=context)
ssl_context=context,
)
if self.assert_fingerprint:
assert_fingerprint(self.sock.getpeercert(binary_form=True),
self.assert_fingerprint)
elif context.verify_mode != ssl.CERT_NONE \
and not getattr(context, 'check_hostname', False) \
and self.assert_hostname is not False:
assert_fingerprint(
self.sock.getpeercert(binary_form=True), self.assert_fingerprint
)
elif (
context.verify_mode != ssl.CERT_NONE
and not getattr(context, "check_hostname", False)
and self.assert_hostname is not False
):
# While urllib3 attempts to always turn off hostname matching from
# the TLS library, this cannot always be done. So we check whether
# the TLS Library still thinks it's matching hostnames.
cert = self.sock.getpeercert()
if not cert.get('subjectAltName', ()):
warnings.warn((
'Certificate for {0} has no `subjectAltName`, falling back to check for a '
'`commonName` for now. This feature is being removed by major browsers and '
'deprecated by RFC 2818. (See https://github.com/shazow/urllib3/issues/497 '
'for details.)'.format(hostname)),
SubjectAltNameWarning
if not cert.get("subjectAltName", ()):
warnings.warn(
(
"Certificate for {0} has no `subjectAltName`, falling back to check for a "
"`commonName` for now. This feature is being removed by major browsers and "
"deprecated by RFC 2818. (See https://github.com/urllib3/urllib3/issues/497 "
"for details.)".format(hostname)
),
SubjectAltNameWarning,
)
_match_hostname(cert, self.assert_hostname or server_hostname)
self.is_verified = (
context.verify_mode == ssl.CERT_REQUIRED or
self.assert_fingerprint is not None
context.verify_mode == ssl.CERT_REQUIRED
or self.assert_fingerprint is not None
)
@@ -399,9 +405,10 @@ def _match_hostname(cert, asserted_hostname):
try:
match_hostname(cert, asserted_hostname)
except CertificateError as e:
log.error(
'Certificate did not match expected hostname: %s. '
'Certificate: %s', asserted_hostname, cert
log.warning(
"Certificate did not match expected hostname: %s. Certificate: %s",
asserted_hostname,
cert,
)
# Add cert to exception and reraise so client code can inspect
# the cert when catching the exception, if they want to
@@ -409,9 +416,8 @@ def _match_hostname(cert, asserted_hostname):
raise
if ssl:
# Make a copy for testing.
UnverifiedHTTPSConnection = HTTPSConnection
HTTPSConnection = VerifiedHTTPSConnection
else:
HTTPSConnection = DummyConnection
if not ssl:
HTTPSConnection = DummyConnection # noqa: F811
VerifiedHTTPSConnection = HTTPSConnection

View File

@@ -26,12 +26,14 @@ from .exceptions import (
from .packages.ssl_match_hostname import CertificateError
from .packages import six
from .packages.six.moves import queue
from .packages.rfc3986.normalizers import normalize_host
from .connection import (
port_by_scheme,
DummyConnection,
HTTPConnection, HTTPSConnection, VerifiedHTTPSConnection,
HTTPException, BaseSSLError,
HTTPConnection,
HTTPSConnection,
VerifiedHTTPSConnection,
HTTPException,
BaseSSLError,
)
from .request import RequestMethods
from .response import HTTPResponse
@@ -41,7 +43,13 @@ from .util.request import set_file_position
from .util.response import assert_header_parsing
from .util.retry import Retry
from .util.timeout import Timeout
from .util.url import get_host, Url, NORMALIZABLE_SCHEMES
from .util.url import (
get_host,
parse_url,
Url,
_normalize_host as normalize_host,
_encode_target,
)
from .util.queue import LifoQueue
@@ -57,6 +65,11 @@ class ConnectionPool(object):
"""
Base class for all connection pools, such as
:class:`.HTTPConnectionPool` and :class:`.HTTPSConnectionPool`.
.. note::
ConnectionPool.urlopen() does not normalize or percent-encode target URIs
which is useful if your target server doesn't support percent-encoded
target URIs.
"""
scheme = None
@@ -71,8 +84,7 @@ class ConnectionPool(object):
self.port = port
def __str__(self):
return '%s(host=%r, port=%r)' % (type(self).__name__,
self.host, self.port)
return "%s(host=%r, port=%r)" % (type(self).__name__, self.host, self.port)
def __enter__(self):
return self
@@ -153,15 +165,24 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
:class:`urllib3.connection.HTTPSConnection` instances.
"""
scheme = 'http'
scheme = "http"
ConnectionCls = HTTPConnection
ResponseCls = HTTPResponse
def __init__(self, host, port=None, strict=False,
timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1, block=False,
headers=None, retries=None,
_proxy=None, _proxy_headers=None,
**conn_kw):
def __init__(
self,
host,
port=None,
strict=False,
timeout=Timeout.DEFAULT_TIMEOUT,
maxsize=1,
block=False,
headers=None,
retries=None,
_proxy=None,
_proxy_headers=None,
**conn_kw
):
ConnectionPool.__init__(self, host, port)
RequestMethods.__init__(self, headers)
@@ -195,19 +216,27 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Enable Nagle's algorithm for proxies, to avoid packet fragmentation.
# We cannot know if the user has added default socket options, so we cannot replace the
# list.
self.conn_kw.setdefault('socket_options', [])
self.conn_kw.setdefault("socket_options", [])
def _new_conn(self):
"""
Return a fresh :class:`HTTPConnection`.
"""
self.num_connections += 1
log.debug("Starting new HTTP connection (%d): %s:%s",
self.num_connections, self.host, self.port or "80")
log.debug(
"Starting new HTTP connection (%d): %s:%s",
self.num_connections,
self.host,
self.port or "80",
)
conn = self.ConnectionCls(host=self.host, port=self.port,
timeout=self.timeout.connect_timeout,
strict=self.strict, **self.conn_kw)
conn = self.ConnectionCls(
host=self.host,
port=self.port,
timeout=self.timeout.connect_timeout,
strict=self.strict,
**self.conn_kw
)
return conn
def _get_conn(self, timeout=None):
@@ -231,16 +260,17 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
except queue.Empty:
if self.block:
raise EmptyPoolError(self,
"Pool reached maximum size and no more "
"connections are allowed.")
raise EmptyPoolError(
self,
"Pool reached maximum size and no more connections are allowed.",
)
pass # Oh well, we'll create a new connection then
# If this is a persistent connection, check if it got disconnected
if conn and is_connection_dropped(conn):
log.debug("Resetting dropped connection: %s", self.host)
conn.close()
if getattr(conn, 'auto_open', 1) == 0:
if getattr(conn, "auto_open", 1) == 0:
# This is a proxied connection that has been mutated by
# httplib._tunnel() and cannot be reused (since it would
# attempt to bypass the proxy)
@@ -270,9 +300,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
pass
except queue.Full:
# This should never happen if self.block == True
log.warning(
"Connection pool is full, discarding connection: %s",
self.host)
log.warning("Connection pool is full, discarding connection: %s", self.host)
# Connection never got put back into the pool, close it.
if conn:
@@ -304,21 +332,30 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
"""Is the error actually a timeout? Will raise a ReadTimeout or pass"""
if isinstance(err, SocketTimeout):
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value)
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
# See the above comment about EAGAIN in Python 3. In Python 2 we have
# to specifically catch it and throw the timeout error
if hasattr(err, 'errno') and err.errno in _blocking_errnos:
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value)
if hasattr(err, "errno") and err.errno in _blocking_errnos:
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
# Catch possible read timeouts thrown as SSL errors. If not the
# case, rethrow the original. We need to do this because of:
# http://bugs.python.org/issue10272
if 'timed out' in str(err) or 'did not complete (read)' in str(err): # Python < 2.7.4
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value)
if "timed out" in str(err) or "did not complete (read)" in str(
err
): # Python < 2.7.4
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % timeout_value
)
def _make_request(self, conn, method, url, timeout=_Default, chunked=False,
**httplib_request_kw):
def _make_request(
self, conn, method, url, timeout=_Default, chunked=False, **httplib_request_kw
):
"""
Perform a request on a given urllib connection object taken from our
pool.
@@ -358,7 +395,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
read_timeout = timeout_obj.read_timeout
# App Engine doesn't have a sock attr
if getattr(conn, 'sock', None):
if getattr(conn, "sock", None):
# In Python 3 socket.py will catch EAGAIN and return None when you
# try and read into the file pointer created by http.client, which
# instead raises a BadStatusLine exception. Instead of catching
@@ -366,7 +403,8 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# timeouts, check for a zero timeout before making the request.
if read_timeout == 0:
raise ReadTimeoutError(
self, url, "Read timed out. (read timeout=%s)" % read_timeout)
self, url, "Read timed out. (read timeout=%s)" % read_timeout
)
if read_timeout is Timeout.DEFAULT_TIMEOUT:
conn.sock.settimeout(socket.getdefaulttimeout())
else: # None or a value
@@ -381,26 +419,38 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Python 3
try:
httplib_response = conn.getresponse()
except Exception as e:
# Remove the TypeError from the exception chain in Python 3;
# otherwise it looks like a programming error was the cause.
except BaseException as e:
# Remove the TypeError from the exception chain in
# Python 3 (including for exceptions like SystemExit).
# Otherwise it looks like a bug in the code.
six.raise_from(e, None)
except (SocketTimeout, BaseSSLError, SocketError) as e:
self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
raise
# AppEngine doesn't have a version attr.
http_version = getattr(conn, '_http_vsn_str', 'HTTP/?')
log.debug("%s://%s:%s \"%s %s %s\" %s %s", self.scheme, self.host, self.port,
method, url, http_version, httplib_response.status,
httplib_response.length)
http_version = getattr(conn, "_http_vsn_str", "HTTP/?")
log.debug(
'%s://%s:%s "%s %s %s" %s %s',
self.scheme,
self.host,
self.port,
method,
url,
http_version,
httplib_response.status,
httplib_response.length,
)
try:
assert_header_parsing(httplib_response.msg)
except (HeaderParsingError, TypeError) as hpe: # Platform-specific: Python 3
log.warning(
'Failed to parse headers (url=%s): %s',
self._absolute_url(url), hpe, exc_info=True)
"Failed to parse headers (url=%s): %s",
self._absolute_url(url),
hpe,
exc_info=True,
)
return httplib_response
@@ -430,7 +480,7 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
Check if the given ``url`` is a member of the same host as this
connection pool.
"""
if url.startswith('/'):
if url.startswith("/"):
return True
# TODO: Add optional support for socket.gethostbyname checking.
@@ -446,10 +496,22 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
return (scheme, host, port) == (self.scheme, self.host, self.port)
def urlopen(self, method, url, body=None, headers=None, retries=None,
redirect=True, assert_same_host=True, timeout=_Default,
pool_timeout=None, release_conn=None, chunked=False,
body_pos=None, **response_kw):
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
assert_same_host=True,
timeout=_Default,
pool_timeout=None,
release_conn=None,
chunked=False,
body_pos=None,
**response_kw
):
"""
Get a connection from the pool and perform an HTTP request. This is the
lowest level call for making a request, so you'll need to specify all
@@ -547,12 +609,18 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if release_conn is None:
release_conn = response_kw.get('preload_content', True)
release_conn = response_kw.get("preload_content", True)
# Check host
if assert_same_host and not self.is_same_host(url):
raise HostChangedError(self, url, retries)
# Ensure that the URL we're connecting to is properly encoded
if url.startswith("/"):
url = six.ensure_str(_encode_target(url))
else:
url = six.ensure_str(parse_url(url).url)
conn = None
# Track whether `conn` needs to be released before
@@ -563,13 +631,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
#
# See issue #651 [1] for details.
#
# [1] <https://github.com/shazow/urllib3/issues/651>
# [1] <https://github.com/urllib3/urllib3/issues/651>
release_this_conn = release_conn
# Merge the proxy headers. Only do this in HTTP. We have to copy the
# headers dict so we can safely change it without those changes being
# reflected in anyone else's copy.
if self.scheme == 'http':
if self.scheme == "http":
headers = headers.copy()
headers.update(self.proxy_headers)
@@ -592,15 +660,22 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
conn.timeout = timeout_obj.connect_timeout
is_new_proxy_conn = self.proxy is not None and not getattr(conn, 'sock', None)
is_new_proxy_conn = self.proxy is not None and not getattr(
conn, "sock", None
)
if is_new_proxy_conn:
self._prepare_proxy(conn)
# Make the request on the httplib connection object.
httplib_response = self._make_request(conn, method, url,
timeout=timeout_obj,
body=body, headers=headers,
chunked=chunked)
httplib_response = self._make_request(
conn,
method,
url,
timeout=timeout_obj,
body=body,
headers=headers,
chunked=chunked,
)
# If we're going to release the connection in ``finally:``, then
# the response doesn't need to know about the connection. Otherwise
@@ -609,14 +684,16 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
response_conn = conn if not release_conn else None
# Pass method to Response for length checking
response_kw['request_method'] = method
response_kw["request_method"] = method
# Import httplib's response into our own wrapper object
response = self.ResponseCls.from_httplib(httplib_response,
pool=self,
connection=response_conn,
retries=retries,
**response_kw)
response = self.ResponseCls.from_httplib(
httplib_response,
pool=self,
connection=response_conn,
retries=retries,
**response_kw
)
# Everything went great!
clean_exit = True
@@ -625,20 +702,28 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
# Timed out by queue.
raise EmptyPoolError(self, "No pool connections are available.")
except (TimeoutError, HTTPException, SocketError, ProtocolError,
BaseSSLError, SSLError, CertificateError) as e:
except (
TimeoutError,
HTTPException,
SocketError,
ProtocolError,
BaseSSLError,
SSLError,
CertificateError,
) as e:
# Discard the connection for these exceptions. It will be
# replaced during the next _get_conn() call.
clean_exit = False
if isinstance(e, (BaseSSLError, CertificateError)):
e = SSLError(e)
elif isinstance(e, (SocketError, NewConnectionError)) and self.proxy:
e = ProxyError('Cannot connect to proxy.', e)
e = ProxyError("Cannot connect to proxy.", e)
elif isinstance(e, (SocketError, HTTPException)):
e = ProtocolError('Connection aborted.', e)
e = ProtocolError("Connection aborted.", e)
retries = retries.increment(method, url, error=e, _pool=self,
_stacktrace=sys.exc_info()[2])
retries = retries.increment(
method, url, error=e, _pool=self, _stacktrace=sys.exc_info()[2]
)
retries.sleep()
# Keep track of the error for the retry warning.
@@ -661,77 +746,87 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
if not conn:
# Try again
log.warning("Retrying (%r) after connection "
"broken by '%r': %s", retries, err, url)
return self.urlopen(method, url, body, headers, retries,
redirect, assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
release_conn=release_conn, body_pos=body_pos,
**response_kw)
def drain_and_release_conn(response):
try:
# discard any remaining response body, the connection will be
# released back to the pool once the entire response is read
response.read()
except (TimeoutError, HTTPException, SocketError, ProtocolError,
BaseSSLError, SSLError):
pass
log.warning(
"Retrying (%r) after connection broken by '%r': %s", retries, err, url
)
return self.urlopen(
method,
url,
body,
headers,
retries,
redirect,
assert_same_host,
timeout=timeout,
pool_timeout=pool_timeout,
release_conn=release_conn,
chunked=chunked,
body_pos=body_pos,
**response_kw
)
# Handle redirect?
redirect_location = redirect and response.get_redirect_location()
if redirect_location:
if response.status == 303:
method = 'GET'
method = "GET"
try:
retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError:
if retries.raise_on_redirect:
# Drain and release the connection for this response, since
# we're not returning it to be released manually.
drain_and_release_conn(response)
response.drain_conn()
raise
return response
# drain and return the connection to the pool before recursing
drain_and_release_conn(response)
response.drain_conn()
retries.sleep_for_retry(response)
log.debug("Redirecting %s -> %s", url, redirect_location)
return self.urlopen(
method, redirect_location, body, headers,
retries=retries, redirect=redirect,
method,
redirect_location,
body,
headers,
retries=retries,
redirect=redirect,
assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
release_conn=release_conn, body_pos=body_pos,
**response_kw)
timeout=timeout,
pool_timeout=pool_timeout,
release_conn=release_conn,
chunked=chunked,
body_pos=body_pos,
**response_kw
)
# Check if we should retry the HTTP response.
has_retry_after = bool(response.getheader('Retry-After'))
has_retry_after = bool(response.getheader("Retry-After"))
if retries.is_retry(method, response.status, has_retry_after):
try:
retries = retries.increment(method, url, response=response, _pool=self)
except MaxRetryError:
if retries.raise_on_status:
# Drain and release the connection for this response, since
# we're not returning it to be released manually.
drain_and_release_conn(response)
response.drain_conn()
raise
return response
# drain and return the connection to the pool before recursing
drain_and_release_conn(response)
response.drain_conn()
retries.sleep(response)
log.debug("Retry: %s", url)
return self.urlopen(
method, url, body, headers,
retries=retries, redirect=redirect,
method,
url,
body,
headers,
retries=retries,
redirect=redirect,
assert_same_host=assert_same_host,
timeout=timeout, pool_timeout=pool_timeout,
timeout=timeout,
pool_timeout=pool_timeout,
release_conn=release_conn,
body_pos=body_pos, **response_kw)
chunked=chunked,
body_pos=body_pos,
**response_kw
)
return response
@@ -754,21 +849,47 @@ class HTTPSConnectionPool(HTTPConnectionPool):
the connection socket into an SSL socket.
"""
scheme = 'https'
scheme = "https"
ConnectionCls = HTTPSConnection
def __init__(self, host, port=None,
strict=False, timeout=Timeout.DEFAULT_TIMEOUT, maxsize=1,
block=False, headers=None, retries=None,
_proxy=None, _proxy_headers=None,
key_file=None, cert_file=None, cert_reqs=None,
key_password=None, ca_certs=None, ssl_version=None,
assert_hostname=None, assert_fingerprint=None,
ca_cert_dir=None, **conn_kw):
def __init__(
self,
host,
port=None,
strict=False,
timeout=Timeout.DEFAULT_TIMEOUT,
maxsize=1,
block=False,
headers=None,
retries=None,
_proxy=None,
_proxy_headers=None,
key_file=None,
cert_file=None,
cert_reqs=None,
key_password=None,
ca_certs=None,
ssl_version=None,
assert_hostname=None,
assert_fingerprint=None,
ca_cert_dir=None,
**conn_kw
):
HTTPConnectionPool.__init__(self, host, port, strict, timeout, maxsize,
block, headers, retries, _proxy, _proxy_headers,
**conn_kw)
HTTPConnectionPool.__init__(
self,
host,
port,
strict,
timeout,
maxsize,
block,
headers,
retries,
_proxy,
_proxy_headers,
**conn_kw
)
self.key_file = key_file
self.cert_file = cert_file
@@ -787,14 +908,16 @@ class HTTPSConnectionPool(HTTPConnectionPool):
"""
if isinstance(conn, VerifiedHTTPSConnection):
conn.set_cert(key_file=self.key_file,
key_password=self.key_password,
cert_file=self.cert_file,
cert_reqs=self.cert_reqs,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint)
conn.set_cert(
key_file=self.key_file,
key_password=self.key_password,
cert_file=self.cert_file,
cert_reqs=self.cert_reqs,
ca_certs=self.ca_certs,
ca_cert_dir=self.ca_cert_dir,
assert_hostname=self.assert_hostname,
assert_fingerprint=self.assert_fingerprint,
)
conn.ssl_version = self.ssl_version
return conn
@@ -811,12 +934,17 @@ class HTTPSConnectionPool(HTTPConnectionPool):
Return a fresh :class:`httplib.HTTPSConnection`.
"""
self.num_connections += 1
log.debug("Starting new HTTPS connection (%d): %s:%s",
self.num_connections, self.host, self.port or "443")
log.debug(
"Starting new HTTPS connection (%d): %s:%s",
self.num_connections,
self.host,
self.port or "443",
)
if not self.ConnectionCls or self.ConnectionCls is DummyConnection:
raise SSLError("Can't connect to HTTPS URL because the SSL "
"module is not available.")
raise SSLError(
"Can't connect to HTTPS URL because the SSL module is not available."
)
actual_host = self.host
actual_port = self.port
@@ -824,11 +952,16 @@ class HTTPSConnectionPool(HTTPConnectionPool):
actual_host = self.proxy.host
actual_port = self.proxy.port
conn = self.ConnectionCls(host=actual_host, port=actual_port,
timeout=self.timeout.connect_timeout,
strict=self.strict, cert_file=self.cert_file,
key_file=self.key_file, key_password=self.key_password,
**self.conn_kw)
conn = self.ConnectionCls(
host=actual_host,
port=actual_port,
timeout=self.timeout.connect_timeout,
strict=self.strict,
cert_file=self.cert_file,
key_file=self.key_file,
key_password=self.key_password,
**self.conn_kw
)
return self._prepare_conn(conn)
@@ -839,16 +972,19 @@ class HTTPSConnectionPool(HTTPConnectionPool):
super(HTTPSConnectionPool, self)._validate_conn(conn)
# Force connect early to allow us to validate the connection.
if not getattr(conn, 'sock', None): # AppEngine might not have `.sock`
if not getattr(conn, "sock", None): # AppEngine might not have `.sock`
conn.connect()
if not conn.is_verified:
warnings.warn((
'Unverified HTTPS request is being made. '
'Adding certificate verification is strongly advised. See: '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings'),
InsecureRequestWarning)
warnings.warn(
(
"Unverified HTTPS request is being made to host '%s'. "
"Adding certificate verification is strongly advised. See: "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings" % conn.host
),
InsecureRequestWarning,
)
def connection_from_url(url, **kw):
@@ -873,7 +1009,7 @@ def connection_from_url(url, **kw):
"""
scheme, host, port = get_host(url)
port = port or port_by_scheme.get(scheme, 80)
if scheme == 'https':
if scheme == "https":
return HTTPSConnectionPool(host, port=port, **kw)
else:
return HTTPConnectionPool(host, port=port, **kw)
@@ -884,14 +1020,14 @@ def _normalize_host(host, scheme):
Normalize hosts for comparisons and use with sockets.
"""
host = normalize_host(host, scheme)
# httplib doesn't like it when we include brackets in IPv6 addresses
# Specifically, if we include brackets but also pass the port then
# httplib crazily doubles up the square brackets on the Host header.
# Instead, we need to make sure we never pass ``None`` as the port.
# However, for backward compatibility reasons we can't actually
# *assert* that. See http://bugs.python.org/issue28539
if host.startswith('[') and host.endswith(']'):
host = host.strip('[]')
if scheme in NORMALIZABLE_SCHEMES:
host = normalize_host(host)
if host.startswith("[") and host.endswith("]"):
host = host[1:-1]
return host

View File

@@ -6,25 +6,31 @@ import os
def is_appengine():
return (is_local_appengine() or
is_prod_appengine() or
is_prod_appengine_mvms())
return is_local_appengine() or is_prod_appengine()
def is_appengine_sandbox():
return is_appengine() and not is_prod_appengine_mvms()
"""Reports if the app is running in the first generation sandbox.
The second generation runtimes are technically still in a sandbox, but it
is much less restrictive, so generally you shouldn't need to check for it.
see https://cloud.google.com/appengine/docs/standard/runtimes
"""
return is_appengine() and os.environ["APPENGINE_RUNTIME"] == "python27"
def is_local_appengine():
return ('APPENGINE_RUNTIME' in os.environ and
'Development/' in os.environ['SERVER_SOFTWARE'])
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Development/")
def is_prod_appengine():
return ('APPENGINE_RUNTIME' in os.environ and
'Google App Engine/' in os.environ['SERVER_SOFTWARE'] and
not is_prod_appengine_mvms())
return "APPENGINE_RUNTIME" in os.environ and os.environ.get(
"SERVER_SOFTWARE", ""
).startswith("Google App Engine/")
def is_prod_appengine_mvms():
return os.environ.get('GAE_VM', False) == 'true'
"""Deprecated."""
return False

View File

@@ -34,29 +34,35 @@ from __future__ import absolute_import
import platform
from ctypes.util import find_library
from ctypes import (
c_void_p, c_int32, c_char_p, c_size_t, c_byte, c_uint32, c_ulong, c_long,
c_bool
c_void_p,
c_int32,
c_char_p,
c_size_t,
c_byte,
c_uint32,
c_ulong,
c_long,
c_bool,
)
from ctypes import CDLL, POINTER, CFUNCTYPE
security_path = find_library('Security')
security_path = find_library("Security")
if not security_path:
raise ImportError('The library Security could not be found')
raise ImportError("The library Security could not be found")
core_foundation_path = find_library('CoreFoundation')
core_foundation_path = find_library("CoreFoundation")
if not core_foundation_path:
raise ImportError('The library CoreFoundation could not be found')
raise ImportError("The library CoreFoundation could not be found")
version = platform.mac_ver()[0]
version_info = tuple(map(int, version.split('.')))
version_info = tuple(map(int, version.split(".")))
if version_info < (10, 8):
raise OSError(
'Only OS X 10.8 and newer are supported, not %s.%s' % (
version_info[0], version_info[1]
)
"Only OS X 10.8 and newer are supported, not %s.%s"
% (version_info[0], version_info[1])
)
Security = CDLL(security_path, use_errno=True)
@@ -129,27 +135,19 @@ try:
Security.SecKeyGetTypeID.argtypes = []
Security.SecKeyGetTypeID.restype = CFTypeID
Security.SecCertificateCreateWithData.argtypes = [
CFAllocatorRef,
CFDataRef
]
Security.SecCertificateCreateWithData.argtypes = [CFAllocatorRef, CFDataRef]
Security.SecCertificateCreateWithData.restype = SecCertificateRef
Security.SecCertificateCopyData.argtypes = [
SecCertificateRef
]
Security.SecCertificateCopyData.argtypes = [SecCertificateRef]
Security.SecCertificateCopyData.restype = CFDataRef
Security.SecCopyErrorMessageString.argtypes = [
OSStatus,
c_void_p
]
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SecIdentityCreateWithCertificate.argtypes = [
CFTypeRef,
SecCertificateRef,
POINTER(SecIdentityRef)
POINTER(SecIdentityRef),
]
Security.SecIdentityCreateWithCertificate.restype = OSStatus
@@ -159,201 +157,126 @@ try:
c_void_p,
Boolean,
c_void_p,
POINTER(SecKeychainRef)
POINTER(SecKeychainRef),
]
Security.SecKeychainCreate.restype = OSStatus
Security.SecKeychainDelete.argtypes = [
SecKeychainRef
]
Security.SecKeychainDelete.argtypes = [SecKeychainRef]
Security.SecKeychainDelete.restype = OSStatus
Security.SecPKCS12Import.argtypes = [
CFDataRef,
CFDictionaryRef,
POINTER(CFArrayRef)
POINTER(CFArrayRef),
]
Security.SecPKCS12Import.restype = OSStatus
SSLReadFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, c_void_p, POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t))
SSLWriteFunc = CFUNCTYPE(
OSStatus, SSLConnectionRef, POINTER(c_byte), POINTER(c_size_t)
)
Security.SSLSetIOFuncs.argtypes = [
SSLContextRef,
SSLReadFunc,
SSLWriteFunc
]
Security.SSLSetIOFuncs.argtypes = [SSLContextRef, SSLReadFunc, SSLWriteFunc]
Security.SSLSetIOFuncs.restype = OSStatus
Security.SSLSetPeerID.argtypes = [
SSLContextRef,
c_char_p,
c_size_t
]
Security.SSLSetPeerID.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerID.restype = OSStatus
Security.SSLSetCertificate.argtypes = [
SSLContextRef,
CFArrayRef
]
Security.SSLSetCertificate.argtypes = [SSLContextRef, CFArrayRef]
Security.SSLSetCertificate.restype = OSStatus
Security.SSLSetCertificateAuthorities.argtypes = [
SSLContextRef,
CFTypeRef,
Boolean
]
Security.SSLSetCertificateAuthorities.argtypes = [SSLContextRef, CFTypeRef, Boolean]
Security.SSLSetCertificateAuthorities.restype = OSStatus
Security.SSLSetConnection.argtypes = [
SSLContextRef,
SSLConnectionRef
]
Security.SSLSetConnection.argtypes = [SSLContextRef, SSLConnectionRef]
Security.SSLSetConnection.restype = OSStatus
Security.SSLSetPeerDomainName.argtypes = [
SSLContextRef,
c_char_p,
c_size_t
]
Security.SSLSetPeerDomainName.argtypes = [SSLContextRef, c_char_p, c_size_t]
Security.SSLSetPeerDomainName.restype = OSStatus
Security.SSLHandshake.argtypes = [
SSLContextRef
]
Security.SSLHandshake.argtypes = [SSLContextRef]
Security.SSLHandshake.restype = OSStatus
Security.SSLRead.argtypes = [
SSLContextRef,
c_char_p,
c_size_t,
POINTER(c_size_t)
]
Security.SSLRead.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLRead.restype = OSStatus
Security.SSLWrite.argtypes = [
SSLContextRef,
c_char_p,
c_size_t,
POINTER(c_size_t)
]
Security.SSLWrite.argtypes = [SSLContextRef, c_char_p, c_size_t, POINTER(c_size_t)]
Security.SSLWrite.restype = OSStatus
Security.SSLClose.argtypes = [
SSLContextRef
]
Security.SSLClose.argtypes = [SSLContextRef]
Security.SSLClose.restype = OSStatus
Security.SSLGetNumberSupportedCiphers.argtypes = [
SSLContextRef,
POINTER(c_size_t)
]
Security.SSLGetNumberSupportedCiphers.argtypes = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberSupportedCiphers.restype = OSStatus
Security.SSLGetSupportedCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t)
POINTER(c_size_t),
]
Security.SSLGetSupportedCiphers.restype = OSStatus
Security.SSLSetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
c_size_t
c_size_t,
]
Security.SSLSetEnabledCiphers.restype = OSStatus
Security.SSLGetNumberEnabledCiphers.argtype = [
SSLContextRef,
POINTER(c_size_t)
]
Security.SSLGetNumberEnabledCiphers.argtype = [SSLContextRef, POINTER(c_size_t)]
Security.SSLGetNumberEnabledCiphers.restype = OSStatus
Security.SSLGetEnabledCiphers.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite),
POINTER(c_size_t)
POINTER(c_size_t),
]
Security.SSLGetEnabledCiphers.restype = OSStatus
Security.SSLGetNegotiatedCipher.argtypes = [
SSLContextRef,
POINTER(SSLCipherSuite)
]
Security.SSLGetNegotiatedCipher.argtypes = [SSLContextRef, POINTER(SSLCipherSuite)]
Security.SSLGetNegotiatedCipher.restype = OSStatus
Security.SSLGetNegotiatedProtocolVersion.argtypes = [
SSLContextRef,
POINTER(SSLProtocol)
POINTER(SSLProtocol),
]
Security.SSLGetNegotiatedProtocolVersion.restype = OSStatus
Security.SSLCopyPeerTrust.argtypes = [
SSLContextRef,
POINTER(SecTrustRef)
]
Security.SSLCopyPeerTrust.argtypes = [SSLContextRef, POINTER(SecTrustRef)]
Security.SSLCopyPeerTrust.restype = OSStatus
Security.SecTrustSetAnchorCertificates.argtypes = [
SecTrustRef,
CFArrayRef
]
Security.SecTrustSetAnchorCertificates.argtypes = [SecTrustRef, CFArrayRef]
Security.SecTrustSetAnchorCertificates.restype = OSStatus
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [
SecTrustRef,
Boolean
]
Security.SecTrustSetAnchorCertificatesOnly.argstypes = [SecTrustRef, Boolean]
Security.SecTrustSetAnchorCertificatesOnly.restype = OSStatus
Security.SecTrustEvaluate.argtypes = [
SecTrustRef,
POINTER(SecTrustResultType)
]
Security.SecTrustEvaluate.argtypes = [SecTrustRef, POINTER(SecTrustResultType)]
Security.SecTrustEvaluate.restype = OSStatus
Security.SecTrustGetCertificateCount.argtypes = [
SecTrustRef
]
Security.SecTrustGetCertificateCount.argtypes = [SecTrustRef]
Security.SecTrustGetCertificateCount.restype = CFIndex
Security.SecTrustGetCertificateAtIndex.argtypes = [
SecTrustRef,
CFIndex
]
Security.SecTrustGetCertificateAtIndex.argtypes = [SecTrustRef, CFIndex]
Security.SecTrustGetCertificateAtIndex.restype = SecCertificateRef
Security.SSLCreateContext.argtypes = [
CFAllocatorRef,
SSLProtocolSide,
SSLConnectionType
SSLConnectionType,
]
Security.SSLCreateContext.restype = SSLContextRef
Security.SSLSetSessionOption.argtypes = [
SSLContextRef,
SSLSessionOption,
Boolean
]
Security.SSLSetSessionOption.argtypes = [SSLContextRef, SSLSessionOption, Boolean]
Security.SSLSetSessionOption.restype = OSStatus
Security.SSLSetProtocolVersionMin.argtypes = [
SSLContextRef,
SSLProtocol
]
Security.SSLSetProtocolVersionMin.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMin.restype = OSStatus
Security.SSLSetProtocolVersionMax.argtypes = [
SSLContextRef,
SSLProtocol
]
Security.SSLSetProtocolVersionMax.argtypes = [SSLContextRef, SSLProtocol]
Security.SSLSetProtocolVersionMax.restype = OSStatus
Security.SecCopyErrorMessageString.argtypes = [
OSStatus,
c_void_p
]
Security.SecCopyErrorMessageString.argtypes = [OSStatus, c_void_p]
Security.SecCopyErrorMessageString.restype = CFStringRef
Security.SSLReadFunc = SSLReadFunc
@@ -369,64 +292,47 @@ try:
Security.OSStatus = OSStatus
Security.kSecImportExportPassphrase = CFStringRef.in_dll(
Security, 'kSecImportExportPassphrase'
Security, "kSecImportExportPassphrase"
)
Security.kSecImportItemIdentity = CFStringRef.in_dll(
Security, 'kSecImportItemIdentity'
Security, "kSecImportItemIdentity"
)
# CoreFoundation time!
CoreFoundation.CFRetain.argtypes = [
CFTypeRef
]
CoreFoundation.CFRetain.argtypes = [CFTypeRef]
CoreFoundation.CFRetain.restype = CFTypeRef
CoreFoundation.CFRelease.argtypes = [
CFTypeRef
]
CoreFoundation.CFRelease.argtypes = [CFTypeRef]
CoreFoundation.CFRelease.restype = None
CoreFoundation.CFGetTypeID.argtypes = [
CFTypeRef
]
CoreFoundation.CFGetTypeID.argtypes = [CFTypeRef]
CoreFoundation.CFGetTypeID.restype = CFTypeID
CoreFoundation.CFStringCreateWithCString.argtypes = [
CFAllocatorRef,
c_char_p,
CFStringEncoding
CFStringEncoding,
]
CoreFoundation.CFStringCreateWithCString.restype = CFStringRef
CoreFoundation.CFStringGetCStringPtr.argtypes = [
CFStringRef,
CFStringEncoding
]
CoreFoundation.CFStringGetCStringPtr.argtypes = [CFStringRef, CFStringEncoding]
CoreFoundation.CFStringGetCStringPtr.restype = c_char_p
CoreFoundation.CFStringGetCString.argtypes = [
CFStringRef,
c_char_p,
CFIndex,
CFStringEncoding
CFStringEncoding,
]
CoreFoundation.CFStringGetCString.restype = c_bool
CoreFoundation.CFDataCreate.argtypes = [
CFAllocatorRef,
c_char_p,
CFIndex
]
CoreFoundation.CFDataCreate.argtypes = [CFAllocatorRef, c_char_p, CFIndex]
CoreFoundation.CFDataCreate.restype = CFDataRef
CoreFoundation.CFDataGetLength.argtypes = [
CFDataRef
]
CoreFoundation.CFDataGetLength.argtypes = [CFDataRef]
CoreFoundation.CFDataGetLength.restype = CFIndex
CoreFoundation.CFDataGetBytePtr.argtypes = [
CFDataRef
]
CoreFoundation.CFDataGetBytePtr.argtypes = [CFDataRef]
CoreFoundation.CFDataGetBytePtr.restype = c_void_p
CoreFoundation.CFDictionaryCreate.argtypes = [
@@ -435,14 +341,11 @@ try:
POINTER(CFTypeRef),
CFIndex,
CFDictionaryKeyCallBacks,
CFDictionaryValueCallBacks
CFDictionaryValueCallBacks,
]
CoreFoundation.CFDictionaryCreate.restype = CFDictionaryRef
CoreFoundation.CFDictionaryGetValue.argtypes = [
CFDictionaryRef,
CFTypeRef
]
CoreFoundation.CFDictionaryGetValue.argtypes = [CFDictionaryRef, CFTypeRef]
CoreFoundation.CFDictionaryGetValue.restype = CFTypeRef
CoreFoundation.CFArrayCreate.argtypes = [
@@ -456,36 +359,30 @@ try:
CoreFoundation.CFArrayCreateMutable.argtypes = [
CFAllocatorRef,
CFIndex,
CFArrayCallBacks
CFArrayCallBacks,
]
CoreFoundation.CFArrayCreateMutable.restype = CFMutableArrayRef
CoreFoundation.CFArrayAppendValue.argtypes = [
CFMutableArrayRef,
c_void_p
]
CoreFoundation.CFArrayAppendValue.argtypes = [CFMutableArrayRef, c_void_p]
CoreFoundation.CFArrayAppendValue.restype = None
CoreFoundation.CFArrayGetCount.argtypes = [
CFArrayRef
]
CoreFoundation.CFArrayGetCount.argtypes = [CFArrayRef]
CoreFoundation.CFArrayGetCount.restype = CFIndex
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [
CFArrayRef,
CFIndex
]
CoreFoundation.CFArrayGetValueAtIndex.argtypes = [CFArrayRef, CFIndex]
CoreFoundation.CFArrayGetValueAtIndex.restype = c_void_p
CoreFoundation.kCFAllocatorDefault = CFAllocatorRef.in_dll(
CoreFoundation, 'kCFAllocatorDefault'
CoreFoundation, "kCFAllocatorDefault"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(
CoreFoundation, "kCFTypeArrayCallBacks"
)
CoreFoundation.kCFTypeArrayCallBacks = c_void_p.in_dll(CoreFoundation, 'kCFTypeArrayCallBacks')
CoreFoundation.kCFTypeDictionaryKeyCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeDictionaryKeyCallBacks'
CoreFoundation, "kCFTypeDictionaryKeyCallBacks"
)
CoreFoundation.kCFTypeDictionaryValueCallBacks = c_void_p.in_dll(
CoreFoundation, 'kCFTypeDictionaryValueCallBacks'
CoreFoundation, "kCFTypeDictionaryValueCallBacks"
)
CoreFoundation.CFTypeRef = CFTypeRef
@@ -494,7 +391,7 @@ try:
CoreFoundation.CFDictionaryRef = CFDictionaryRef
except (AttributeError):
raise ImportError('Error initializing ctypes')
raise ImportError("Error initializing ctypes")
class CFConst(object):
@@ -502,6 +399,7 @@ class CFConst(object):
A class object that acts as essentially a namespace for CoreFoundation
constants.
"""
kCFStringEncodingUTF8 = CFStringEncoding(0x08000100)
@@ -509,6 +407,7 @@ class SecurityConst(object):
"""
A class object that acts as essentially a namespace for Security constants.
"""
kSSLSessionOptionBreakOnServerAuth = 0
kSSLProtocol2 = 1
@@ -516,6 +415,7 @@ class SecurityConst(object):
kTLSProtocol1 = 4
kTLSProtocol11 = 7
kTLSProtocol12 = 8
# SecureTransport does not support TLS 1.3 even if there's a constant for it
kTLSProtocol13 = 10
kTLSProtocolMaxSupported = 999

View File

@@ -66,22 +66,18 @@ def _cf_string_to_unicode(value):
value_as_void_p = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
string = CoreFoundation.CFStringGetCStringPtr(
value_as_void_p,
CFConst.kCFStringEncodingUTF8
value_as_void_p, CFConst.kCFStringEncodingUTF8
)
if string is None:
buffer = ctypes.create_string_buffer(1024)
result = CoreFoundation.CFStringGetCString(
value_as_void_p,
buffer,
1024,
CFConst.kCFStringEncodingUTF8
value_as_void_p, buffer, 1024, CFConst.kCFStringEncodingUTF8
)
if not result:
raise OSError('Error copying C string from CFStringRef')
raise OSError("Error copying C string from CFStringRef")
string = buffer.value
if string is not None:
string = string.decode('utf-8')
string = string.decode("utf-8")
return string
@@ -97,8 +93,8 @@ def _assert_no_error(error, exception_class=None):
output = _cf_string_to_unicode(cf_error_string)
CoreFoundation.CFRelease(cf_error_string)
if output is None or output == u'':
output = u'OSStatus %s' % error
if output is None or output == u"":
output = u"OSStatus %s" % error
if exception_class is None:
exception_class = ssl.SSLError
@@ -115,8 +111,7 @@ def _cert_array_from_pem(pem_bundle):
pem_bundle = pem_bundle.replace(b"\r\n", b"\n")
der_certs = [
base64.b64decode(match.group(1))
for match in _PEM_CERTS_RE.finditer(pem_bundle)
base64.b64decode(match.group(1)) for match in _PEM_CERTS_RE.finditer(pem_bundle)
]
if not der_certs:
raise ssl.SSLError("No root certificates specified")
@@ -124,7 +119,7 @@ def _cert_array_from_pem(pem_bundle):
cert_array = CoreFoundation.CFArrayCreateMutable(
CoreFoundation.kCFAllocatorDefault,
0,
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks)
ctypes.byref(CoreFoundation.kCFTypeArrayCallBacks),
)
if not cert_array:
raise ssl.SSLError("Unable to allocate memory!")
@@ -186,21 +181,16 @@ def _temporary_keychain():
# some random bytes to password-protect the keychain we're creating, so we
# ask for 40 random bytes.
random_bytes = os.urandom(40)
filename = base64.b16encode(random_bytes[:8]).decode('utf-8')
filename = base64.b16encode(random_bytes[:8]).decode("utf-8")
password = base64.b16encode(random_bytes[8:]) # Must be valid UTF-8
tempdirectory = tempfile.mkdtemp()
keychain_path = os.path.join(tempdirectory, filename).encode('utf-8')
keychain_path = os.path.join(tempdirectory, filename).encode("utf-8")
# We now want to create the keychain itself.
keychain = Security.SecKeychainRef()
status = Security.SecKeychainCreate(
keychain_path,
len(password),
password,
False,
None,
ctypes.byref(keychain)
keychain_path, len(password), password, False, None, ctypes.byref(keychain)
)
_assert_no_error(status)
@@ -219,14 +209,12 @@ def _load_items_from_file(keychain, path):
identities = []
result_array = None
with open(path, 'rb') as f:
with open(path, "rb") as f:
raw_filedata = f.read()
try:
filedata = CoreFoundation.CFDataCreate(
CoreFoundation.kCFAllocatorDefault,
raw_filedata,
len(raw_filedata)
CoreFoundation.kCFAllocatorDefault, raw_filedata, len(raw_filedata)
)
result_array = CoreFoundation.CFArrayRef()
result = Security.SecItemImport(
@@ -237,7 +225,7 @@ def _load_items_from_file(keychain, path):
0, # import flags
None, # key params, can include passphrase in the future
keychain, # The keychain to insert into
ctypes.byref(result_array) # Results
ctypes.byref(result_array), # Results
)
_assert_no_error(result)
@@ -247,9 +235,7 @@ def _load_items_from_file(keychain, path):
# keychain already has them!
result_count = CoreFoundation.CFArrayGetCount(result_array)
for index in range(result_count):
item = CoreFoundation.CFArrayGetValueAtIndex(
result_array, index
)
item = CoreFoundation.CFArrayGetValueAtIndex(result_array, index)
item = ctypes.cast(item, CoreFoundation.CFTypeRef)
if _is_cert(item):
@@ -307,9 +293,7 @@ def _load_client_cert_chain(keychain, *paths):
try:
for file_path in paths:
new_identities, new_certs = _load_items_from_file(
keychain, file_path
)
new_identities, new_certs = _load_items_from_file(keychain, file_path)
identities.extend(new_identities)
certificates.extend(new_certs)
@@ -318,9 +302,7 @@ def _load_client_cert_chain(keychain, *paths):
if not identities:
new_identity = Security.SecIdentityRef()
status = Security.SecIdentityCreateWithCertificate(
keychain,
certificates[0],
ctypes.byref(new_identity)
keychain, certificates[0], ctypes.byref(new_identity)
)
_assert_no_error(status)
identities.append(new_identity)

View File

@@ -50,7 +50,7 @@ from ..exceptions import (
MaxRetryError,
ProtocolError,
TimeoutError,
SSLError
SSLError,
)
from ..request import RequestMethods
@@ -96,23 +96,24 @@ class AppEngineManager(RequestMethods):
Beyond those cases, it will raise normal urllib3 errors.
"""
def __init__(self, headers=None, retries=None, validate_certificate=True,
urlfetch_retries=True):
def __init__(
self,
headers=None,
retries=None,
validate_certificate=True,
urlfetch_retries=True,
):
if not urlfetch:
raise AppEnginePlatformError(
"URLFetch is not available in this environment.")
if is_prod_appengine_mvms():
raise AppEnginePlatformError(
"Use normal urllib3.PoolManager instead of AppEngineManager"
"on Managed VMs, as using URLFetch is not necessary in "
"this environment.")
"URLFetch is not available in this environment."
)
warnings.warn(
"urllib3 is using URLFetch on Google App Engine sandbox instead "
"of sockets. To use sockets directly instead of URLFetch see "
"https://urllib3.readthedocs.io/en/latest/reference/urllib3.contrib.html.",
AppEnginePlatformWarning)
AppEnginePlatformWarning,
)
RequestMethods.__init__(self, headers)
self.validate_certificate = validate_certificate
@@ -127,17 +128,22 @@ class AppEngineManager(RequestMethods):
# Return False to re-raise any potential exceptions
return False
def urlopen(self, method, url, body=None, headers=None,
retries=None, redirect=True, timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw):
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=None,
redirect=True,
timeout=Timeout.DEFAULT_TIMEOUT,
**response_kw
):
retries = self._get_retries(retries, redirect)
try:
follow_redirects = (
redirect and
retries.redirect != 0 and
retries.total)
follow_redirects = redirect and retries.redirect != 0 and retries.total
response = urlfetch.fetch(
url,
payload=body,
@@ -152,44 +158,52 @@ class AppEngineManager(RequestMethods):
raise TimeoutError(self, e)
except urlfetch.InvalidURLError as e:
if 'too large' in str(e):
if "too large" in str(e):
raise AppEnginePlatformError(
"URLFetch request too large, URLFetch only "
"supports requests up to 10mb in size.", e)
"supports requests up to 10mb in size.",
e,
)
raise ProtocolError(e)
except urlfetch.DownloadError as e:
if 'Too many redirects' in str(e):
if "Too many redirects" in str(e):
raise MaxRetryError(self, url, reason=e)
raise ProtocolError(e)
except urlfetch.ResponseTooLargeError as e:
raise AppEnginePlatformError(
"URLFetch response too large, URLFetch only supports"
"responses up to 32mb in size.", e)
"responses up to 32mb in size.",
e,
)
except urlfetch.SSLCertificateError as e:
raise SSLError(e)
except urlfetch.InvalidMethodError as e:
raise AppEnginePlatformError(
"URLFetch does not support method: %s" % method, e)
"URLFetch does not support method: %s" % method, e
)
http_response = self._urlfetch_response_to_http_response(
response, retries=retries, **response_kw)
response, retries=retries, **response_kw
)
# Handle redirect?
redirect_location = redirect and http_response.get_redirect_location()
if redirect_location:
# Check for redirect response
if (self.urlfetch_retries and retries.raise_on_redirect):
if self.urlfetch_retries and retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
else:
if http_response.status == 303:
method = 'GET'
method = "GET"
try:
retries = retries.increment(method, url, response=http_response, _pool=self)
retries = retries.increment(
method, url, response=http_response, _pool=self
)
except MaxRetryError:
if retries.raise_on_redirect:
raise MaxRetryError(self, url, "too many redirects")
@@ -199,22 +213,32 @@ class AppEngineManager(RequestMethods):
log.debug("Redirecting %s -> %s", url, redirect_location)
redirect_url = urljoin(url, redirect_location)
return self.urlopen(
method, redirect_url, body, headers,
retries=retries, redirect=redirect,
timeout=timeout, **response_kw)
method,
redirect_url,
body,
headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
# Check if we should retry the HTTP response.
has_retry_after = bool(http_response.getheader('Retry-After'))
has_retry_after = bool(http_response.getheader("Retry-After"))
if retries.is_retry(method, http_response.status, has_retry_after):
retries = retries.increment(
method, url, response=http_response, _pool=self)
retries = retries.increment(method, url, response=http_response, _pool=self)
log.debug("Retry: %s", url)
retries.sleep(http_response)
return self.urlopen(
method, url,
body=body, headers=headers,
retries=retries, redirect=redirect,
timeout=timeout, **response_kw)
method,
url,
body=body,
headers=headers,
retries=retries,
redirect=redirect,
timeout=timeout,
**response_kw
)
return http_response
@@ -223,18 +247,18 @@ class AppEngineManager(RequestMethods):
if is_prod_appengine():
# Production GAE handles deflate encoding automatically, but does
# not remove the encoding header.
content_encoding = urlfetch_resp.headers.get('content-encoding')
content_encoding = urlfetch_resp.headers.get("content-encoding")
if content_encoding == 'deflate':
del urlfetch_resp.headers['content-encoding']
if content_encoding == "deflate":
del urlfetch_resp.headers["content-encoding"]
transfer_encoding = urlfetch_resp.headers.get('transfer-encoding')
transfer_encoding = urlfetch_resp.headers.get("transfer-encoding")
# We have a full response's content,
# so let's make sure we don't report ourselves as chunked data.
if transfer_encoding == 'chunked':
if transfer_encoding == "chunked":
encodings = transfer_encoding.split(",")
encodings.remove('chunked')
urlfetch_resp.headers['transfer-encoding'] = ','.join(encodings)
encodings.remove("chunked")
urlfetch_resp.headers["transfer-encoding"] = ",".join(encodings)
original_response = HTTPResponse(
# In order for decoding to work, we must present the content as
@@ -262,20 +286,21 @@ class AppEngineManager(RequestMethods):
warnings.warn(
"URLFetch does not support granular timeout settings, "
"reverting to total or default URLFetch timeout.",
AppEnginePlatformWarning)
AppEnginePlatformWarning,
)
return timeout.total
return timeout
def _get_retries(self, retries, redirect):
if not isinstance(retries, Retry):
retries = Retry.from_int(
retries, redirect=redirect, default=self.retries)
retries = Retry.from_int(retries, redirect=redirect, default=self.retries)
if retries.connect or retries.read or retries.redirect:
warnings.warn(
"URLFetch only supports total retries and does not "
"recognize connect, read, or redirect retry parameters.",
AppEnginePlatformWarning)
AppEnginePlatformWarning,
)
return retries

View File

@@ -20,7 +20,7 @@ class NTLMConnectionPool(HTTPSConnectionPool):
Implements an NTLM authentication version of an urllib3 connection pool
"""
scheme = 'https'
scheme = "https"
def __init__(self, user, pw, authurl, *args, **kwargs):
"""
@@ -31,7 +31,7 @@ class NTLMConnectionPool(HTTPSConnectionPool):
super(NTLMConnectionPool, self).__init__(*args, **kwargs)
self.authurl = authurl
self.rawuser = user
user_parts = user.split('\\', 1)
user_parts = user.split("\\", 1)
self.domain = user_parts[0].upper()
self.user = user_parts[1]
self.pw = pw
@@ -40,72 +40,82 @@ class NTLMConnectionPool(HTTPSConnectionPool):
# Performs the NTLM handshake that secures the connection. The socket
# must be kept open while requests are performed.
self.num_connections += 1
log.debug('Starting NTLM HTTPS connection no. %d: https://%s%s',
self.num_connections, self.host, self.authurl)
log.debug(
"Starting NTLM HTTPS connection no. %d: https://%s%s",
self.num_connections,
self.host,
self.authurl,
)
headers = {'Connection': 'Keep-Alive'}
req_header = 'Authorization'
resp_header = 'www-authenticate'
headers = {"Connection": "Keep-Alive"}
req_header = "Authorization"
resp_header = "www-authenticate"
conn = HTTPSConnection(host=self.host, port=self.port)
# Send negotiation message
headers[req_header] = (
'NTLM %s' % ntlm.create_NTLM_NEGOTIATE_MESSAGE(self.rawuser))
log.debug('Request headers: %s', headers)
conn.request('GET', self.authurl, None, headers)
headers[req_header] = "NTLM %s" % ntlm.create_NTLM_NEGOTIATE_MESSAGE(
self.rawuser
)
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
reshdr = dict(res.getheaders())
log.debug('Response status: %s %s', res.status, res.reason)
log.debug('Response headers: %s', reshdr)
log.debug('Response data: %s [...]', res.read(100))
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", reshdr)
log.debug("Response data: %s [...]", res.read(100))
# Remove the reference to the socket, so that it can not be closed by
# the response object (we want to keep the socket open)
res.fp = None
# Server should respond with a challenge message
auth_header_values = reshdr[resp_header].split(', ')
auth_header_values = reshdr[resp_header].split(", ")
auth_header_value = None
for s in auth_header_values:
if s[:5] == 'NTLM ':
if s[:5] == "NTLM ":
auth_header_value = s[5:]
if auth_header_value is None:
raise Exception('Unexpected %s response header: %s' %
(resp_header, reshdr[resp_header]))
raise Exception(
"Unexpected %s response header: %s" % (resp_header, reshdr[resp_header])
)
# Send authentication message
ServerChallenge, NegotiateFlags = \
ntlm.parse_NTLM_CHALLENGE_MESSAGE(auth_header_value)
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(ServerChallenge,
self.user,
self.domain,
self.pw,
NegotiateFlags)
headers[req_header] = 'NTLM %s' % auth_msg
log.debug('Request headers: %s', headers)
conn.request('GET', self.authurl, None, headers)
ServerChallenge, NegotiateFlags = ntlm.parse_NTLM_CHALLENGE_MESSAGE(
auth_header_value
)
auth_msg = ntlm.create_NTLM_AUTHENTICATE_MESSAGE(
ServerChallenge, self.user, self.domain, self.pw, NegotiateFlags
)
headers[req_header] = "NTLM %s" % auth_msg
log.debug("Request headers: %s", headers)
conn.request("GET", self.authurl, None, headers)
res = conn.getresponse()
log.debug('Response status: %s %s', res.status, res.reason)
log.debug('Response headers: %s', dict(res.getheaders()))
log.debug('Response data: %s [...]', res.read()[:100])
log.debug("Response status: %s %s", res.status, res.reason)
log.debug("Response headers: %s", dict(res.getheaders()))
log.debug("Response data: %s [...]", res.read()[:100])
if res.status != 200:
if res.status == 401:
raise Exception('Server rejected request: wrong '
'username or password')
raise Exception('Wrong server response: %s %s' %
(res.status, res.reason))
raise Exception("Server rejected request: wrong username or password")
raise Exception("Wrong server response: %s %s" % (res.status, res.reason))
res.fp = None
log.debug('Connection established')
log.debug("Connection established")
return conn
def urlopen(self, method, url, body=None, headers=None, retries=3,
redirect=True, assert_same_host=True):
def urlopen(
self,
method,
url,
body=None,
headers=None,
retries=3,
redirect=True,
assert_same_host=True,
):
if headers is None:
headers = {}
headers['Connection'] = 'Keep-Alive'
return super(NTLMConnectionPool, self).urlopen(method, url, body,
headers, retries,
redirect,
assert_same_host)
headers["Connection"] = "Keep-Alive"
return super(NTLMConnectionPool, self).urlopen(
method, url, body, headers, retries, redirect, assert_same_host
)

View File

@@ -47,6 +47,7 @@ import OpenSSL.SSL
from cryptography import x509
from cryptography.hazmat.backends.openssl import backend as openssl_backend
from cryptography.hazmat.backends.openssl.x509 import _Certificate
try:
from cryptography.x509 import UnsupportedExtension
except ImportError:
@@ -54,6 +55,7 @@ except ImportError:
class UnsupportedExtension(Exception):
pass
from socket import timeout, error as SocketError
from io import BytesIO
@@ -71,7 +73,7 @@ import sys
from .. import util
__all__ = ['inject_into_urllib3', 'extract_from_urllib3']
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works.
HAS_SNI = True
@@ -82,25 +84,23 @@ _openssl_versions = {
ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD,
}
if hasattr(ssl, 'PROTOCOL_SSLv3') and hasattr(OpenSSL.SSL, 'SSLv3_METHOD'):
if hasattr(ssl, "PROTOCOL_SSLv3") and hasattr(OpenSSL.SSL, "SSLv3_METHOD"):
_openssl_versions[ssl.PROTOCOL_SSLv3] = OpenSSL.SSL.SSLv3_METHOD
if hasattr(ssl, 'PROTOCOL_TLSv1_1') and hasattr(OpenSSL.SSL, 'TLSv1_1_METHOD'):
if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD
if hasattr(ssl, 'PROTOCOL_TLSv1_2') and hasattr(OpenSSL.SSL, 'TLSv1_2_METHOD'):
if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"):
_openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD
_stdlib_to_openssl_verify = {
ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE,
ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER,
ssl.CERT_REQUIRED:
OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER
+ OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}
_openssl_to_stdlib_verify = dict(
(v, k) for k, v in _stdlib_to_openssl_verify.items()
)
_openssl_to_stdlib_verify = dict((v, k) for k, v in _stdlib_to_openssl_verify.items())
# OpenSSL will only write 16K at a time
SSL_WRITE_BLOCKSIZE = 16384
@@ -113,7 +113,7 @@ log = logging.getLogger(__name__)
def inject_into_urllib3():
'Monkey-patch urllib3 with PyOpenSSL-backed SSL-support.'
"Monkey-patch urllib3 with PyOpenSSL-backed SSL-support."
_validate_dependencies_met()
@@ -126,7 +126,7 @@ def inject_into_urllib3():
def extract_from_urllib3():
'Undo monkey-patching by :func:`inject_into_urllib3`.'
"Undo monkey-patching by :func:`inject_into_urllib3`."
util.SSLContext = orig_util_SSLContext
util.ssl_.SSLContext = orig_util_SSLContext
@@ -143,17 +143,23 @@ def _validate_dependencies_met():
"""
# Method added in `cryptography==1.1`; not available in older versions
from cryptography.x509.extensions import Extensions
if getattr(Extensions, "get_extension_for_class", None) is None:
raise ImportError("'cryptography' module missing required functionality. "
"Try upgrading to v1.3.4 or newer.")
raise ImportError(
"'cryptography' module missing required functionality. "
"Try upgrading to v1.3.4 or newer."
)
# pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509
# attribute is only present on those versions.
from OpenSSL.crypto import X509
x509 = X509()
if getattr(x509, "_x509", None) is None:
raise ImportError("'pyOpenSSL' module missing required functionality. "
"Try upgrading to v0.14 or newer.")
raise ImportError(
"'pyOpenSSL' module missing required functionality. "
"Try upgrading to v0.14 or newer."
)
def _dnsname_to_stdlib(name):
@@ -169,6 +175,7 @@ def _dnsname_to_stdlib(name):
If the name cannot be idna-encoded then we return None signalling that
the name given should be skipped.
"""
def idna_encode(name):
"""
Borrowed wholesale from the Python Cryptography Project. It turns out
@@ -178,23 +185,23 @@ def _dnsname_to_stdlib(name):
from pip._vendor import idna
try:
for prefix in [u'*.', u'.']:
for prefix in [u"*.", u"."]:
if name.startswith(prefix):
name = name[len(prefix):]
return prefix.encode('ascii') + idna.encode(name)
name = name[len(prefix) :]
return prefix.encode("ascii") + idna.encode(name)
return idna.encode(name)
except idna.core.IDNAError:
return None
# Don't send IPv6 addresses through the IDNA encoder.
if ':' in name:
if ":" in name:
return name
name = idna_encode(name)
if name is None:
return None
elif sys.version_info >= (3, 0):
name = name.decode('utf-8')
name = name.decode("utf-8")
return name
@@ -213,14 +220,16 @@ def get_subj_alt_name(peer_cert):
# We want to find the SAN extension. Ask Cryptography to locate it (it's
# faster than looping in Python)
try:
ext = cert.extensions.get_extension_for_class(
x509.SubjectAlternativeName
).value
ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
except x509.ExtensionNotFound:
# No such extension, return the empty list.
return []
except (x509.DuplicateExtension, UnsupportedExtension,
x509.UnsupportedGeneralNameType, UnicodeError) as e:
except (
x509.DuplicateExtension,
UnsupportedExtension,
x509.UnsupportedGeneralNameType,
UnicodeError,
) as e:
# A problem has been found with the quality of the certificate. Assume
# no SAN field is present.
log.warning(
@@ -239,23 +248,23 @@ def get_subj_alt_name(peer_cert):
# does with certificates, and so we need to attempt to do the same.
# We also want to skip over names which cannot be idna encoded.
names = [
('DNS', name) for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
("DNS", name)
for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName))
if name is not None
]
names.extend(
('IP Address', str(name))
for name in ext.get_values_for_type(x509.IPAddress)
("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress)
)
return names
class WrappedSocket(object):
'''API-compatibility wrapper for Python OpenSSL's Connection-class.
"""API-compatibility wrapper for Python OpenSSL's Connection-class.
Note: _makefile_refs, _drop() and _reuse() are needed for the garbage
collector of pypy.
'''
"""
def __init__(self, connection, socket, suppress_ragged_eofs=True):
self.connection = connection
@@ -278,18 +287,18 @@ class WrappedSocket(object):
try:
data = self.connection.recv(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
return b''
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return b""
else:
raise SocketError(str(e))
except OpenSSL.SSL.ZeroReturnError:
if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN:
return b''
return b""
else:
raise
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout('The read operation timed out')
raise timeout("The read operation timed out")
else:
return self.recv(*args, **kwargs)
@@ -303,7 +312,7 @@ class WrappedSocket(object):
try:
return self.connection.recv_into(*args, **kwargs)
except OpenSSL.SSL.SysCallError as e:
if self.suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"):
return 0
else:
raise SocketError(str(e))
@@ -314,7 +323,7 @@ class WrappedSocket(object):
raise
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(self.socket, self.socket.gettimeout()):
raise timeout('The read operation timed out')
raise timeout("The read operation timed out")
else:
return self.recv_into(*args, **kwargs)
@@ -339,7 +348,9 @@ class WrappedSocket(object):
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self._send_until_done(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE])
sent = self._send_until_done(
data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE]
)
total_sent += sent
def shutdown(self):
@@ -363,15 +374,11 @@ class WrappedSocket(object):
return x509
if binary_form:
return OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_ASN1,
x509)
return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509)
return {
'subject': (
(('commonName', x509.get_subject().CN),),
),
'subjectAltName': get_subj_alt_name(x509)
"subject": ((("commonName", x509.get_subject().CN),),),
"subjectAltName": get_subj_alt_name(x509),
}
def version(self):
@@ -388,9 +395,12 @@ class WrappedSocket(object):
if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
makefile = backport_makefile
@@ -403,6 +413,7 @@ class PyOpenSSLContext(object):
for translating the interface of the standard library ``SSLContext`` object
to calls into PyOpenSSL.
"""
def __init__(self, protocol):
self.protocol = _openssl_versions[protocol]
self._ctx = OpenSSL.SSL.Context(self.protocol)
@@ -424,43 +435,48 @@ class PyOpenSSLContext(object):
@verify_mode.setter
def verify_mode(self, value):
self._ctx.set_verify(
_stdlib_to_openssl_verify[value],
_verify_callback
)
self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback)
def set_default_verify_paths(self):
self._ctx.set_default_verify_paths()
def set_ciphers(self, ciphers):
if isinstance(ciphers, six.text_type):
ciphers = ciphers.encode('utf-8')
ciphers = ciphers.encode("utf-8")
self._ctx.set_cipher_list(ciphers)
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
if cafile is not None:
cafile = cafile.encode('utf-8')
cafile = cafile.encode("utf-8")
if capath is not None:
capath = capath.encode('utf-8')
self._ctx.load_verify_locations(cafile, capath)
if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata))
capath = capath.encode("utf-8")
try:
self._ctx.load_verify_locations(cafile, capath)
if cadata is not None:
self._ctx.load_verify_locations(BytesIO(cadata))
except OpenSSL.SSL.Error as e:
raise ssl.SSLError("unable to load trusted certificates: %r" % e)
def load_cert_chain(self, certfile, keyfile=None, password=None):
self._ctx.use_certificate_chain_file(certfile)
if password is not None:
if not isinstance(password, six.binary_type):
password = password.encode('utf-8')
password = password.encode("utf-8")
self._ctx.set_passwd_cb(lambda *_: password)
self._ctx.use_privatekey_file(keyfile or certfile)
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
server_hostname=None):
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
cnx = OpenSSL.SSL.Connection(self._ctx, sock)
if isinstance(server_hostname, six.text_type): # Platform-specific: Python 3
server_hostname = server_hostname.encode('utf-8')
server_hostname = server_hostname.encode("utf-8")
if server_hostname is not None:
cnx.set_tlsext_host_name(server_hostname)
@@ -472,10 +488,10 @@ class PyOpenSSLContext(object):
cnx.do_handshake()
except OpenSSL.SSL.WantReadError:
if not util.wait_for_read(sock, sock.gettimeout()):
raise timeout('select timed out')
raise timeout("select timed out")
continue
except OpenSSL.SSL.Error as e:
raise ssl.SSLError('bad handshake: %r' % e)
raise ssl.SSLError("bad handshake: %r" % e)
break
return WrappedSocket(cnx, sock)

View File

@@ -62,12 +62,12 @@ import threading
import weakref
from .. import util
from ._securetransport.bindings import (
Security, SecurityConst, CoreFoundation
)
from ._securetransport.bindings import Security, SecurityConst, CoreFoundation
from ._securetransport.low_level import (
_assert_no_error, _cert_array_from_pem, _temporary_keychain,
_load_client_cert_chain
_assert_no_error,
_cert_array_from_pem,
_temporary_keychain,
_load_client_cert_chain,
)
try: # Platform-specific: Python 2
@@ -76,7 +76,7 @@ except ImportError: # Platform-specific: Python 3
_fileobject = None
from ..packages.backports.makefile import backport_makefile
__all__ = ['inject_into_urllib3', 'extract_from_urllib3']
__all__ = ["inject_into_urllib3", "extract_from_urllib3"]
# SNI always works
HAS_SNI = True
@@ -144,31 +144,36 @@ CIPHER_SUITES = [
]
# Basically this is simple: for PROTOCOL_SSLv23 we turn it into a low of
# TLSv1 and a high of TLSv1.3. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+ and TLSv1.3 is macOS 10.13+
# TLSv1 and a high of TLSv1.2. For everything else, we pin to that version.
# TLSv1 to 1.2 are supported on macOS 10.8+
_protocol_to_min_max = {
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocolMaxSupported),
util.PROTOCOL_TLS: (SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol12)
}
if hasattr(ssl, "PROTOCOL_SSLv2"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv2] = (
SecurityConst.kSSLProtocol2, SecurityConst.kSSLProtocol2
SecurityConst.kSSLProtocol2,
SecurityConst.kSSLProtocol2,
)
if hasattr(ssl, "PROTOCOL_SSLv3"):
_protocol_to_min_max[ssl.PROTOCOL_SSLv3] = (
SecurityConst.kSSLProtocol3, SecurityConst.kSSLProtocol3
SecurityConst.kSSLProtocol3,
SecurityConst.kSSLProtocol3,
)
if hasattr(ssl, "PROTOCOL_TLSv1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1] = (
SecurityConst.kTLSProtocol1, SecurityConst.kTLSProtocol1
SecurityConst.kTLSProtocol1,
SecurityConst.kTLSProtocol1,
)
if hasattr(ssl, "PROTOCOL_TLSv1_1"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_1] = (
SecurityConst.kTLSProtocol11, SecurityConst.kTLSProtocol11
SecurityConst.kTLSProtocol11,
SecurityConst.kTLSProtocol11,
)
if hasattr(ssl, "PROTOCOL_TLSv1_2"):
_protocol_to_min_max[ssl.PROTOCOL_TLSv1_2] = (
SecurityConst.kTLSProtocol12, SecurityConst.kTLSProtocol12
SecurityConst.kTLSProtocol12,
SecurityConst.kTLSProtocol12,
)
@@ -218,7 +223,7 @@ def _read_callback(connection_id, data_buffer, data_length_pointer):
while read_count < requested_length:
if timeout is None or timeout >= 0:
if not util.wait_for_read(base_socket, timeout):
raise socket.error(errno.EAGAIN, 'timed out')
raise socket.error(errno.EAGAIN, "timed out")
remaining = requested_length - read_count
buffer = (ctypes.c_char * remaining).from_address(
@@ -274,7 +279,7 @@ def _write_callback(connection_id, data_buffer, data_length_pointer):
while sent < bytes_to_write:
if timeout is None or timeout >= 0:
if not util.wait_for_write(base_socket, timeout):
raise socket.error(errno.EAGAIN, 'timed out')
raise socket.error(errno.EAGAIN, "timed out")
chunk_sent = base_socket.send(data)
sent += chunk_sent
@@ -316,6 +321,7 @@ class WrappedSocket(object):
Note: _makefile_refs, _drop(), and _reuse() are needed for the garbage
collector of PyPy.
"""
def __init__(self, socket):
self.socket = socket
self.context = None
@@ -380,7 +386,7 @@ class WrappedSocket(object):
# We want data in memory, so load it up.
if os.path.isfile(trust_bundle):
with open(trust_bundle, 'rb') as f:
with open(trust_bundle, "rb") as f:
trust_bundle = f.read()
cert_array = None
@@ -394,9 +400,7 @@ class WrappedSocket(object):
# created for this connection, shove our CAs into it, tell ST to
# ignore everything else it knows, and then ask if it can build a
# chain. This is a buuuunch of code.
result = Security.SSLCopyPeerTrust(
self.context, ctypes.byref(trust)
)
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
raise ssl.SSLError("Failed to copy trust reference")
@@ -408,9 +412,7 @@ class WrappedSocket(object):
_assert_no_error(result)
trust_result = Security.SecTrustResultType()
result = Security.SecTrustEvaluate(
trust, ctypes.byref(trust_result)
)
result = Security.SecTrustEvaluate(trust, ctypes.byref(trust_result))
_assert_no_error(result)
finally:
if trust:
@@ -422,23 +424,24 @@ class WrappedSocket(object):
# Ok, now we can look at what the result was.
successes = (
SecurityConst.kSecTrustResultUnspecified,
SecurityConst.kSecTrustResultProceed
SecurityConst.kSecTrustResultProceed,
)
if trust_result.value not in successes:
raise ssl.SSLError(
"certificate verify failed, error code: %d" %
trust_result.value
"certificate verify failed, error code: %d" % trust_result.value
)
def handshake(self,
server_hostname,
verify,
trust_bundle,
min_version,
max_version,
client_cert,
client_key,
client_key_passphrase):
def handshake(
self,
server_hostname,
verify,
trust_bundle,
min_version,
max_version,
client_cert,
client_key,
client_key_passphrase,
):
"""
Actually performs the TLS handshake. This is run automatically by
wrapped socket, and shouldn't be needed in user code.
@@ -468,7 +471,7 @@ class WrappedSocket(object):
# If we have a server hostname, we should set that too.
if server_hostname:
if not isinstance(server_hostname, bytes):
server_hostname = server_hostname.encode('utf-8')
server_hostname = server_hostname.encode("utf-8")
result = Security.SSLSetPeerDomainName(
self.context, server_hostname, len(server_hostname)
@@ -482,13 +485,7 @@ class WrappedSocket(object):
result = Security.SSLSetProtocolVersionMin(self.context, min_version)
_assert_no_error(result)
# TLS 1.3 isn't necessarily enabled by the OS
# so we have to detect when we error out and try
# setting TLS 1.3 if it's allowed. kTLSProtocolMaxSupported
# was added in macOS 10.13 along with kTLSProtocol13.
result = Security.SSLSetProtocolVersionMax(self.context, max_version)
if result != 0 and max_version == SecurityConst.kTLSProtocolMaxSupported:
result = Security.SSLSetProtocolVersionMax(self.context, SecurityConst.kTLSProtocol12)
_assert_no_error(result)
# If there's a trust DB, we need to use it. We do that by telling
@@ -497,9 +494,7 @@ class WrappedSocket(object):
# authing in that case.
if not verify or trust_bundle is not None:
result = Security.SSLSetSessionOption(
self.context,
SecurityConst.kSSLSessionOptionBreakOnServerAuth,
True
self.context, SecurityConst.kSSLSessionOptionBreakOnServerAuth, True
)
_assert_no_error(result)
@@ -509,9 +504,7 @@ class WrappedSocket(object):
self._client_cert_chain = _load_client_cert_chain(
self._keychain, client_cert, client_key
)
result = Security.SSLSetCertificate(
self.context, self._client_cert_chain
)
result = Security.SSLSetCertificate(self.context, self._client_cert_chain)
_assert_no_error(result)
while True:
@@ -562,7 +555,7 @@ class WrappedSocket(object):
# There are some result codes that we want to treat as "not always
# errors". Specifically, those are errSSLWouldBlock,
# errSSLClosedGraceful, and errSSLClosedNoNotify.
if (result == SecurityConst.errSSLWouldBlock):
if result == SecurityConst.errSSLWouldBlock:
# If we didn't process any bytes, then this was just a time out.
# However, we can get errSSLWouldBlock in situations when we *did*
# read some data, and in those cases we should just read "short"
@@ -570,7 +563,10 @@ class WrappedSocket(object):
if processed_bytes.value == 0:
# Timed out, no data read.
raise socket.timeout("recv timed out")
elif result in (SecurityConst.errSSLClosedGraceful, SecurityConst.errSSLClosedNoNotify):
elif result in (
SecurityConst.errSSLClosedGraceful,
SecurityConst.errSSLClosedNoNotify,
):
# The remote peer has closed this connection. We should do so as
# well. Note that we don't actually return here because in
# principle this could actually be fired along with return data.
@@ -609,7 +605,7 @@ class WrappedSocket(object):
def sendall(self, data):
total_sent = 0
while total_sent < len(data):
sent = self.send(data[total_sent:total_sent + SSL_WRITE_BLOCKSIZE])
sent = self.send(data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE])
total_sent += sent
def shutdown(self):
@@ -656,18 +652,14 @@ class WrappedSocket(object):
# instead to just flag to urllib3 that it shouldn't do its own hostname
# validation when using SecureTransport.
if not binary_form:
raise ValueError(
"SecureTransport only supports dumping binary certs"
)
raise ValueError("SecureTransport only supports dumping binary certs")
trust = Security.SecTrustRef()
certdata = None
der_bytes = None
try:
# Grab the trust store.
result = Security.SSLCopyPeerTrust(
self.context, ctypes.byref(trust)
)
result = Security.SSLCopyPeerTrust(self.context, ctypes.byref(trust))
_assert_no_error(result)
if not trust:
# Probably we haven't done the handshake yet. No biggie.
@@ -699,22 +691,24 @@ class WrappedSocket(object):
def version(self):
protocol = Security.SSLProtocol()
result = Security.SSLGetNegotiatedProtocolVersion(self.context, ctypes.byref(protocol))
result = Security.SSLGetNegotiatedProtocolVersion(
self.context, ctypes.byref(protocol)
)
_assert_no_error(result)
if protocol.value == SecurityConst.kTLSProtocol13:
return 'TLSv1.3'
raise ssl.SSLError("SecureTransport does not support TLS 1.3")
elif protocol.value == SecurityConst.kTLSProtocol12:
return 'TLSv1.2'
return "TLSv1.2"
elif protocol.value == SecurityConst.kTLSProtocol11:
return 'TLSv1.1'
return "TLSv1.1"
elif protocol.value == SecurityConst.kTLSProtocol1:
return 'TLSv1'
return "TLSv1"
elif protocol.value == SecurityConst.kSSLProtocol3:
return 'SSLv3'
return "SSLv3"
elif protocol.value == SecurityConst.kSSLProtocol2:
return 'SSLv2'
return "SSLv2"
else:
raise ssl.SSLError('Unknown TLS version: %r' % protocol)
raise ssl.SSLError("Unknown TLS version: %r" % protocol)
def _reuse(self):
self._makefile_refs += 1
@@ -727,16 +721,21 @@ class WrappedSocket(object):
if _fileobject: # Platform-specific: Python 2
def makefile(self, mode, bufsize=-1):
self._makefile_refs += 1
return _fileobject(self, mode, bufsize, close=True)
else: # Platform-specific: Python 3
def makefile(self, mode="r", buffering=None, *args, **kwargs):
# We disable buffering with SecureTransport because it conflicts with
# the buffering that ST does internally (see issue #1153 for more).
buffering = 0
return backport_makefile(self, mode, buffering, *args, **kwargs)
WrappedSocket.makefile = makefile
@@ -746,6 +745,7 @@ class SecureTransportContext(object):
interface of the standard library ``SSLContext`` object to calls into
SecureTransport.
"""
def __init__(self, protocol):
self._min_version, self._max_version = _protocol_to_min_max[protocol]
self._options = 0
@@ -812,16 +812,17 @@ class SecureTransportContext(object):
def set_ciphers(self, ciphers):
# For now, we just require the default cipher string.
if ciphers != util.ssl_.DEFAULT_CIPHERS:
raise ValueError(
"SecureTransport doesn't support custom cipher strings"
)
raise ValueError("SecureTransport doesn't support custom cipher strings")
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
# OK, we only really support cadata and cafile.
if capath is not None:
raise ValueError(
"SecureTransport does not support cert directories"
)
raise ValueError("SecureTransport does not support cert directories")
# Raise if cafile does not exist.
if cafile is not None:
with open(cafile):
pass
self._trust_bundle = cafile or cadata
@@ -830,9 +831,14 @@ class SecureTransportContext(object):
self._client_key = keyfile
self._client_cert_passphrase = password
def wrap_socket(self, sock, server_side=False,
do_handshake_on_connect=True, suppress_ragged_eofs=True,
server_hostname=None):
def wrap_socket(
self,
sock,
server_side=False,
do_handshake_on_connect=True,
suppress_ragged_eofs=True,
server_hostname=None,
):
# So, what do we do here? Firstly, we assert some properties. This is a
# stripped down shim, so there is some functionality we don't support.
# See PEP 543 for the real deal.
@@ -846,8 +852,13 @@ class SecureTransportContext(object):
# Now we can handshake
wrapped_socket.handshake(
server_hostname, self._verify, self._trust_bundle,
self._min_version, self._max_version, self._client_cert,
self._client_key, self._client_key_passphrase
server_hostname,
self._verify,
self._trust_bundle,
self._min_version,
self._max_version,
self._client_cert,
self._client_key,
self._client_key_passphrase,
)
return wrapped_socket

View File

@@ -42,23 +42,20 @@ except ImportError:
import warnings
from ..exceptions import DependencyWarning
warnings.warn((
'SOCKS support in urllib3 requires the installation of optional '
'dependencies: specifically, PySocks. For more information, see '
'https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies'
warnings.warn(
(
"SOCKS support in urllib3 requires the installation of optional "
"dependencies: specifically, PySocks. For more information, see "
"https://urllib3.readthedocs.io/en/latest/contrib.html#socks-proxies"
),
DependencyWarning
DependencyWarning,
)
raise
from socket import error as SocketError, timeout as SocketTimeout
from ..connection import (
HTTPConnection, HTTPSConnection
)
from ..connectionpool import (
HTTPConnectionPool, HTTPSConnectionPool
)
from ..connection import HTTPConnection, HTTPSConnection
from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from ..exceptions import ConnectTimeoutError, NewConnectionError
from ..poolmanager import PoolManager
from ..util.url import parse_url
@@ -73,8 +70,9 @@ class SOCKSConnection(HTTPConnection):
"""
A plain-text HTTP connection that connects via a SOCKS proxy.
"""
def __init__(self, *args, **kwargs):
self._socks_options = kwargs.pop('_socks_options')
self._socks_options = kwargs.pop("_socks_options")
super(SOCKSConnection, self).__init__(*args, **kwargs)
def _new_conn(self):
@@ -83,28 +81,30 @@ class SOCKSConnection(HTTPConnection):
"""
extra_kw = {}
if self.source_address:
extra_kw['source_address'] = self.source_address
extra_kw["source_address"] = self.source_address
if self.socket_options:
extra_kw['socket_options'] = self.socket_options
extra_kw["socket_options"] = self.socket_options
try:
conn = socks.create_connection(
(self.host, self.port),
proxy_type=self._socks_options['socks_version'],
proxy_addr=self._socks_options['proxy_host'],
proxy_port=self._socks_options['proxy_port'],
proxy_username=self._socks_options['username'],
proxy_password=self._socks_options['password'],
proxy_rdns=self._socks_options['rdns'],
proxy_type=self._socks_options["socks_version"],
proxy_addr=self._socks_options["proxy_host"],
proxy_port=self._socks_options["proxy_port"],
proxy_username=self._socks_options["username"],
proxy_password=self._socks_options["password"],
proxy_rdns=self._socks_options["rdns"],
timeout=self.timeout,
**extra_kw
)
except SocketTimeout:
raise ConnectTimeoutError(
self, "Connection to %s timed out. (connect timeout=%s)" %
(self.host, self.timeout))
self,
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
except socks.ProxyError as e:
# This is fragile as hell, but it seems to be the only way to raise
@@ -114,23 +114,22 @@ class SOCKSConnection(HTTPConnection):
if isinstance(error, SocketTimeout):
raise ConnectTimeoutError(
self,
"Connection to %s timed out. (connect timeout=%s)" %
(self.host, self.timeout)
"Connection to %s timed out. (connect timeout=%s)"
% (self.host, self.timeout),
)
else:
raise NewConnectionError(
self,
"Failed to establish a new connection: %s" % error
self, "Failed to establish a new connection: %s" % error
)
else:
raise NewConnectionError(
self,
"Failed to establish a new connection: %s" % e
self, "Failed to establish a new connection: %s" % e
)
except SocketError as e: # Defensive: PySocks should catch all these.
raise NewConnectionError(
self, "Failed to establish a new connection: %s" % e)
self, "Failed to establish a new connection: %s" % e
)
return conn
@@ -156,47 +155,53 @@ class SOCKSProxyManager(PoolManager):
A version of the urllib3 ProxyManager that routes connections via the
defined SOCKS proxy.
"""
pool_classes_by_scheme = {
'http': SOCKSHTTPConnectionPool,
'https': SOCKSHTTPSConnectionPool,
"http": SOCKSHTTPConnectionPool,
"https": SOCKSHTTPSConnectionPool,
}
def __init__(self, proxy_url, username=None, password=None,
num_pools=10, headers=None, **connection_pool_kw):
def __init__(
self,
proxy_url,
username=None,
password=None,
num_pools=10,
headers=None,
**connection_pool_kw
):
parsed = parse_url(proxy_url)
if username is None and password is None and parsed.auth is not None:
split = parsed.auth.split(':')
split = parsed.auth.split(":")
if len(split) == 2:
username, password = split
if parsed.scheme == 'socks5':
if parsed.scheme == "socks5":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = False
elif parsed.scheme == 'socks5h':
elif parsed.scheme == "socks5h":
socks_version = socks.PROXY_TYPE_SOCKS5
rdns = True
elif parsed.scheme == 'socks4':
elif parsed.scheme == "socks4":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = False
elif parsed.scheme == 'socks4a':
elif parsed.scheme == "socks4a":
socks_version = socks.PROXY_TYPE_SOCKS4
rdns = True
else:
raise ValueError(
"Unable to determine SOCKS version from %s" % proxy_url
)
raise ValueError("Unable to determine SOCKS version from %s" % proxy_url)
self.proxy_url = proxy_url
socks_options = {
'socks_version': socks_version,
'proxy_host': parsed.host,
'proxy_port': parsed.port,
'username': username,
'password': password,
'rdns': rdns
"socks_version": socks_version,
"proxy_host": parsed.host,
"proxy_port": parsed.port,
"username": username,
"password": password,
"rdns": rdns,
}
connection_pool_kw['_socks_options'] = socks_options
connection_pool_kw["_socks_options"] = socks_options
super(SOCKSProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw

View File

@@ -1,7 +1,6 @@
from __future__ import absolute_import
from .packages.six.moves.http_client import (
IncompleteRead as httplib_IncompleteRead
)
from .packages.six.moves.http_client import IncompleteRead as httplib_IncompleteRead
# Base Exceptions
@@ -17,6 +16,7 @@ class HTTPWarning(Warning):
class PoolError(HTTPError):
"Base exception for errors caused within a pool."
def __init__(self, pool, message):
self.pool = pool
HTTPError.__init__(self, "%s: %s" % (pool, message))
@@ -28,6 +28,7 @@ class PoolError(HTTPError):
class RequestError(PoolError):
"Base exception for PoolErrors that have associated URLs."
def __init__(self, pool, url, message):
self.url = url
PoolError.__init__(self, pool, message)
@@ -44,7 +45,10 @@ class SSLError(HTTPError):
class ProxyError(HTTPError):
"Raised when the connection to a proxy fails."
pass
def __init__(self, message, error, *args):
super(ProxyError, self).__init__(message, error, *args)
self.original_error = error
class DecodeError(HTTPError):
@@ -63,6 +67,7 @@ ConnectionError = ProtocolError
# Leaf Exceptions
class MaxRetryError(RequestError):
"""Raised when the maximum number of retries is exceeded.
@@ -76,8 +81,7 @@ class MaxRetryError(RequestError):
def __init__(self, pool, url, reason=None):
self.reason = reason
message = "Max retries exceeded with url: %s (Caused by %r)" % (
url, reason)
message = "Max retries exceeded with url: %s (Caused by %r)" % (url, reason)
RequestError.__init__(self, pool, url, message)
@@ -93,6 +97,7 @@ class HostChangedError(RequestError):
class TimeoutStateError(HTTPError):
""" Raised when passing an invalid state to a timeout """
pass
@@ -102,6 +107,7 @@ class TimeoutError(HTTPError):
Catching this error will catch both :exc:`ReadTimeoutErrors
<ReadTimeoutError>` and :exc:`ConnectTimeoutErrors <ConnectTimeoutError>`.
"""
pass
@@ -149,8 +155,8 @@ class LocationParseError(LocationValueError):
class ResponseError(HTTPError):
"Used as a container for an error reason supplied in a MaxRetryError."
GENERIC_ERROR = 'too many error responses'
SPECIFIC_ERROR = 'too many {status_code} error responses'
GENERIC_ERROR = "too many error responses"
SPECIFIC_ERROR = "too many {status_code} error responses"
class SecurityWarning(HTTPWarning):
@@ -188,6 +194,21 @@ class DependencyWarning(HTTPWarning):
Warned when an attempt is made to import a module with missing optional
dependencies.
"""
pass
class InvalidProxyConfigurationWarning(HTTPWarning):
"""
Warned when using an HTTPS proxy and an HTTPS URL. Currently
urllib3 doesn't support HTTPS proxies and the proxy will be
contacted via HTTP instead. This warning can be fixed by
changing your HTTPS proxy URL into an HTTP proxy URL.
If you encounter this warning read this:
https://github.com/urllib3/urllib3/issues/1850
"""
pass
@@ -201,6 +222,7 @@ class BodyNotHttplibCompatible(HTTPError):
Body should be httplib.HTTPResponse like (have an fp attribute which
returns raw chunks) for read_chunked().
"""
pass
@@ -212,12 +234,15 @@ class IncompleteRead(HTTPError, httplib_IncompleteRead):
for `partial` to avoid creating large objects on streamed
reads.
"""
def __init__(self, partial, expected):
super(IncompleteRead, self).__init__(partial, expected)
def __repr__(self):
return ('IncompleteRead(%i bytes read, '
'%i more expected)' % (self.partial, self.expected))
return "IncompleteRead(%i bytes read, %i more expected)" % (
self.partial,
self.expected,
)
class InvalidHeader(HTTPError):
@@ -236,8 +261,9 @@ class ProxySchemeUnknown(AssertionError, ValueError):
class HeaderParsingError(HTTPError):
"Raised by assert_header_parsing, but we convert it to a log.warning statement."
def __init__(self, defects, unparsed_data):
message = '%s, unparsed data: %r' % (defects or 'Unknown', unparsed_data)
message = "%s, unparsed data: %r" % (defects or "Unknown", unparsed_data)
super(HeaderParsingError, self).__init__(message)

View File

@@ -6,7 +6,7 @@ import re
from .packages import six
def guess_content_type(filename, default='application/octet-stream'):
def guess_content_type(filename, default="application/octet-stream"):
"""
Guess the "Content-Type" of a file.
@@ -41,22 +41,22 @@ def format_header_param_rfc2231(name, value):
if not any(ch in value for ch in '"\\\r\n'):
result = u'%s="%s"' % (name, value)
try:
result.encode('ascii')
result.encode("ascii")
except (UnicodeEncodeError, UnicodeDecodeError):
pass
else:
return result
if not six.PY3: # Python 2:
value = value.encode('utf-8')
if six.PY2: # Python 2:
value = value.encode("utf-8")
# encode_rfc2231 accepts an encoded string and returns an ascii-encoded
# string in Python 2 but accepts and returns unicode strings in Python 3
value = email.utils.encode_rfc2231(value, 'utf-8')
value = '%s*=%s' % (name, value)
value = email.utils.encode_rfc2231(value, "utf-8")
value = "%s*=%s" % (name, value)
if not six.PY3: # Python 2:
value = value.decode('utf-8')
if six.PY2: # Python 2:
value = value.decode("utf-8")
return value
@@ -69,23 +69,21 @@ _HTML5_REPLACEMENTS = {
}
# All control characters from 0x00 to 0x1F *except* 0x1B.
_HTML5_REPLACEMENTS.update({
six.unichr(cc): u"%{:02X}".format(cc)
for cc
in range(0x00, 0x1F+1)
if cc not in (0x1B,)
})
_HTML5_REPLACEMENTS.update(
{
six.unichr(cc): u"%{:02X}".format(cc)
for cc in range(0x00, 0x1F + 1)
if cc not in (0x1B,)
}
)
def _replace_multiple(value, needles_and_replacements):
def replacer(match):
return needles_and_replacements[match.group(0)]
pattern = re.compile(
r"|".join([
re.escape(needle) for needle in needles_and_replacements.keys()
])
r"|".join([re.escape(needle) for needle in needles_and_replacements.keys()])
)
result = pattern.sub(replacer, value)
@@ -140,13 +138,15 @@ class RequestField(object):
An optional callable that is used to encode and format the headers. By
default, this is :func:`format_header_param_html5`.
"""
def __init__(
self,
name,
data,
filename=None,
headers=None,
header_formatter=format_header_param_html5):
self,
name,
data,
filename=None,
headers=None,
header_formatter=format_header_param_html5,
):
self._name = name
self._filename = filename
self.data = data
@@ -156,11 +156,7 @@ class RequestField(object):
self.header_formatter = header_formatter
@classmethod
def from_tuples(
cls,
fieldname,
value,
header_formatter=format_header_param_html5):
def from_tuples(cls, fieldname, value, header_formatter=format_header_param_html5):
"""
A :class:`~urllib3.fields.RequestField` factory from old-style tuple parameters.
@@ -189,7 +185,8 @@ class RequestField(object):
data = value
request_param = cls(
fieldname, data, filename=filename, header_formatter=header_formatter)
fieldname, data, filename=filename, header_formatter=header_formatter
)
request_param.make_multipart(content_type=content_type)
return request_param
@@ -227,7 +224,7 @@ class RequestField(object):
if value is not None:
parts.append(self._render_part(name, value))
return u'; '.join(parts)
return u"; ".join(parts)
def render_headers(self):
"""
@@ -235,21 +232,22 @@ class RequestField(object):
"""
lines = []
sort_keys = ['Content-Disposition', 'Content-Type', 'Content-Location']
sort_keys = ["Content-Disposition", "Content-Type", "Content-Location"]
for sort_key in sort_keys:
if self.headers.get(sort_key, False):
lines.append(u'%s: %s' % (sort_key, self.headers[sort_key]))
lines.append(u"%s: %s" % (sort_key, self.headers[sort_key]))
for header_name, header_value in self.headers.items():
if header_name not in sort_keys:
if header_value:
lines.append(u'%s: %s' % (header_name, header_value))
lines.append(u"%s: %s" % (header_name, header_value))
lines.append(u'\r\n')
return u'\r\n'.join(lines)
lines.append(u"\r\n")
return u"\r\n".join(lines)
def make_multipart(self, content_disposition=None, content_type=None,
content_location=None):
def make_multipart(
self, content_disposition=None, content_type=None, content_location=None
):
"""
Makes this request field into a multipart request field.
@@ -262,11 +260,14 @@ class RequestField(object):
The 'Content-Location' of the request body.
"""
self.headers['Content-Disposition'] = content_disposition or u'form-data'
self.headers['Content-Disposition'] += u'; '.join([
u'', self._render_parts(
((u'name', self._name), (u'filename', self._filename))
)
])
self.headers['Content-Type'] = content_type
self.headers['Content-Location'] = content_location
self.headers["Content-Disposition"] = content_disposition or u"form-data"
self.headers["Content-Disposition"] += u"; ".join(
[
u"",
self._render_parts(
((u"name", self._name), (u"filename", self._filename))
),
]
)
self.headers["Content-Type"] = content_type
self.headers["Content-Location"] = content_location

View File

@@ -9,7 +9,7 @@ from .packages import six
from .packages.six import b
from .fields import RequestField
writer = codecs.lookup('utf-8')[3]
writer = codecs.lookup("utf-8")[3]
def choose_boundary():
@@ -17,8 +17,8 @@ def choose_boundary():
Our embarrassingly-simple replacement for mimetools.choose_boundary.
"""
boundary = binascii.hexlify(os.urandom(16))
if six.PY3:
boundary = boundary.decode('ascii')
if not six.PY2:
boundary = boundary.decode("ascii")
return boundary
@@ -76,7 +76,7 @@ def encode_multipart_formdata(fields, boundary=None):
boundary = choose_boundary()
for field in iter_field_objects(fields):
body.write(b('--%s\r\n' % (boundary)))
body.write(b("--%s\r\n" % (boundary)))
writer(body).write(field.render_headers())
data = field.data
@@ -89,10 +89,10 @@ def encode_multipart_formdata(fields, boundary=None):
else:
body.write(data)
body.write(b'\r\n')
body.write(b"\r\n")
body.write(b('--%s--\r\n' % (boundary)))
body.write(b("--%s--\r\n" % (boundary)))
content_type = str('multipart/form-data; boundary=%s' % boundary)
content_type = str("multipart/form-data; boundary=%s" % boundary)
return body.getvalue(), content_type

View File

@@ -2,4 +2,4 @@ from __future__ import absolute_import
from . import ssl_match_hostname
__all__ = ('ssl_match_hostname', )
__all__ = ("ssl_match_hostname",)

View File

@@ -11,15 +11,14 @@ import io
from socket import SocketIO
def backport_makefile(self, mode="r", buffering=None, encoding=None,
errors=None, newline=None):
def backport_makefile(
self, mode="r", buffering=None, encoding=None, errors=None, newline=None
):
"""
Backport of ``socket.makefile`` from Python 3.5.
"""
if not set(mode) <= {"r", "w", "b"}:
raise ValueError(
"invalid mode %r (only r, w, b allowed)" % (mode,)
)
raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
writing = "w" in mode
reading = "r" in mode or not writing
assert reading or writing

View File

@@ -1,56 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
An implementation of semantics and validations described in RFC 3986.
See http://rfc3986.readthedocs.io/ for detailed documentation.
:copyright: (c) 2014 Rackspace
:license: Apache v2.0, see LICENSE for details
"""
from .api import iri_reference
from .api import IRIReference
from .api import is_valid_uri
from .api import normalize_uri
from .api import uri_reference
from .api import URIReference
from .api import urlparse
from .parseresult import ParseResult
__title__ = 'rfc3986'
__author__ = 'Ian Stapleton Cordasco'
__author_email__ = 'graffatcolmingov@gmail.com'
__license__ = 'Apache v2.0'
__copyright__ = 'Copyright 2014 Rackspace'
__version__ = '1.3.2'
__all__ = (
'ParseResult',
'URIReference',
'IRIReference',
'is_valid_uri',
'normalize_uri',
'uri_reference',
'iri_reference',
'urlparse',
'__title__',
'__author__',
'__author_email__',
'__license__',
'__copyright__',
'__version__',
)

View File

@@ -1,353 +0,0 @@
"""Module containing the implementation of the URIMixin class."""
import warnings
from . import exceptions as exc
from . import misc
from . import normalizers
from . import validators
class URIMixin(object):
"""Mixin with all shared methods for URIs and IRIs."""
__hash__ = tuple.__hash__
def authority_info(self):
"""Return a dictionary with the ``userinfo``, ``host``, and ``port``.
If the authority is not valid, it will raise a
:class:`~rfc3986.exceptions.InvalidAuthority` Exception.
:returns:
``{'userinfo': 'username:password', 'host': 'www.example.com',
'port': '80'}``
:rtype: dict
:raises rfc3986.exceptions.InvalidAuthority:
If the authority is not ``None`` and can not be parsed.
"""
if not self.authority:
return {'userinfo': None, 'host': None, 'port': None}
match = self._match_subauthority()
if match is None:
# In this case, we have an authority that was parsed from the URI
# Reference, but it cannot be further parsed by our
# misc.SUBAUTHORITY_MATCHER. In this case it must not be a valid
# authority.
raise exc.InvalidAuthority(self.authority.encode(self.encoding))
# We had a match, now let's ensure that it is actually a valid host
# address if it is IPv4
matches = match.groupdict()
host = matches.get('host')
if (host and misc.IPv4_MATCHER.match(host) and not
validators.valid_ipv4_host_address(host)):
# If we have a host, it appears to be IPv4 and it does not have
# valid bytes, it is an InvalidAuthority.
raise exc.InvalidAuthority(self.authority.encode(self.encoding))
return matches
def _match_subauthority(self):
return misc.SUBAUTHORITY_MATCHER.match(self.authority)
@property
def host(self):
"""If present, a string representing the host."""
try:
authority = self.authority_info()
except exc.InvalidAuthority:
return None
return authority['host']
@property
def port(self):
"""If present, the port extracted from the authority."""
try:
authority = self.authority_info()
except exc.InvalidAuthority:
return None
return authority['port']
@property
def userinfo(self):
"""If present, the userinfo extracted from the authority."""
try:
authority = self.authority_info()
except exc.InvalidAuthority:
return None
return authority['userinfo']
def is_absolute(self):
"""Determine if this URI Reference is an absolute URI.
See http://tools.ietf.org/html/rfc3986#section-4.3 for explanation.
:returns: ``True`` if it is an absolute URI, ``False`` otherwise.
:rtype: bool
"""
return bool(misc.ABSOLUTE_URI_MATCHER.match(self.unsplit()))
def is_valid(self, **kwargs):
"""Determine if the URI is valid.
.. deprecated:: 1.1.0
Use the :class:`~rfc3986.validators.Validator` object instead.
:param bool require_scheme: Set to ``True`` if you wish to require the
presence of the scheme component.
:param bool require_authority: Set to ``True`` if you wish to require
the presence of the authority component.
:param bool require_path: Set to ``True`` if you wish to require the
presence of the path component.
:param bool require_query: Set to ``True`` if you wish to require the
presence of the query component.
:param bool require_fragment: Set to ``True`` if you wish to require
the presence of the fragment component.
:returns: ``True`` if the URI is valid. ``False`` otherwise.
:rtype: bool
"""
warnings.warn("Please use rfc3986.validators.Validator instead. "
"This method will be eventually removed.",
DeprecationWarning)
validators = [
(self.scheme_is_valid, kwargs.get('require_scheme', False)),
(self.authority_is_valid, kwargs.get('require_authority', False)),
(self.path_is_valid, kwargs.get('require_path', False)),
(self.query_is_valid, kwargs.get('require_query', False)),
(self.fragment_is_valid, kwargs.get('require_fragment', False)),
]
return all(v(r) for v, r in validators)
def authority_is_valid(self, require=False):
"""Determine if the authority component is valid.
.. deprecated:: 1.1.0
Use the :class:`~rfc3986.validators.Validator` object instead.
:param bool require:
Set to ``True`` to require the presence of this component.
:returns:
``True`` if the authority is valid. ``False`` otherwise.
:rtype:
bool
"""
warnings.warn("Please use rfc3986.validators.Validator instead. "
"This method will be eventually removed.",
DeprecationWarning)
try:
self.authority_info()
except exc.InvalidAuthority:
return False
return validators.authority_is_valid(
self.authority,
host=self.host,
require=require,
)
def scheme_is_valid(self, require=False):
"""Determine if the scheme component is valid.
.. deprecated:: 1.1.0
Use the :class:`~rfc3986.validators.Validator` object instead.
:param str require: Set to ``True`` to require the presence of this
component.
:returns: ``True`` if the scheme is valid. ``False`` otherwise.
:rtype: bool
"""
warnings.warn("Please use rfc3986.validators.Validator instead. "
"This method will be eventually removed.",
DeprecationWarning)
return validators.scheme_is_valid(self.scheme, require)
def path_is_valid(self, require=False):
"""Determine if the path component is valid.
.. deprecated:: 1.1.0
Use the :class:`~rfc3986.validators.Validator` object instead.
:param str require: Set to ``True`` to require the presence of this
component.
:returns: ``True`` if the path is valid. ``False`` otherwise.
:rtype: bool
"""
warnings.warn("Please use rfc3986.validators.Validator instead. "
"This method will be eventually removed.",
DeprecationWarning)
return validators.path_is_valid(self.path, require)
def query_is_valid(self, require=False):
"""Determine if the query component is valid.
.. deprecated:: 1.1.0
Use the :class:`~rfc3986.validators.Validator` object instead.
:param str require: Set to ``True`` to require the presence of this
component.
:returns: ``True`` if the query is valid. ``False`` otherwise.
:rtype: bool
"""
warnings.warn("Please use rfc3986.validators.Validator instead. "
"This method will be eventually removed.",
DeprecationWarning)
return validators.query_is_valid(self.query, require)
def fragment_is_valid(self, require=False):
"""Determine if the fragment component is valid.
.. deprecated:: 1.1.0
Use the Validator object instead.
:param str require: Set to ``True`` to require the presence of this
component.
:returns: ``True`` if the fragment is valid. ``False`` otherwise.
:rtype: bool
"""
warnings.warn("Please use rfc3986.validators.Validator instead. "
"This method will be eventually removed.",
DeprecationWarning)
return validators.fragment_is_valid(self.fragment, require)
def normalized_equality(self, other_ref):
"""Compare this URIReference to another URIReference.
:param URIReference other_ref: (required), The reference with which
we're comparing.
:returns: ``True`` if the references are equal, ``False`` otherwise.
:rtype: bool
"""
return tuple(self.normalize()) == tuple(other_ref.normalize())
def resolve_with(self, base_uri, strict=False):
"""Use an absolute URI Reference to resolve this relative reference.
Assuming this is a relative reference that you would like to resolve,
use the provided base URI to resolve it.
See http://tools.ietf.org/html/rfc3986#section-5 for more information.
:param base_uri: Either a string or URIReference. It must be an
absolute URI or it will raise an exception.
:returns: A new URIReference which is the result of resolving this
reference using ``base_uri``.
:rtype: :class:`URIReference`
:raises rfc3986.exceptions.ResolutionError:
If the ``base_uri`` is not an absolute URI.
"""
if not isinstance(base_uri, URIMixin):
base_uri = type(self).from_string(base_uri)
if not base_uri.is_absolute():
raise exc.ResolutionError(base_uri)
# This is optional per
# http://tools.ietf.org/html/rfc3986#section-5.2.1
base_uri = base_uri.normalize()
# The reference we're resolving
resolving = self
if not strict and resolving.scheme == base_uri.scheme:
resolving = resolving.copy_with(scheme=None)
# http://tools.ietf.org/html/rfc3986#page-32
if resolving.scheme is not None:
target = resolving.copy_with(
path=normalizers.normalize_path(resolving.path)
)
else:
if resolving.authority is not None:
target = resolving.copy_with(
scheme=base_uri.scheme,
path=normalizers.normalize_path(resolving.path)
)
else:
if resolving.path is None:
if resolving.query is not None:
query = resolving.query
else:
query = base_uri.query
target = resolving.copy_with(
scheme=base_uri.scheme,
authority=base_uri.authority,
path=base_uri.path,
query=query
)
else:
if resolving.path.startswith('/'):
path = normalizers.normalize_path(resolving.path)
else:
path = normalizers.normalize_path(
misc.merge_paths(base_uri, resolving.path)
)
target = resolving.copy_with(
scheme=base_uri.scheme,
authority=base_uri.authority,
path=path,
query=resolving.query
)
return target
def unsplit(self):
"""Create a URI string from the components.
:returns: The URI Reference reconstituted as a string.
:rtype: str
"""
# See http://tools.ietf.org/html/rfc3986#section-5.3
result_list = []
if self.scheme:
result_list.extend([self.scheme, ':'])
if self.authority:
result_list.extend(['//', self.authority])
if self.path:
result_list.append(self.path)
if self.query is not None:
result_list.extend(['?', self.query])
if self.fragment is not None:
result_list.extend(['#', self.fragment])
return ''.join(result_list)
def copy_with(self, scheme=misc.UseExisting, authority=misc.UseExisting,
path=misc.UseExisting, query=misc.UseExisting,
fragment=misc.UseExisting):
"""Create a copy of this reference with the new components.
:param str scheme:
(optional) The scheme to use for the new reference.
:param str authority:
(optional) The authority to use for the new reference.
:param str path:
(optional) The path to use for the new reference.
:param str query:
(optional) The query to use for the new reference.
:param str fragment:
(optional) The fragment to use for the new reference.
:returns:
New URIReference with provided components.
:rtype:
URIReference
"""
attributes = {
'scheme': scheme,
'authority': authority,
'path': path,
'query': query,
'fragment': fragment,
}
for key, value in list(attributes.items()):
if value is misc.UseExisting:
del attributes[key]
uri = self._replace(**attributes)
uri.encoding = self.encoding
return uri

View File

@@ -1,267 +0,0 @@
# -*- coding: utf-8 -*-
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the regular expressions crafted from ABNF."""
import sys
# https://tools.ietf.org/html/rfc3986#page-13
GEN_DELIMS = GENERIC_DELIMITERS = ":/?#[]@"
GENERIC_DELIMITERS_SET = set(GENERIC_DELIMITERS)
# https://tools.ietf.org/html/rfc3986#page-13
SUB_DELIMS = SUB_DELIMITERS = "!$&'()*+,;="
SUB_DELIMITERS_SET = set(SUB_DELIMITERS)
# Escape the '*' for use in regular expressions
SUB_DELIMITERS_RE = r"!$&'()\*+,;="
RESERVED_CHARS_SET = GENERIC_DELIMITERS_SET.union(SUB_DELIMITERS_SET)
ALPHA = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
DIGIT = '0123456789'
# https://tools.ietf.org/html/rfc3986#section-2.3
UNRESERVED = UNRESERVED_CHARS = ALPHA + DIGIT + r'._!-'
UNRESERVED_CHARS_SET = set(UNRESERVED_CHARS)
NON_PCT_ENCODED_SET = RESERVED_CHARS_SET.union(UNRESERVED_CHARS_SET)
# We need to escape the '-' in this case:
UNRESERVED_RE = r'A-Za-z0-9._~\-'
# Percent encoded character values
PERCENT_ENCODED = PCT_ENCODED = '%[A-Fa-f0-9]{2}'
PCHAR = '([' + UNRESERVED_RE + SUB_DELIMITERS_RE + ':@]|%s)' % PCT_ENCODED
# NOTE(sigmavirus24): We're going to use more strict regular expressions
# than appear in Appendix B for scheme. This will prevent over-eager
# consuming of items that aren't schemes.
SCHEME_RE = '[a-zA-Z][a-zA-Z0-9+.-]*'
_AUTHORITY_RE = '[^/?#]*'
_PATH_RE = '[^?#]*'
_QUERY_RE = '[^#]*'
_FRAGMENT_RE = '.*'
# Extracted from http://tools.ietf.org/html/rfc3986#appendix-B
COMPONENT_PATTERN_DICT = {
'scheme': SCHEME_RE,
'authority': _AUTHORITY_RE,
'path': _PATH_RE,
'query': _QUERY_RE,
'fragment': _FRAGMENT_RE,
}
# See http://tools.ietf.org/html/rfc3986#appendix-B
# In this case, we name each of the important matches so we can use
# SRE_Match#groupdict to parse the values out if we so choose. This is also
# modified to ignore other matches that are not important to the parsing of
# the reference so we can also simply use SRE_Match#groups.
URL_PARSING_RE = (
r'(?:(?P<scheme>{scheme}):)?(?://(?P<authority>{authority}))?'
r'(?P<path>{path})(?:\?(?P<query>{query}))?'
r'(?:#(?P<fragment>{fragment}))?'
).format(**COMPONENT_PATTERN_DICT)
# #########################
# Authority Matcher Section
# #########################
# Host patterns, see: http://tools.ietf.org/html/rfc3986#section-3.2.2
# The pattern for a regular name, e.g., www.google.com, api.github.com
REGULAR_NAME_RE = REG_NAME = '((?:{0}|[{1}])*)'.format(
'%[0-9A-Fa-f]{2}', SUB_DELIMITERS_RE + UNRESERVED_RE
)
# The pattern for an IPv4 address, e.g., 192.168.255.255, 127.0.0.1,
IPv4_RE = r'([0-9]{1,3}\.){3}[0-9]{1,3}'
# Hexadecimal characters used in each piece of an IPv6 address
HEXDIG_RE = '[0-9A-Fa-f]{1,4}'
# Least-significant 32 bits of an IPv6 address
LS32_RE = '({hex}:{hex}|{ipv4})'.format(hex=HEXDIG_RE, ipv4=IPv4_RE)
# Substitutions into the following patterns for IPv6 patterns defined
# http://tools.ietf.org/html/rfc3986#page-20
_subs = {'hex': HEXDIG_RE, 'ls32': LS32_RE}
# Below: h16 = hexdig, see: https://tools.ietf.org/html/rfc5234 for details
# about ABNF (Augmented Backus-Naur Form) use in the comments
variations = [
# 6( h16 ":" ) ls32
'(%(hex)s:){6}%(ls32)s' % _subs,
# "::" 5( h16 ":" ) ls32
'::(%(hex)s:){5}%(ls32)s' % _subs,
# [ h16 ] "::" 4( h16 ":" ) ls32
'(%(hex)s)?::(%(hex)s:){4}%(ls32)s' % _subs,
# [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
'((%(hex)s:)?%(hex)s)?::(%(hex)s:){3}%(ls32)s' % _subs,
# [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
'((%(hex)s:){0,2}%(hex)s)?::(%(hex)s:){2}%(ls32)s' % _subs,
# [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
'((%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s' % _subs,
# [ *4( h16 ":" ) h16 ] "::" ls32
'((%(hex)s:){0,4}%(hex)s)?::%(ls32)s' % _subs,
# [ *5( h16 ":" ) h16 ] "::" h16
'((%(hex)s:){0,5}%(hex)s)?::%(hex)s' % _subs,
# [ *6( h16 ":" ) h16 ] "::"
'((%(hex)s:){0,6}%(hex)s)?::' % _subs,
]
IPv6_RE = '(({0})|({1})|({2})|({3})|({4})|({5})|({6})|({7})|({8}))'.format(
*variations
)
IPv_FUTURE_RE = r'v[0-9A-Fa-f]+\.[%s]+' % (
UNRESERVED_RE + SUB_DELIMITERS_RE + ':'
)
# RFC 6874 Zone ID ABNF
ZONE_ID = '(?:[' + UNRESERVED_RE + ']|' + PCT_ENCODED + ')+'
IPv6_ADDRZ_RFC4007_RE = IPv6_RE + '(?:(?:%25|%)' + ZONE_ID + ')?'
IPv6_ADDRZ_RE = IPv6_RE + '(?:%25' + ZONE_ID + ')?'
IP_LITERAL_RE = r'\[({0}|{1})\]'.format(
IPv6_ADDRZ_RFC4007_RE,
IPv_FUTURE_RE,
)
# Pattern for matching the host piece of the authority
HOST_RE = HOST_PATTERN = '({0}|{1}|{2})'.format(
REG_NAME,
IPv4_RE,
IP_LITERAL_RE,
)
USERINFO_RE = '^([' + UNRESERVED_RE + SUB_DELIMITERS_RE + ':]|%s)+' % (
PCT_ENCODED
)
PORT_RE = '[0-9]{1,5}'
# ####################
# Path Matcher Section
# ####################
# See http://tools.ietf.org/html/rfc3986#section-3.3 for more information
# about the path patterns defined below.
segments = {
'segment': PCHAR + '*',
# Non-zero length segment
'segment-nz': PCHAR + '+',
# Non-zero length segment without ":"
'segment-nz-nc': PCHAR.replace(':', '') + '+'
}
# Path types taken from Section 3.3 (linked above)
PATH_EMPTY = '^$'
PATH_ROOTLESS = '%(segment-nz)s(/%(segment)s)*' % segments
PATH_NOSCHEME = '%(segment-nz-nc)s(/%(segment)s)*' % segments
PATH_ABSOLUTE = '/(%s)?' % PATH_ROOTLESS
PATH_ABEMPTY = '(/%(segment)s)*' % segments
PATH_RE = '^(%s|%s|%s|%s|%s)$' % (
PATH_ABEMPTY, PATH_ABSOLUTE, PATH_NOSCHEME, PATH_ROOTLESS, PATH_EMPTY
)
FRAGMENT_RE = QUERY_RE = (
'^([/?:@' + UNRESERVED_RE + SUB_DELIMITERS_RE + ']|%s)*$' % PCT_ENCODED
)
# ##########################
# Relative reference matcher
# ##########################
# See http://tools.ietf.org/html/rfc3986#section-4.2 for details
RELATIVE_PART_RE = '(//%s%s|%s|%s|%s)' % (
COMPONENT_PATTERN_DICT['authority'],
PATH_ABEMPTY,
PATH_ABSOLUTE,
PATH_NOSCHEME,
PATH_EMPTY,
)
# See http://tools.ietf.org/html/rfc3986#section-3 for definition
HIER_PART_RE = '(//%s%s|%s|%s|%s)' % (
COMPONENT_PATTERN_DICT['authority'],
PATH_ABEMPTY,
PATH_ABSOLUTE,
PATH_ROOTLESS,
PATH_EMPTY,
)
# ###############
# IRIs / RFC 3987
# ###############
# Only wide-unicode gets the high-ranges of UCSCHAR
if sys.maxunicode > 0xFFFF: # pragma: no cover
IPRIVATE = u'\uE000-\uF8FF\U000F0000-\U000FFFFD\U00100000-\U0010FFFD'
UCSCHAR_RE = (
u'\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF'
u'\U00010000-\U0001FFFD\U00020000-\U0002FFFD'
u'\U00030000-\U0003FFFD\U00040000-\U0004FFFD'
u'\U00050000-\U0005FFFD\U00060000-\U0006FFFD'
u'\U00070000-\U0007FFFD\U00080000-\U0008FFFD'
u'\U00090000-\U0009FFFD\U000A0000-\U000AFFFD'
u'\U000B0000-\U000BFFFD\U000C0000-\U000CFFFD'
u'\U000D0000-\U000DFFFD\U000E1000-\U000EFFFD'
)
else: # pragma: no cover
IPRIVATE = u'\uE000-\uF8FF'
UCSCHAR_RE = (
u'\u00A0-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF'
)
IUNRESERVED_RE = u'A-Za-z0-9\\._~\\-' + UCSCHAR_RE
IPCHAR = u'([' + IUNRESERVED_RE + SUB_DELIMITERS_RE + u':@]|%s)' % PCT_ENCODED
isegments = {
'isegment': IPCHAR + u'*',
# Non-zero length segment
'isegment-nz': IPCHAR + u'+',
# Non-zero length segment without ":"
'isegment-nz-nc': IPCHAR.replace(':', '') + u'+'
}
IPATH_ROOTLESS = u'%(isegment-nz)s(/%(isegment)s)*' % isegments
IPATH_NOSCHEME = u'%(isegment-nz-nc)s(/%(isegment)s)*' % isegments
IPATH_ABSOLUTE = u'/(?:%s)?' % IPATH_ROOTLESS
IPATH_ABEMPTY = u'(?:/%(isegment)s)*' % isegments
IPATH_RE = u'^(?:%s|%s|%s|%s|%s)$' % (
IPATH_ABEMPTY, IPATH_ABSOLUTE, IPATH_NOSCHEME, IPATH_ROOTLESS, PATH_EMPTY
)
IREGULAR_NAME_RE = IREG_NAME = u'(?:{0}|[{1}])*'.format(
u'%[0-9A-Fa-f]{2}', SUB_DELIMITERS_RE + IUNRESERVED_RE
)
IHOST_RE = IHOST_PATTERN = u'({0}|{1}|{2})'.format(
IREG_NAME,
IPv4_RE,
IP_LITERAL_RE,
)
IUSERINFO_RE = u'^(?:[' + IUNRESERVED_RE + SUB_DELIMITERS_RE + u':]|%s)+' % (
PCT_ENCODED
)
IFRAGMENT_RE = (u'^(?:[/?:@' + IUNRESERVED_RE + SUB_DELIMITERS_RE
+ u']|%s)*$' % PCT_ENCODED)
IQUERY_RE = (u'^(?:[/?:@' + IUNRESERVED_RE + SUB_DELIMITERS_RE
+ IPRIVATE + u']|%s)*$' % PCT_ENCODED)
IRELATIVE_PART_RE = u'(//%s%s|%s|%s|%s)' % (
COMPONENT_PATTERN_DICT['authority'],
IPATH_ABEMPTY,
IPATH_ABSOLUTE,
IPATH_NOSCHEME,
PATH_EMPTY,
)
IHIER_PART_RE = u'(//%s%s|%s|%s|%s)' % (
COMPONENT_PATTERN_DICT['authority'],
IPATH_ABEMPTY,
IPATH_ABSOLUTE,
IPATH_ROOTLESS,
PATH_EMPTY,
)

View File

@@ -1,106 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing the simple and functional API for rfc3986.
This module defines functions and provides access to the public attributes
and classes of rfc3986.
"""
from .iri import IRIReference
from .parseresult import ParseResult
from .uri import URIReference
def uri_reference(uri, encoding='utf-8'):
"""Parse a URI string into a URIReference.
This is a convenience function. You could achieve the same end by using
``URIReference.from_string(uri)``.
:param str uri: The URI which needs to be parsed into a reference.
:param str encoding: The encoding of the string provided
:returns: A parsed URI
:rtype: :class:`URIReference`
"""
return URIReference.from_string(uri, encoding)
def iri_reference(iri, encoding='utf-8'):
"""Parse a IRI string into an IRIReference.
This is a convenience function. You could achieve the same end by using
``IRIReference.from_string(iri)``.
:param str iri: The IRI which needs to be parsed into a reference.
:param str encoding: The encoding of the string provided
:returns: A parsed IRI
:rtype: :class:`IRIReference`
"""
return IRIReference.from_string(iri, encoding)
def is_valid_uri(uri, encoding='utf-8', **kwargs):
"""Determine if the URI given is valid.
This is a convenience function. You could use either
``uri_reference(uri).is_valid()`` or
``URIReference.from_string(uri).is_valid()`` to achieve the same result.
:param str uri: The URI to be validated.
:param str encoding: The encoding of the string provided
:param bool require_scheme: Set to ``True`` if you wish to require the
presence of the scheme component.
:param bool require_authority: Set to ``True`` if you wish to require the
presence of the authority component.
:param bool require_path: Set to ``True`` if you wish to require the
presence of the path component.
:param bool require_query: Set to ``True`` if you wish to require the
presence of the query component.
:param bool require_fragment: Set to ``True`` if you wish to require the
presence of the fragment component.
:returns: ``True`` if the URI is valid, ``False`` otherwise.
:rtype: bool
"""
return URIReference.from_string(uri, encoding).is_valid(**kwargs)
def normalize_uri(uri, encoding='utf-8'):
"""Normalize the given URI.
This is a convenience function. You could use either
``uri_reference(uri).normalize().unsplit()`` or
``URIReference.from_string(uri).normalize().unsplit()`` instead.
:param str uri: The URI to be normalized.
:param str encoding: The encoding of the string provided
:returns: The normalized URI.
:rtype: str
"""
normalized_reference = URIReference.from_string(uri, encoding).normalize()
return normalized_reference.unsplit()
def urlparse(uri, encoding='utf-8'):
"""Parse a given URI and return a ParseResult.
This is a partial replacement of the standard library's urlparse function.
:param str uri: The URI to be parsed.
:param str encoding: The encoding of the string provided.
:returns: A parsed URI
:rtype: :class:`~rfc3986.parseresult.ParseResult`
"""
return ParseResult.from_string(uri, encoding, strict=False)

View File

@@ -1,298 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2017 Ian Stapleton Cordasco
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the logic for the URIBuilder object."""
from . import compat
from . import normalizers
from . import uri
class URIBuilder(object):
"""Object to aid in building up a URI Reference from parts.
.. note::
This object should be instantiated by the user, but it's recommended
that it is not provided with arguments. Instead, use the available
method to populate the fields.
"""
def __init__(self, scheme=None, userinfo=None, host=None, port=None,
path=None, query=None, fragment=None):
"""Initialize our URI builder.
:param str scheme:
(optional)
:param str userinfo:
(optional)
:param str host:
(optional)
:param int port:
(optional)
:param str path:
(optional)
:param str query:
(optional)
:param str fragment:
(optional)
"""
self.scheme = scheme
self.userinfo = userinfo
self.host = host
self.port = port
self.path = path
self.query = query
self.fragment = fragment
def __repr__(self):
"""Provide a convenient view of our builder object."""
formatstr = ('URIBuilder(scheme={b.scheme}, userinfo={b.userinfo}, '
'host={b.host}, port={b.port}, path={b.path}, '
'query={b.query}, fragment={b.fragment})')
return formatstr.format(b=self)
def add_scheme(self, scheme):
"""Add a scheme to our builder object.
After normalizing, this will generate a new URIBuilder instance with
the specified scheme and all other attributes the same.
.. code-block:: python
>>> URIBuilder().add_scheme('HTTPS')
URIBuilder(scheme='https', userinfo=None, host=None, port=None,
path=None, query=None, fragment=None)
"""
scheme = normalizers.normalize_scheme(scheme)
return URIBuilder(
scheme=scheme,
userinfo=self.userinfo,
host=self.host,
port=self.port,
path=self.path,
query=self.query,
fragment=self.fragment,
)
def add_credentials(self, username, password):
"""Add credentials as the userinfo portion of the URI.
.. code-block:: python
>>> URIBuilder().add_credentials('root', 's3crete')
URIBuilder(scheme=None, userinfo='root:s3crete', host=None,
port=None, path=None, query=None, fragment=None)
>>> URIBuilder().add_credentials('root', None)
URIBuilder(scheme=None, userinfo='root', host=None,
port=None, path=None, query=None, fragment=None)
"""
if username is None:
raise ValueError('Username cannot be None')
userinfo = normalizers.normalize_username(username)
if password is not None:
userinfo = '{}:{}'.format(
userinfo,
normalizers.normalize_password(password),
)
return URIBuilder(
scheme=self.scheme,
userinfo=userinfo,
host=self.host,
port=self.port,
path=self.path,
query=self.query,
fragment=self.fragment,
)
def add_host(self, host):
"""Add hostname to the URI.
.. code-block:: python
>>> URIBuilder().add_host('google.com')
URIBuilder(scheme=None, userinfo=None, host='google.com',
port=None, path=None, query=None, fragment=None)
"""
return URIBuilder(
scheme=self.scheme,
userinfo=self.userinfo,
host=normalizers.normalize_host(host),
port=self.port,
path=self.path,
query=self.query,
fragment=self.fragment,
)
def add_port(self, port):
"""Add port to the URI.
.. code-block:: python
>>> URIBuilder().add_port(80)
URIBuilder(scheme=None, userinfo=None, host=None, port='80',
path=None, query=None, fragment=None)
>>> URIBuilder().add_port(443)
URIBuilder(scheme=None, userinfo=None, host=None, port='443',
path=None, query=None, fragment=None)
"""
port_int = int(port)
if port_int < 0:
raise ValueError(
'ports are not allowed to be negative. You provided {}'.format(
port_int,
)
)
if port_int > 65535:
raise ValueError(
'ports are not allowed to be larger than 65535. '
'You provided {}'.format(
port_int,
)
)
return URIBuilder(
scheme=self.scheme,
userinfo=self.userinfo,
host=self.host,
port='{}'.format(port_int),
path=self.path,
query=self.query,
fragment=self.fragment,
)
def add_path(self, path):
"""Add a path to the URI.
.. code-block:: python
>>> URIBuilder().add_path('sigmavirus24/rfc3985')
URIBuilder(scheme=None, userinfo=None, host=None, port=None,
path='/sigmavirus24/rfc3986', query=None, fragment=None)
>>> URIBuilder().add_path('/checkout.php')
URIBuilder(scheme=None, userinfo=None, host=None, port=None,
path='/checkout.php', query=None, fragment=None)
"""
if not path.startswith('/'):
path = '/{}'.format(path)
return URIBuilder(
scheme=self.scheme,
userinfo=self.userinfo,
host=self.host,
port=self.port,
path=normalizers.normalize_path(path),
query=self.query,
fragment=self.fragment,
)
def add_query_from(self, query_items):
"""Generate and add a query a dictionary or list of tuples.
.. code-block:: python
>>> URIBuilder().add_query_from({'a': 'b c'})
URIBuilder(scheme=None, userinfo=None, host=None, port=None,
path=None, query='a=b+c', fragment=None)
>>> URIBuilder().add_query_from([('a', 'b c')])
URIBuilder(scheme=None, userinfo=None, host=None, port=None,
path=None, query='a=b+c', fragment=None)
"""
query = normalizers.normalize_query(compat.urlencode(query_items))
return URIBuilder(
scheme=self.scheme,
userinfo=self.userinfo,
host=self.host,
port=self.port,
path=self.path,
query=query,
fragment=self.fragment,
)
def add_query(self, query):
"""Add a pre-formated query string to the URI.
.. code-block:: python
>>> URIBuilder().add_query('a=b&c=d')
URIBuilder(scheme=None, userinfo=None, host=None, port=None,
path=None, query='a=b&c=d', fragment=None)
"""
return URIBuilder(
scheme=self.scheme,
userinfo=self.userinfo,
host=self.host,
port=self.port,
path=self.path,
query=normalizers.normalize_query(query),
fragment=self.fragment,
)
def add_fragment(self, fragment):
"""Add a fragment to the URI.
.. code-block:: python
>>> URIBuilder().add_fragment('section-2.6.1')
URIBuilder(scheme=None, userinfo=None, host=None, port=None,
path=None, query=None, fragment='section-2.6.1')
"""
return URIBuilder(
scheme=self.scheme,
userinfo=self.userinfo,
host=self.host,
port=self.port,
path=self.path,
query=self.query,
fragment=normalizers.normalize_fragment(fragment),
)
def finalize(self):
"""Create a URIReference from our builder.
.. code-block:: python
>>> URIBuilder().add_scheme('https').add_host('github.com'
... ).add_path('sigmavirus24/rfc3986').finalize().unsplit()
'https://github.com/sigmavirus24/rfc3986'
>>> URIBuilder().add_scheme('https').add_host('github.com'
... ).add_path('sigmavirus24/rfc3986').add_credentials(
... 'sigmavirus24', 'not-re@l').finalize().unsplit()
'https://sigmavirus24:not-re%40l@github.com/sigmavirus24/rfc3986'
"""
return uri.URIReference(
self.scheme,
normalizers.normalize_authority(
(self.userinfo, self.host, self.port)
),
self.path,
self.query,
self.fragment,
)

View File

@@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compatibility module for Python 2 and 3 support."""
import sys
try:
from urllib.parse import quote as urlquote
except ImportError: # Python 2.x
from urllib import quote as urlquote
try:
from urllib.parse import urlencode
except ImportError: # Python 2.x
from urllib import urlencode
__all__ = (
'to_bytes',
'to_str',
'urlquote',
'urlencode',
)
PY3 = (3, 0) <= sys.version_info < (4, 0)
PY2 = (2, 6) <= sys.version_info < (2, 8)
if PY3:
unicode = str # Python 3.x
def to_str(b, encoding='utf-8'):
"""Ensure that b is text in the specified encoding."""
if hasattr(b, 'decode') and not isinstance(b, unicode):
b = b.decode(encoding)
return b
def to_bytes(s, encoding='utf-8'):
"""Ensure that s is converted to bytes from the encoding."""
if hasattr(s, 'encode') and not isinstance(s, bytes):
s = s.encode(encoding)
return s

View File

@@ -1,118 +0,0 @@
# -*- coding: utf-8 -*-
"""Exceptions module for rfc3986."""
from . import compat
class RFC3986Exception(Exception):
"""Base class for all rfc3986 exception classes."""
pass
class InvalidAuthority(RFC3986Exception):
"""Exception when the authority string is invalid."""
def __init__(self, authority):
"""Initialize the exception with the invalid authority."""
super(InvalidAuthority, self).__init__(
u"The authority ({0}) is not valid.".format(
compat.to_str(authority)))
class InvalidPort(RFC3986Exception):
"""Exception when the port is invalid."""
def __init__(self, port):
"""Initialize the exception with the invalid port."""
super(InvalidPort, self).__init__(
'The port ("{0}") is not valid.'.format(port))
class ResolutionError(RFC3986Exception):
"""Exception to indicate a failure to resolve a URI."""
def __init__(self, uri):
"""Initialize the error with the failed URI."""
super(ResolutionError, self).__init__(
"{0} is not an absolute URI.".format(uri.unsplit()))
class ValidationError(RFC3986Exception):
"""Exception raised during Validation of a URI."""
pass
class MissingComponentError(ValidationError):
"""Exception raised when a required component is missing."""
def __init__(self, uri, *component_names):
"""Initialize the error with the missing component name."""
verb = 'was'
if len(component_names) > 1:
verb = 'were'
self.uri = uri
self.components = sorted(component_names)
components = ', '.join(self.components)
super(MissingComponentError, self).__init__(
"{} {} required but missing".format(components, verb),
uri,
self.components,
)
class UnpermittedComponentError(ValidationError):
"""Exception raised when a component has an unpermitted value."""
def __init__(self, component_name, component_value, allowed_values):
"""Initialize the error with the unpermitted component."""
super(UnpermittedComponentError, self).__init__(
"{} was required to be one of {!r} but was {!r}".format(
component_name, list(sorted(allowed_values)), component_value,
),
component_name,
component_value,
allowed_values,
)
self.component_name = component_name
self.component_value = component_value
self.allowed_values = allowed_values
class PasswordForbidden(ValidationError):
"""Exception raised when a URL has a password in the userinfo section."""
def __init__(self, uri):
"""Initialize the error with the URI that failed validation."""
unsplit = getattr(uri, 'unsplit', lambda: uri)
super(PasswordForbidden, self).__init__(
'"{}" contained a password when validation forbade it'.format(
unsplit()
)
)
self.uri = uri
class InvalidComponentsError(ValidationError):
"""Exception raised when one or more components are invalid."""
def __init__(self, uri, *component_names):
"""Initialize the error with the invalid component name(s)."""
verb = 'was'
if len(component_names) > 1:
verb = 'were'
self.uri = uri
self.components = sorted(component_names)
components = ', '.join(self.components)
super(InvalidComponentsError, self).__init__(
"{} {} found to be invalid".format(components, verb),
uri,
self.components,
)
class MissingDependencyError(RFC3986Exception):
"""Exception raised when an IRI is encoded without the 'idna' module."""

View File

@@ -1,147 +0,0 @@
"""Module containing the implementation of the IRIReference class."""
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Copyright (c) 2015 Ian Stapleton Cordasco
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from . import compat
from . import exceptions
from . import misc
from . import normalizers
from . import uri
try:
from pip._vendor import idna
except ImportError: # pragma: no cover
idna = None
class IRIReference(namedtuple('IRIReference', misc.URI_COMPONENTS),
uri.URIMixin):
"""Immutable object representing a parsed IRI Reference.
Can be encoded into an URIReference object via the procedure
specified in RFC 3987 Section 3.1
.. note::
The IRI submodule is a new interface and may possibly change in
the future. Check for changes to the interface when upgrading.
"""
slots = ()
def __new__(cls, scheme, authority, path, query, fragment,
encoding='utf-8'):
"""Create a new IRIReference."""
ref = super(IRIReference, cls).__new__(
cls,
scheme or None,
authority or None,
path or None,
query,
fragment)
ref.encoding = encoding
return ref
def __eq__(self, other):
"""Compare this reference to another."""
other_ref = other
if isinstance(other, tuple):
other_ref = self.__class__(*other)
elif not isinstance(other, IRIReference):
try:
other_ref = self.__class__.from_string(other)
except TypeError:
raise TypeError(
'Unable to compare {0}() to {1}()'.format(
type(self).__name__, type(other).__name__))
# See http://tools.ietf.org/html/rfc3986#section-6.2
return tuple(self) == tuple(other_ref)
def _match_subauthority(self):
return misc.ISUBAUTHORITY_MATCHER.match(self.authority)
@classmethod
def from_string(cls, iri_string, encoding='utf-8'):
"""Parse a IRI reference from the given unicode IRI string.
:param str iri_string: Unicode IRI to be parsed into a reference.
:param str encoding: The encoding of the string provided
:returns: :class:`IRIReference` or subclass thereof
"""
iri_string = compat.to_str(iri_string, encoding)
split_iri = misc.IRI_MATCHER.match(iri_string).groupdict()
return cls(
split_iri['scheme'], split_iri['authority'],
normalizers.encode_component(split_iri['path'], encoding),
normalizers.encode_component(split_iri['query'], encoding),
normalizers.encode_component(split_iri['fragment'], encoding),
encoding,
)
def encode(self, idna_encoder=None): # noqa: C901
"""Encode an IRIReference into a URIReference instance.
If the ``idna`` module is installed or the ``rfc3986[idna]``
extra is used then unicode characters in the IRI host
component will be encoded with IDNA2008.
:param idna_encoder:
Function that encodes each part of the host component
If not given will raise an exception if the IRI
contains a host component.
:rtype: uri.URIReference
:returns: A URI reference
"""
authority = self.authority
if authority:
if idna_encoder is None:
if idna is None: # pragma: no cover
raise exceptions.MissingDependencyError(
"Could not import the 'idna' module "
"and the IRI hostname requires encoding"
)
def idna_encoder(name):
if any(ord(c) > 128 for c in name):
try:
return idna.encode(name.lower(),
strict=True,
std3_rules=True)
except idna.IDNAError:
raise exceptions.InvalidAuthority(self.authority)
return name
authority = ""
if self.host:
authority = ".".join([compat.to_str(idna_encoder(part))
for part in self.host.split(".")])
if self.userinfo is not None:
authority = (normalizers.encode_component(
self.userinfo, self.encoding) + '@' + authority)
if self.port is not None:
authority += ":" + str(self.port)
return uri.URIReference(self.scheme,
authority,
path=self.path,
query=self.query,
fragment=self.fragment,
encoding=self.encoding)

View File

@@ -1,124 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing compiled regular expressions and constants.
This module contains important constants, patterns, and compiled regular
expressions for parsing and validating URIs and their components.
"""
import re
from . import abnf_regexp
# These are enumerated for the named tuple used as a superclass of
# URIReference
URI_COMPONENTS = ['scheme', 'authority', 'path', 'query', 'fragment']
important_characters = {
'generic_delimiters': abnf_regexp.GENERIC_DELIMITERS,
'sub_delimiters': abnf_regexp.SUB_DELIMITERS,
# We need to escape the '*' in this case
're_sub_delimiters': abnf_regexp.SUB_DELIMITERS_RE,
'unreserved_chars': abnf_regexp.UNRESERVED_CHARS,
# We need to escape the '-' in this case:
're_unreserved': abnf_regexp.UNRESERVED_RE,
}
# For details about delimiters and reserved characters, see:
# http://tools.ietf.org/html/rfc3986#section-2.2
GENERIC_DELIMITERS = abnf_regexp.GENERIC_DELIMITERS_SET
SUB_DELIMITERS = abnf_regexp.SUB_DELIMITERS_SET
RESERVED_CHARS = abnf_regexp.RESERVED_CHARS_SET
# For details about unreserved characters, see:
# http://tools.ietf.org/html/rfc3986#section-2.3
UNRESERVED_CHARS = abnf_regexp.UNRESERVED_CHARS_SET
NON_PCT_ENCODED = abnf_regexp.NON_PCT_ENCODED_SET
URI_MATCHER = re.compile(abnf_regexp.URL_PARSING_RE)
SUBAUTHORITY_MATCHER = re.compile((
'^(?:(?P<userinfo>{0})@)?' # userinfo
'(?P<host>{1})' # host
':?(?P<port>{2})?$' # port
).format(abnf_regexp.USERINFO_RE,
abnf_regexp.HOST_PATTERN,
abnf_regexp.PORT_RE))
HOST_MATCHER = re.compile('^' + abnf_regexp.HOST_RE + '$')
IPv4_MATCHER = re.compile('^' + abnf_regexp.IPv4_RE + '$')
IPv6_MATCHER = re.compile(r'^\[' + abnf_regexp.IPv6_ADDRZ_RFC4007_RE + r'\]$')
# Used by host validator
IPv6_NO_RFC4007_MATCHER = re.compile(r'^\[%s\]$' % (
abnf_regexp.IPv6_ADDRZ_RE
))
# Matcher used to validate path components
PATH_MATCHER = re.compile(abnf_regexp.PATH_RE)
# ##################################
# Query and Fragment Matcher Section
# ##################################
QUERY_MATCHER = re.compile(abnf_regexp.QUERY_RE)
FRAGMENT_MATCHER = QUERY_MATCHER
# Scheme validation, see: http://tools.ietf.org/html/rfc3986#section-3.1
SCHEME_MATCHER = re.compile('^{0}$'.format(abnf_regexp.SCHEME_RE))
RELATIVE_REF_MATCHER = re.compile(r'^%s(\?%s)?(#%s)?$' % (
abnf_regexp.RELATIVE_PART_RE,
abnf_regexp.QUERY_RE,
abnf_regexp.FRAGMENT_RE,
))
# See http://tools.ietf.org/html/rfc3986#section-4.3
ABSOLUTE_URI_MATCHER = re.compile(r'^%s:%s(\?%s)?$' % (
abnf_regexp.COMPONENT_PATTERN_DICT['scheme'],
abnf_regexp.HIER_PART_RE,
abnf_regexp.QUERY_RE[1:-1],
))
# ###############
# IRIs / RFC 3987
# ###############
IRI_MATCHER = re.compile(abnf_regexp.URL_PARSING_RE, re.UNICODE)
ISUBAUTHORITY_MATCHER = re.compile((
u'^(?:(?P<userinfo>{0})@)?' # iuserinfo
u'(?P<host>{1})' # ihost
u':?(?P<port>{2})?$' # port
).format(abnf_regexp.IUSERINFO_RE,
abnf_regexp.IHOST_RE,
abnf_regexp.PORT_RE), re.UNICODE)
# Path merger as defined in http://tools.ietf.org/html/rfc3986#section-5.2.3
def merge_paths(base_uri, relative_path):
"""Merge a base URI's path with a relative URI's path."""
if base_uri.path is None and base_uri.authority is not None:
return '/' + relative_path
else:
path = base_uri.path or ''
index = path.rfind('/')
return path[:index] + '/' + relative_path
UseExisting = object()

View File

@@ -1,167 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module with functions to normalize components."""
import re
from . import compat
from . import misc
def normalize_scheme(scheme):
"""Normalize the scheme component."""
return scheme.lower()
def normalize_authority(authority):
"""Normalize an authority tuple to a string."""
userinfo, host, port = authority
result = ''
if userinfo:
result += normalize_percent_characters(userinfo) + '@'
if host:
result += normalize_host(host)
if port:
result += ':' + port
return result
def normalize_username(username):
"""Normalize a username to make it safe to include in userinfo."""
return compat.urlquote(username)
def normalize_password(password):
"""Normalize a password to make safe for userinfo."""
return compat.urlquote(password)
def normalize_host(host):
"""Normalize a host string."""
if misc.IPv6_MATCHER.match(host):
percent = host.find('%')
if percent != -1:
percent_25 = host.find('%25')
# Replace RFC 4007 IPv6 Zone ID delimiter '%' with '%25'
# from RFC 6874. If the host is '[<IPv6 addr>%25]' then we
# assume RFC 4007 and normalize to '[<IPV6 addr>%2525]'
if percent_25 == -1 or percent < percent_25 or \
(percent == percent_25 and percent_25 == len(host) - 4):
host = host.replace('%', '%25', 1)
# Don't normalize the casing of the Zone ID
return host[:percent].lower() + host[percent:]
return host.lower()
def normalize_path(path):
"""Normalize the path string."""
if not path:
return path
path = normalize_percent_characters(path)
return remove_dot_segments(path)
def normalize_query(query):
"""Normalize the query string."""
if not query:
return query
return normalize_percent_characters(query)
def normalize_fragment(fragment):
"""Normalize the fragment string."""
if not fragment:
return fragment
return normalize_percent_characters(fragment)
PERCENT_MATCHER = re.compile('%[A-Fa-f0-9]{2}')
def normalize_percent_characters(s):
"""All percent characters should be upper-cased.
For example, ``"%3afoo%DF%ab"`` should be turned into ``"%3Afoo%DF%AB"``.
"""
matches = set(PERCENT_MATCHER.findall(s))
for m in matches:
if not m.isupper():
s = s.replace(m, m.upper())
return s
def remove_dot_segments(s):
"""Remove dot segments from the string.
See also Section 5.2.4 of :rfc:`3986`.
"""
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
segments = s.split('/') # Turn the path into a list of segments
output = [] # Initialize the variable to use to store output
for segment in segments:
# '.' is the current directory, so ignore it, it is superfluous
if segment == '.':
continue
# Anything other than '..', should be appended to the output
elif segment != '..':
output.append(segment)
# In this case segment == '..', if we can, we should pop the last
# element
elif output:
output.pop()
# If the path starts with '/' and the output is empty or the first string
# is non-empty
if s.startswith('/') and (not output or output[0]):
output.insert(0, '')
# If the path starts with '/.' or '/..' ensure we add one more empty
# string to add a trailing '/'
if s.endswith(('/.', '/..')):
output.append('')
return '/'.join(output)
def encode_component(uri_component, encoding):
"""Encode the specific component in the provided encoding."""
if uri_component is None:
return uri_component
# Try to see if the component we're encoding is already percent-encoded
# so we can skip all '%' characters but still encode all others.
percent_encodings = len(PERCENT_MATCHER.findall(
compat.to_str(uri_component, encoding)))
uri_bytes = compat.to_bytes(uri_component, encoding)
is_percent_encoded = percent_encodings == uri_bytes.count(b'%')
encoded_uri = bytearray()
for i in range(0, len(uri_bytes)):
# Will return a single character bytestring on both Python 2 & 3
byte = uri_bytes[i:i+1]
byte_ord = ord(byte)
if ((is_percent_encoded and byte == b'%')
or (byte_ord < 128 and byte.decode() in misc.NON_PCT_ENCODED)):
encoded_uri.extend(byte)
continue
encoded_uri.extend('%{0:02x}'.format(byte_ord).encode().upper())
return encoded_uri.decode(encoding)

View File

@@ -1,385 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2015 Ian Stapleton Cordasco
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the urlparse compatibility logic."""
from collections import namedtuple
from . import compat
from . import exceptions
from . import misc
from . import normalizers
from . import uri
__all__ = ('ParseResult', 'ParseResultBytes')
PARSED_COMPONENTS = ('scheme', 'userinfo', 'host', 'port', 'path', 'query',
'fragment')
class ParseResultMixin(object):
def _generate_authority(self, attributes):
# I swear I did not align the comparisons below. That's just how they
# happened to align based on pep8 and attribute lengths.
userinfo, host, port = (attributes[p]
for p in ('userinfo', 'host', 'port'))
if (self.userinfo != userinfo or
self.host != host or
self.port != port):
if port:
port = '{0}'.format(port)
return normalizers.normalize_authority(
(compat.to_str(userinfo, self.encoding),
compat.to_str(host, self.encoding),
port)
)
return self.authority
def geturl(self):
"""Shim to match the standard library method."""
return self.unsplit()
@property
def hostname(self):
"""Shim to match the standard library."""
return self.host
@property
def netloc(self):
"""Shim to match the standard library."""
return self.authority
@property
def params(self):
"""Shim to match the standard library."""
return self.query
class ParseResult(namedtuple('ParseResult', PARSED_COMPONENTS),
ParseResultMixin):
"""Implementation of urlparse compatibility class.
This uses the URIReference logic to handle compatibility with the
urlparse.ParseResult class.
"""
slots = ()
def __new__(cls, scheme, userinfo, host, port, path, query, fragment,
uri_ref, encoding='utf-8'):
"""Create a new ParseResult."""
parse_result = super(ParseResult, cls).__new__(
cls,
scheme or None,
userinfo or None,
host,
port or None,
path or None,
query,
fragment)
parse_result.encoding = encoding
parse_result.reference = uri_ref
return parse_result
@classmethod
def from_parts(cls, scheme=None, userinfo=None, host=None, port=None,
path=None, query=None, fragment=None, encoding='utf-8'):
"""Create a ParseResult instance from its parts."""
authority = ''
if userinfo is not None:
authority += userinfo + '@'
if host is not None:
authority += host
if port is not None:
authority += ':{0}'.format(port)
uri_ref = uri.URIReference(scheme=scheme,
authority=authority,
path=path,
query=query,
fragment=fragment,
encoding=encoding).normalize()
userinfo, host, port = authority_from(uri_ref, strict=True)
return cls(scheme=uri_ref.scheme,
userinfo=userinfo,
host=host,
port=port,
path=uri_ref.path,
query=uri_ref.query,
fragment=uri_ref.fragment,
uri_ref=uri_ref,
encoding=encoding)
@classmethod
def from_string(cls, uri_string, encoding='utf-8', strict=True,
lazy_normalize=True):
"""Parse a URI from the given unicode URI string.
:param str uri_string: Unicode URI to be parsed into a reference.
:param str encoding: The encoding of the string provided
:param bool strict: Parse strictly according to :rfc:`3986` if True.
If False, parse similarly to the standard library's urlparse
function.
:returns: :class:`ParseResult` or subclass thereof
"""
reference = uri.URIReference.from_string(uri_string, encoding)
if not lazy_normalize:
reference = reference.normalize()
userinfo, host, port = authority_from(reference, strict)
return cls(scheme=reference.scheme,
userinfo=userinfo,
host=host,
port=port,
path=reference.path,
query=reference.query,
fragment=reference.fragment,
uri_ref=reference,
encoding=encoding)
@property
def authority(self):
"""Return the normalized authority."""
return self.reference.authority
def copy_with(self, scheme=misc.UseExisting, userinfo=misc.UseExisting,
host=misc.UseExisting, port=misc.UseExisting,
path=misc.UseExisting, query=misc.UseExisting,
fragment=misc.UseExisting):
"""Create a copy of this instance replacing with specified parts."""
attributes = zip(PARSED_COMPONENTS,
(scheme, userinfo, host, port, path, query, fragment))
attrs_dict = {}
for name, value in attributes:
if value is misc.UseExisting:
value = getattr(self, name)
attrs_dict[name] = value
authority = self._generate_authority(attrs_dict)
ref = self.reference.copy_with(scheme=attrs_dict['scheme'],
authority=authority,
path=attrs_dict['path'],
query=attrs_dict['query'],
fragment=attrs_dict['fragment'])
return ParseResult(uri_ref=ref, encoding=self.encoding, **attrs_dict)
def encode(self, encoding=None):
"""Convert to an instance of ParseResultBytes."""
encoding = encoding or self.encoding
attrs = dict(
zip(PARSED_COMPONENTS,
(attr.encode(encoding) if hasattr(attr, 'encode') else attr
for attr in self)))
return ParseResultBytes(
uri_ref=self.reference,
encoding=encoding,
**attrs
)
def unsplit(self, use_idna=False):
"""Create a URI string from the components.
:returns: The parsed URI reconstituted as a string.
:rtype: str
"""
parse_result = self
if use_idna and self.host:
hostbytes = self.host.encode('idna')
host = hostbytes.decode(self.encoding)
parse_result = self.copy_with(host=host)
return parse_result.reference.unsplit()
class ParseResultBytes(namedtuple('ParseResultBytes', PARSED_COMPONENTS),
ParseResultMixin):
"""Compatibility shim for the urlparse.ParseResultBytes object."""
def __new__(cls, scheme, userinfo, host, port, path, query, fragment,
uri_ref, encoding='utf-8', lazy_normalize=True):
"""Create a new ParseResultBytes instance."""
parse_result = super(ParseResultBytes, cls).__new__(
cls,
scheme or None,
userinfo or None,
host,
port or None,
path or None,
query or None,
fragment or None)
parse_result.encoding = encoding
parse_result.reference = uri_ref
parse_result.lazy_normalize = lazy_normalize
return parse_result
@classmethod
def from_parts(cls, scheme=None, userinfo=None, host=None, port=None,
path=None, query=None, fragment=None, encoding='utf-8',
lazy_normalize=True):
"""Create a ParseResult instance from its parts."""
authority = ''
if userinfo is not None:
authority += userinfo + '@'
if host is not None:
authority += host
if port is not None:
authority += ':{0}'.format(int(port))
uri_ref = uri.URIReference(scheme=scheme,
authority=authority,
path=path,
query=query,
fragment=fragment,
encoding=encoding)
if not lazy_normalize:
uri_ref = uri_ref.normalize()
to_bytes = compat.to_bytes
userinfo, host, port = authority_from(uri_ref, strict=True)
return cls(scheme=to_bytes(scheme, encoding),
userinfo=to_bytes(userinfo, encoding),
host=to_bytes(host, encoding),
port=port,
path=to_bytes(path, encoding),
query=to_bytes(query, encoding),
fragment=to_bytes(fragment, encoding),
uri_ref=uri_ref,
encoding=encoding,
lazy_normalize=lazy_normalize)
@classmethod
def from_string(cls, uri_string, encoding='utf-8', strict=True,
lazy_normalize=True):
"""Parse a URI from the given unicode URI string.
:param str uri_string: Unicode URI to be parsed into a reference.
:param str encoding: The encoding of the string provided
:param bool strict: Parse strictly according to :rfc:`3986` if True.
If False, parse similarly to the standard library's urlparse
function.
:returns: :class:`ParseResultBytes` or subclass thereof
"""
reference = uri.URIReference.from_string(uri_string, encoding)
if not lazy_normalize:
reference = reference.normalize()
userinfo, host, port = authority_from(reference, strict)
to_bytes = compat.to_bytes
return cls(scheme=to_bytes(reference.scheme, encoding),
userinfo=to_bytes(userinfo, encoding),
host=to_bytes(host, encoding),
port=port,
path=to_bytes(reference.path, encoding),
query=to_bytes(reference.query, encoding),
fragment=to_bytes(reference.fragment, encoding),
uri_ref=reference,
encoding=encoding,
lazy_normalize=lazy_normalize)
@property
def authority(self):
"""Return the normalized authority."""
return self.reference.authority.encode(self.encoding)
def copy_with(self, scheme=misc.UseExisting, userinfo=misc.UseExisting,
host=misc.UseExisting, port=misc.UseExisting,
path=misc.UseExisting, query=misc.UseExisting,
fragment=misc.UseExisting, lazy_normalize=True):
"""Create a copy of this instance replacing with specified parts."""
attributes = zip(PARSED_COMPONENTS,
(scheme, userinfo, host, port, path, query, fragment))
attrs_dict = {}
for name, value in attributes:
if value is misc.UseExisting:
value = getattr(self, name)
if not isinstance(value, bytes) and hasattr(value, 'encode'):
value = value.encode(self.encoding)
attrs_dict[name] = value
authority = self._generate_authority(attrs_dict)
to_str = compat.to_str
ref = self.reference.copy_with(
scheme=to_str(attrs_dict['scheme'], self.encoding),
authority=to_str(authority, self.encoding),
path=to_str(attrs_dict['path'], self.encoding),
query=to_str(attrs_dict['query'], self.encoding),
fragment=to_str(attrs_dict['fragment'], self.encoding)
)
if not lazy_normalize:
ref = ref.normalize()
return ParseResultBytes(
uri_ref=ref,
encoding=self.encoding,
lazy_normalize=lazy_normalize,
**attrs_dict
)
def unsplit(self, use_idna=False):
"""Create a URI bytes object from the components.
:returns: The parsed URI reconstituted as a string.
:rtype: bytes
"""
parse_result = self
if use_idna and self.host:
# self.host is bytes, to encode to idna, we need to decode it
# first
host = self.host.decode(self.encoding)
hostbytes = host.encode('idna')
parse_result = self.copy_with(host=hostbytes)
if self.lazy_normalize:
parse_result = parse_result.copy_with(lazy_normalize=False)
uri = parse_result.reference.unsplit()
return uri.encode(self.encoding)
def split_authority(authority):
# Initialize our expected return values
userinfo = host = port = None
# Initialize an extra var we may need to use
extra_host = None
# Set-up rest in case there is no userinfo portion
rest = authority
if '@' in authority:
userinfo, rest = authority.rsplit('@', 1)
# Handle IPv6 host addresses
if rest.startswith('['):
host, rest = rest.split(']', 1)
host += ']'
if ':' in rest:
extra_host, port = rest.split(':', 1)
elif not host and rest:
host = rest
if extra_host and not host:
host = extra_host
return userinfo, host, port
def authority_from(reference, strict):
try:
subauthority = reference.authority_info()
except exceptions.InvalidAuthority:
if strict:
raise
userinfo, host, port = split_authority(reference.authority)
else:
# Thanks to Richard Barrell for this idea:
# https://twitter.com/0x2ba22e11/status/617338811975139328
userinfo, host, port = (subauthority.get(p)
for p in ('userinfo', 'host', 'port'))
if port:
try:
port = int(port)
except ValueError:
raise exceptions.InvalidPort(port)
return userinfo, host, port

View File

@@ -1,153 +0,0 @@
"""Module containing the implementation of the URIReference class."""
# -*- coding: utf-8 -*-
# Copyright (c) 2014 Rackspace
# Copyright (c) 2015 Ian Stapleton Cordasco
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from . import compat
from . import misc
from . import normalizers
from ._mixin import URIMixin
class URIReference(namedtuple('URIReference', misc.URI_COMPONENTS), URIMixin):
"""Immutable object representing a parsed URI Reference.
.. note::
This class is not intended to be directly instantiated by the user.
This object exposes attributes for the following components of a
URI:
- scheme
- authority
- path
- query
- fragment
.. attribute:: scheme
The scheme that was parsed for the URI Reference. For example,
``http``, ``https``, ``smtp``, ``imap``, etc.
.. attribute:: authority
Component of the URI that contains the user information, host,
and port sub-components. For example,
``google.com``, ``127.0.0.1:5000``, ``username@[::1]``,
``username:password@example.com:443``, etc.
.. attribute:: path
The path that was parsed for the given URI Reference. For example,
``/``, ``/index.php``, etc.
.. attribute:: query
The query component for a given URI Reference. For example, ``a=b``,
``a=b%20c``, ``a=b+c``, ``a=b,c=d,e=%20f``, etc.
.. attribute:: fragment
The fragment component of a URI. For example, ``section-3.1``.
This class also provides extra attributes for easier access to information
like the subcomponents of the authority component.
.. attribute:: userinfo
The user information parsed from the authority.
.. attribute:: host
The hostname, IPv4, or IPv6 adddres parsed from the authority.
.. attribute:: port
The port parsed from the authority.
"""
slots = ()
def __new__(cls, scheme, authority, path, query, fragment,
encoding='utf-8'):
"""Create a new URIReference."""
ref = super(URIReference, cls).__new__(
cls,
scheme or None,
authority or None,
path or None,
query,
fragment)
ref.encoding = encoding
return ref
__hash__ = tuple.__hash__
def __eq__(self, other):
"""Compare this reference to another."""
other_ref = other
if isinstance(other, tuple):
other_ref = URIReference(*other)
elif not isinstance(other, URIReference):
try:
other_ref = URIReference.from_string(other)
except TypeError:
raise TypeError(
'Unable to compare URIReference() to {0}()'.format(
type(other).__name__))
# See http://tools.ietf.org/html/rfc3986#section-6.2
naive_equality = tuple(self) == tuple(other_ref)
return naive_equality or self.normalized_equality(other_ref)
def normalize(self):
"""Normalize this reference as described in Section 6.2.2.
This is not an in-place normalization. Instead this creates a new
URIReference.
:returns: A new reference object with normalized components.
:rtype: URIReference
"""
# See http://tools.ietf.org/html/rfc3986#section-6.2.2 for logic in
# this method.
return URIReference(normalizers.normalize_scheme(self.scheme or ''),
normalizers.normalize_authority(
(self.userinfo, self.host, self.port)),
normalizers.normalize_path(self.path or ''),
normalizers.normalize_query(self.query),
normalizers.normalize_fragment(self.fragment),
self.encoding)
@classmethod
def from_string(cls, uri_string, encoding='utf-8'):
"""Parse a URI reference from the given unicode URI string.
:param str uri_string: Unicode URI to be parsed into a reference.
:param str encoding: The encoding of the string provided
:returns: :class:`URIReference` or subclass thereof
"""
uri_string = compat.to_str(uri_string, encoding)
split_uri = misc.URI_MATCHER.match(uri_string).groupdict()
return cls(
split_uri['scheme'], split_uri['authority'],
normalizers.encode_component(split_uri['path'], encoding),
normalizers.encode_component(split_uri['query'], encoding),
normalizers.encode_component(split_uri['fragment'], encoding),
encoding,
)

View File

@@ -1,450 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2017 Ian Stapleton Cordasco
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing the validation logic for rfc3986."""
from . import exceptions
from . import misc
from . import normalizers
class Validator(object):
"""Object used to configure validation of all objects in rfc3986.
.. versionadded:: 1.0
Example usage::
>>> from rfc3986 import api, validators
>>> uri = api.uri_reference('https://github.com/')
>>> validator = validators.Validator().require_presence_of(
... 'scheme', 'host', 'path',
... ).allow_schemes(
... 'http', 'https',
... ).allow_hosts(
... '127.0.0.1', 'github.com',
... )
>>> validator.validate(uri)
>>> invalid_uri = rfc3986.uri_reference('imap://mail.google.com')
>>> validator.validate(invalid_uri)
Traceback (most recent call last):
...
rfc3986.exceptions.MissingComponentError: ('path was required but
missing', URIReference(scheme=u'imap', authority=u'mail.google.com',
path=None, query=None, fragment=None), ['path'])
"""
COMPONENT_NAMES = frozenset([
'scheme',
'userinfo',
'host',
'port',
'path',
'query',
'fragment',
])
def __init__(self):
"""Initialize our default validations."""
self.allowed_schemes = set()
self.allowed_hosts = set()
self.allowed_ports = set()
self.allow_password = True
self.required_components = {
'scheme': False,
'userinfo': False,
'host': False,
'port': False,
'path': False,
'query': False,
'fragment': False,
}
self.validated_components = self.required_components.copy()
def allow_schemes(self, *schemes):
"""Require the scheme to be one of the provided schemes.
.. versionadded:: 1.0
:param schemes:
Schemes, without ``://`` that are allowed.
:returns:
The validator instance.
:rtype:
Validator
"""
for scheme in schemes:
self.allowed_schemes.add(normalizers.normalize_scheme(scheme))
return self
def allow_hosts(self, *hosts):
"""Require the host to be one of the provided hosts.
.. versionadded:: 1.0
:param hosts:
Hosts that are allowed.
:returns:
The validator instance.
:rtype:
Validator
"""
for host in hosts:
self.allowed_hosts.add(normalizers.normalize_host(host))
return self
def allow_ports(self, *ports):
"""Require the port to be one of the provided ports.
.. versionadded:: 1.0
:param ports:
Ports that are allowed.
:returns:
The validator instance.
:rtype:
Validator
"""
for port in ports:
port_int = int(port, base=10)
if 0 <= port_int <= 65535:
self.allowed_ports.add(port)
return self
def allow_use_of_password(self):
"""Allow passwords to be present in the URI.
.. versionadded:: 1.0
:returns:
The validator instance.
:rtype:
Validator
"""
self.allow_password = True
return self
def forbid_use_of_password(self):
"""Prevent passwords from being included in the URI.
.. versionadded:: 1.0
:returns:
The validator instance.
:rtype:
Validator
"""
self.allow_password = False
return self
def check_validity_of(self, *components):
"""Check the validity of the components provided.
This can be specified repeatedly.
.. versionadded:: 1.1
:param components:
Names of components from :attr:`Validator.COMPONENT_NAMES`.
:returns:
The validator instance.
:rtype:
Validator
"""
components = [c.lower() for c in components]
for component in components:
if component not in self.COMPONENT_NAMES:
raise ValueError(
'"{}" is not a valid component'.format(component)
)
self.validated_components.update({
component: True for component in components
})
return self
def require_presence_of(self, *components):
"""Require the components provided.
This can be specified repeatedly.
.. versionadded:: 1.0
:param components:
Names of components from :attr:`Validator.COMPONENT_NAMES`.
:returns:
The validator instance.
:rtype:
Validator
"""
components = [c.lower() for c in components]
for component in components:
if component not in self.COMPONENT_NAMES:
raise ValueError(
'"{}" is not a valid component'.format(component)
)
self.required_components.update({
component: True for component in components
})
return self
def validate(self, uri):
"""Check a URI for conditions specified on this validator.
.. versionadded:: 1.0
:param uri:
Parsed URI to validate.
:type uri:
rfc3986.uri.URIReference
:raises MissingComponentError:
When a required component is missing.
:raises UnpermittedComponentError:
When a component is not one of those allowed.
:raises PasswordForbidden:
When a password is present in the userinfo component but is
not permitted by configuration.
:raises InvalidComponentsError:
When a component was found to be invalid.
"""
if not self.allow_password:
check_password(uri)
required_components = [
component
for component, required in self.required_components.items()
if required
]
validated_components = [
component
for component, required in self.validated_components.items()
if required
]
if required_components:
ensure_required_components_exist(uri, required_components)
if validated_components:
ensure_components_are_valid(uri, validated_components)
ensure_one_of(self.allowed_schemes, uri, 'scheme')
ensure_one_of(self.allowed_hosts, uri, 'host')
ensure_one_of(self.allowed_ports, uri, 'port')
def check_password(uri):
"""Assert that there is no password present in the uri."""
userinfo = uri.userinfo
if not userinfo:
return
credentials = userinfo.split(':', 1)
if len(credentials) <= 1:
return
raise exceptions.PasswordForbidden(uri)
def ensure_one_of(allowed_values, uri, attribute):
"""Assert that the uri's attribute is one of the allowed values."""
value = getattr(uri, attribute)
if value is not None and allowed_values and value not in allowed_values:
raise exceptions.UnpermittedComponentError(
attribute, value, allowed_values,
)
def ensure_required_components_exist(uri, required_components):
"""Assert that all required components are present in the URI."""
missing_components = sorted([
component
for component in required_components
if getattr(uri, component) is None
])
if missing_components:
raise exceptions.MissingComponentError(uri, *missing_components)
def is_valid(value, matcher, require):
"""Determine if a value is valid based on the provided matcher.
:param str value:
Value to validate.
:param matcher:
Compiled regular expression to use to validate the value.
:param require:
Whether or not the value is required.
"""
if require:
return (value is not None
and matcher.match(value))
# require is False and value is not None
return value is None or matcher.match(value)
def authority_is_valid(authority, host=None, require=False):
"""Determine if the authority string is valid.
:param str authority:
The authority to validate.
:param str host:
(optional) The host portion of the authority to validate.
:param bool require:
(optional) Specify if authority must not be None.
:returns:
``True`` if valid, ``False`` otherwise
:rtype:
bool
"""
validated = is_valid(authority, misc.SUBAUTHORITY_MATCHER, require)
if validated and host is not None:
return host_is_valid(host, require)
return validated
def host_is_valid(host, require=False):
"""Determine if the host string is valid.
:param str host:
The host to validate.
:param bool require:
(optional) Specify if host must not be None.
:returns:
``True`` if valid, ``False`` otherwise
:rtype:
bool
"""
validated = is_valid(host, misc.HOST_MATCHER, require)
if validated and host is not None and misc.IPv4_MATCHER.match(host):
return valid_ipv4_host_address(host)
elif validated and host is not None and misc.IPv6_MATCHER.match(host):
return misc.IPv6_NO_RFC4007_MATCHER.match(host) is not None
return validated
def scheme_is_valid(scheme, require=False):
"""Determine if the scheme is valid.
:param str scheme:
The scheme string to validate.
:param bool require:
(optional) Set to ``True`` to require the presence of a scheme.
:returns:
``True`` if the scheme is valid. ``False`` otherwise.
:rtype:
bool
"""
return is_valid(scheme, misc.SCHEME_MATCHER, require)
def path_is_valid(path, require=False):
"""Determine if the path component is valid.
:param str path:
The path string to validate.
:param bool require:
(optional) Set to ``True`` to require the presence of a path.
:returns:
``True`` if the path is valid. ``False`` otherwise.
:rtype:
bool
"""
return is_valid(path, misc.PATH_MATCHER, require)
def query_is_valid(query, require=False):
"""Determine if the query component is valid.
:param str query:
The query string to validate.
:param bool require:
(optional) Set to ``True`` to require the presence of a query.
:returns:
``True`` if the query is valid. ``False`` otherwise.
:rtype:
bool
"""
return is_valid(query, misc.QUERY_MATCHER, require)
def fragment_is_valid(fragment, require=False):
"""Determine if the fragment component is valid.
:param str fragment:
The fragment string to validate.
:param bool require:
(optional) Set to ``True`` to require the presence of a fragment.
:returns:
``True`` if the fragment is valid. ``False`` otherwise.
:rtype:
bool
"""
return is_valid(fragment, misc.FRAGMENT_MATCHER, require)
def valid_ipv4_host_address(host):
"""Determine if the given host is a valid IPv4 address."""
# If the host exists, and it might be IPv4, check each byte in the
# address.
return all([0 <= int(byte, base=10) <= 255 for byte in host.split('.')])
_COMPONENT_VALIDATORS = {
'scheme': scheme_is_valid,
'path': path_is_valid,
'query': query_is_valid,
'fragment': fragment_is_valid,
}
_SUBAUTHORITY_VALIDATORS = set(['userinfo', 'host', 'port'])
def subauthority_component_is_valid(uri, component):
"""Determine if the userinfo, host, and port are valid."""
try:
subauthority_dict = uri.authority_info()
except exceptions.InvalidAuthority:
return False
# If we can parse the authority into sub-components and we're not
# validating the port, we can assume it's valid.
if component == 'host':
return host_is_valid(subauthority_dict['host'])
elif component != 'port':
return True
try:
port = int(subauthority_dict['port'])
except TypeError:
# If the port wasn't provided it'll be None and int(None) raises a
# TypeError
return True
return (0 <= port <= 65535)
def ensure_components_are_valid(uri, validated_components):
"""Assert that all components are valid in the URI."""
invalid_components = set([])
for component in validated_components:
if component in _SUBAUTHORITY_VALIDATORS:
if not subauthority_component_is_valid(uri, component):
invalid_components.add(component)
# Python's peephole optimizer means that while this continue *is*
# actually executed, coverage.py cannot detect that. See also,
# https://bitbucket.org/ned/coveragepy/issues/198/continue-marked-as-not-covered
continue # nocov: Python 2.7, 3.3, 3.4
validator = _COMPONENT_VALIDATORS[component]
if not validator(getattr(uri, component)):
invalid_components.add(component)
if invalid_components:
raise exceptions.InvalidComponentsError(uri, *invalid_components)

View File

@@ -1,6 +1,4 @@
"""Utilities for writing code that runs on Python 2 and 3"""
# Copyright (c) 2010-2015 Benjamin Peterson
# Copyright (c) 2010-2019 Benjamin Peterson
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -20,6 +18,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Utilities for writing code that runs on Python 2 and 3"""
from __future__ import absolute_import
import functools
@@ -29,7 +29,7 @@ import sys
import types
__author__ = "Benjamin Peterson <benjamin@python.org>"
__version__ = "1.10.0"
__version__ = "1.12.0"
# Useful for very coarse version differentiation.
@@ -38,15 +38,15 @@ PY3 = sys.version_info[0] == 3
PY34 = sys.version_info[0:2] >= (3, 4)
if PY3:
string_types = str,
integer_types = int,
class_types = type,
string_types = (str,)
integer_types = (int,)
class_types = (type,)
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
else:
string_types = basestring,
string_types = (basestring,)
integer_types = (int, long)
class_types = (type, types.ClassType)
text_type = unicode
@@ -58,9 +58,9 @@ else:
else:
# It's possible to have sizeof(long) != sizeof(Py_ssize_t).
class X(object):
def __len__(self):
return 1 << 31
try:
len(X())
except OverflowError:
@@ -84,7 +84,6 @@ def _import_module(name):
class _LazyDescr(object):
def __init__(self, name):
self.name = name
@@ -101,7 +100,6 @@ class _LazyDescr(object):
class MovedModule(_LazyDescr):
def __init__(self, name, old, new=None):
super(MovedModule, self).__init__(name)
if PY3:
@@ -122,7 +120,6 @@ class MovedModule(_LazyDescr):
class _LazyModule(types.ModuleType):
def __init__(self, name):
super(_LazyModule, self).__init__(name)
self.__doc__ = self.__class__.__doc__
@@ -137,7 +134,6 @@ class _LazyModule(types.ModuleType):
class MovedAttribute(_LazyDescr):
def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None):
super(MovedAttribute, self).__init__(name)
if PY3:
@@ -221,28 +217,36 @@ class _SixMetaPathImporter(object):
Required, if is_package is implemented"""
self.__get_module(fullname) # eventually raises ImportError
return None
get_source = get_code # same as get_code
_importer = _SixMetaPathImporter(__name__)
class _MovedItems(_LazyModule):
"""Lazy loading of moved objects"""
__path__ = [] # mark as package
_moved_attributes = [
MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"),
MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"),
MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"),
MovedAttribute(
"filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"
),
MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"),
MovedAttribute("intern", "__builtin__", "sys"),
MovedAttribute("map", "itertools", "builtins", "imap", "map"),
MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"),
MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"),
MovedAttribute("getoutput", "commands", "subprocess"),
MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"),
MovedAttribute(
"reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"
),
MovedAttribute("reduce", "__builtin__", "functools"),
MovedAttribute("shlex_quote", "pipes", "shlex", "quote"),
MovedAttribute("StringIO", "StringIO", "io"),
@@ -251,7 +255,9 @@ _moved_attributes = [
MovedAttribute("UserString", "UserString", "collections"),
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
MovedAttribute(
"zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"
),
MovedModule("builtins", "__builtin__"),
MovedModule("configparser", "ConfigParser"),
MovedModule("copyreg", "copy_reg"),
@@ -262,10 +268,13 @@ _moved_attributes = [
MovedModule("html_entities", "htmlentitydefs", "html.entities"),
MovedModule("html_parser", "HTMLParser", "html.parser"),
MovedModule("http_client", "httplib", "http.client"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"),
MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"),
MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"),
MovedModule(
"email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"
),
MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"),
MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"),
MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"),
MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"),
@@ -283,15 +292,12 @@ _moved_attributes = [
MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"),
MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"),
MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"),
MovedModule("tkinter_colorchooser", "tkColorChooser",
"tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog",
"tkinter.commondialog"),
MovedModule("tkinter_colorchooser", "tkColorChooser", "tkinter.colorchooser"),
MovedModule("tkinter_commondialog", "tkCommonDialog", "tkinter.commondialog"),
MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"),
MovedModule("tkinter_font", "tkFont", "tkinter.font"),
MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog",
"tkinter.simpledialog"),
MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", "tkinter.simpledialog"),
MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"),
MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"),
MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"),
@@ -301,9 +307,7 @@ _moved_attributes = [
]
# Add windows specific modules.
if sys.platform == "win32":
_moved_attributes += [
MovedModule("winreg", "_winreg"),
]
_moved_attributes += [MovedModule("winreg", "_winreg")]
for attr in _moved_attributes:
setattr(_MovedItems, attr.name, attr)
@@ -337,10 +341,14 @@ _urllib_parse_moved_attributes = [
MovedAttribute("quote_plus", "urllib", "urllib.parse"),
MovedAttribute("unquote", "urllib", "urllib.parse"),
MovedAttribute("unquote_plus", "urllib", "urllib.parse"),
MovedAttribute(
"unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"
),
MovedAttribute("urlencode", "urllib", "urllib.parse"),
MovedAttribute("splitquery", "urllib", "urllib.parse"),
MovedAttribute("splittag", "urllib", "urllib.parse"),
MovedAttribute("splituser", "urllib", "urllib.parse"),
MovedAttribute("splitvalue", "urllib", "urllib.parse"),
MovedAttribute("uses_fragment", "urlparse", "urllib.parse"),
MovedAttribute("uses_netloc", "urlparse", "urllib.parse"),
MovedAttribute("uses_params", "urlparse", "urllib.parse"),
@@ -353,8 +361,11 @@ del attr
Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes
_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
"moves.urllib_parse", "moves.urllib.parse")
_importer._add_module(
Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"),
"moves.urllib_parse",
"moves.urllib.parse",
)
class Module_six_moves_urllib_error(_LazyModule):
@@ -373,8 +384,11 @@ del attr
Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes
_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
"moves.urllib_error", "moves.urllib.error")
_importer._add_module(
Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"),
"moves.urllib_error",
"moves.urllib.error",
)
class Module_six_moves_urllib_request(_LazyModule):
@@ -416,6 +430,8 @@ _urllib_request_moved_attributes = [
MovedAttribute("URLopener", "urllib", "urllib.request"),
MovedAttribute("FancyURLopener", "urllib", "urllib.request"),
MovedAttribute("proxy_bypass", "urllib", "urllib.request"),
MovedAttribute("parse_http_list", "urllib2", "urllib.request"),
MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"),
]
for attr in _urllib_request_moved_attributes:
setattr(Module_six_moves_urllib_request, attr.name, attr)
@@ -423,8 +439,11 @@ del attr
Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes
_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
"moves.urllib_request", "moves.urllib.request")
_importer._add_module(
Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"),
"moves.urllib_request",
"moves.urllib.request",
)
class Module_six_moves_urllib_response(_LazyModule):
@@ -444,8 +463,11 @@ del attr
Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes
_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
"moves.urllib_response", "moves.urllib.response")
_importer._add_module(
Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"),
"moves.urllib_response",
"moves.urllib.response",
)
class Module_six_moves_urllib_robotparser(_LazyModule):
@@ -454,21 +476,27 @@ class Module_six_moves_urllib_robotparser(_LazyModule):
_urllib_robotparser_moved_attributes = [
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"),
MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser")
]
for attr in _urllib_robotparser_moved_attributes:
setattr(Module_six_moves_urllib_robotparser, attr.name, attr)
del attr
Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes
Module_six_moves_urllib_robotparser._moved_attributes = (
_urllib_robotparser_moved_attributes
)
_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser", "moves.urllib.robotparser")
_importer._add_module(
Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"),
"moves.urllib_robotparser",
"moves.urllib.robotparser",
)
class Module_six_moves_urllib(types.ModuleType):
"""Create a six.moves.urllib namespace that resembles the Python 3 namespace"""
__path__ = [] # mark as package
parse = _importer._get_module("moves.urllib_parse")
error = _importer._get_module("moves.urllib_error")
@@ -477,10 +505,12 @@ class Module_six_moves_urllib(types.ModuleType):
robotparser = _importer._get_module("moves.urllib_robotparser")
def __dir__(self):
return ['parse', 'error', 'request', 'response', 'robotparser']
return ["parse", "error", "request", "response", "robotparser"]
_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"),
"moves.urllib")
_importer._add_module(
Module_six_moves_urllib(__name__ + ".moves.urllib"), "moves.urllib"
)
def add_move(move):
@@ -520,19 +550,24 @@ else:
try:
advance_iterator = next
except NameError:
def advance_iterator(it):
return it.next()
next = advance_iterator
try:
callable = callable
except NameError:
def callable(obj):
return any("__call__" in klass.__dict__ for klass in type(obj).__mro__)
if PY3:
def get_unbound_function(unbound):
return unbound
@@ -543,6 +578,7 @@ if PY3:
Iterator = object
else:
def get_unbound_function(unbound):
return unbound.im_func
@@ -553,13 +589,13 @@ else:
return types.MethodType(func, None, cls)
class Iterator(object):
def next(self):
return type(self).__next__(self)
callable = callable
_add_doc(get_unbound_function,
"""Get the function out of a possibly unbound function""")
_add_doc(
get_unbound_function, """Get the function out of a possibly unbound function"""
)
get_method_function = operator.attrgetter(_meth_func)
@@ -571,6 +607,7 @@ get_function_globals = operator.attrgetter(_func_globals)
if PY3:
def iterkeys(d, **kw):
return iter(d.keys(**kw))
@@ -589,6 +626,7 @@ if PY3:
viewitems = operator.methodcaller("items")
else:
def iterkeys(d, **kw):
return d.iterkeys(**kw)
@@ -609,28 +647,33 @@ else:
_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.")
_add_doc(itervalues, "Return an iterator over the values of a dictionary.")
_add_doc(iteritems,
"Return an iterator over the (key, value) pairs of a dictionary.")
_add_doc(iterlists,
"Return an iterator over the (key, [values]) pairs of a dictionary.")
_add_doc(iteritems, "Return an iterator over the (key, value) pairs of a dictionary.")
_add_doc(
iterlists, "Return an iterator over the (key, [values]) pairs of a dictionary."
)
if PY3:
def b(s):
return s.encode("latin-1")
def u(s):
return s
unichr = chr
import struct
int2byte = struct.Struct(">B").pack
del struct
byte2int = operator.itemgetter(0)
indexbytes = operator.getitem
iterbytes = iter
import io
StringIO = io.StringIO
BytesIO = io.BytesIO
del io
_assertCountEqual = "assertCountEqual"
if sys.version_info[1] <= 1:
_assertRaisesRegex = "assertRaisesRegexp"
@@ -639,12 +682,15 @@ if PY3:
_assertRaisesRegex = "assertRaisesRegex"
_assertRegex = "assertRegex"
else:
def b(s):
return s
# Workaround for standalone backslash
def u(s):
return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape")
return unicode(s.replace(r"\\", r"\\\\"), "unicode_escape")
unichr = unichr
int2byte = chr
@@ -653,8 +699,10 @@ else:
def indexbytes(buf, i):
return ord(buf[i])
iterbytes = functools.partial(itertools.imap, ord)
import StringIO
StringIO = BytesIO = StringIO.StringIO
_assertCountEqual = "assertItemsEqual"
_assertRaisesRegex = "assertRaisesRegexp"
@@ -679,13 +727,19 @@ if PY3:
exec_ = getattr(moves.builtins, "exec")
def reraise(tp, value, tb=None):
if value is None:
value = tp()
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
try:
if value is None:
value = tp()
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
finally:
value = None
tb = None
else:
def exec_(_code_, _globs_=None, _locs_=None):
"""Execute code in a namespace."""
if _globs_ is None:
@@ -698,28 +752,45 @@ else:
_locs_ = _globs_
exec("""exec _code_ in _globs_, _locs_""")
exec_("""def reraise(tp, value, tb=None):
raise tp, value, tb
""")
exec_(
"""def reraise(tp, value, tb=None):
try:
raise tp, value, tb
finally:
tb = None
"""
)
if sys.version_info[:2] == (3, 2):
exec_("""def raise_from(value, from_value):
if from_value is None:
raise value
raise value from from_value
""")
exec_(
"""def raise_from(value, from_value):
try:
if from_value is None:
raise value
raise value from from_value
finally:
value = None
"""
)
elif sys.version_info[:2] > (3, 2):
exec_("""def raise_from(value, from_value):
raise value from from_value
""")
exec_(
"""def raise_from(value, from_value):
try:
raise value from from_value
finally:
value = None
"""
)
else:
def raise_from(value, from_value):
raise value
print_ = getattr(moves.builtins, "print", None)
if print_ is None:
def print_(*args, **kwargs):
"""The new-style print function for Python 2.4 and 2.5."""
fp = kwargs.pop("file", sys.stdout)
@@ -730,14 +801,17 @@ if print_ is None:
if not isinstance(data, basestring):
data = str(data)
# If the file has an encoding, encode unicode with it.
if (isinstance(fp, file) and
isinstance(data, unicode) and
fp.encoding is not None):
if (
isinstance(fp, file)
and isinstance(data, unicode)
and fp.encoding is not None
):
errors = getattr(fp, "errors", None)
if errors is None:
errors = "strict"
data = data.encode(fp.encoding, errors)
fp.write(data)
want_unicode = False
sep = kwargs.pop("sep", None)
if sep is not None:
@@ -773,6 +847,8 @@ if print_ is None:
write(sep)
write(arg)
write(end)
if sys.version_info[:2] < (3, 3):
_print = print_
@@ -783,16 +859,24 @@ if sys.version_info[:2] < (3, 3):
if flush and fp is not None:
fp.flush()
_add_doc(reraise, """Reraise an exception.""")
if sys.version_info[0:2] < (3, 4):
def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES):
def wraps(
wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
def wrapper(f):
f = functools.wraps(wrapped, assigned, updated)(f)
f.__wrapped__ = wrapped
return f
return wrapper
else:
wraps = functools.wraps
@@ -802,29 +886,95 @@ def with_metaclass(meta, *bases):
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(meta):
class metaclass(type):
def __new__(cls, name, this_bases, d):
return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, "temporary_class", (), {})
def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass."""
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
slots = orig_vars.get("__slots__")
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
orig_vars.pop("__dict__", None)
orig_vars.pop("__weakref__", None)
if hasattr(cls, "__qualname__"):
orig_vars["__qualname__"] = cls.__qualname__
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
def ensure_binary(s, encoding="utf-8", errors="strict"):
"""Coerce **s** to six.binary_type.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> encoded to `bytes`
- `bytes` -> `bytes`
"""
if isinstance(s, text_type):
return s.encode(encoding, errors)
elif isinstance(s, binary_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
def ensure_str(s, encoding="utf-8", errors="strict"):
"""Coerce *s* to `str`.
For Python 2:
- `unicode` -> encoded to `str`
- `str` -> `str`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
if not isinstance(s, (text_type, binary_type)):
raise TypeError("not expecting type '%s'" % type(s))
if PY2 and isinstance(s, text_type):
s = s.encode(encoding, errors)
elif PY3 and isinstance(s, binary_type):
s = s.decode(encoding, errors)
return s
def ensure_text(s, encoding="utf-8", errors="strict"):
"""Coerce *s* to six.text_type.
For Python 2:
- `unicode` -> `unicode`
- `str` -> `unicode`
For Python 3:
- `str` -> `str`
- `bytes` -> decoded to `str`
"""
if isinstance(s, binary_type):
return s.decode(encoding, errors)
elif isinstance(s, text_type):
return s
else:
raise TypeError("not expecting type '%s'" % type(s))
def python_2_unicode_compatible(klass):
"""
A decorator that defines __unicode__ and __str__ methods under Python 2.
@@ -834,12 +984,13 @@ def python_2_unicode_compatible(klass):
returning text and apply this decorator to the class.
"""
if PY2:
if '__str__' not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." %
klass.__name__)
if "__str__" not in klass.__dict__:
raise ValueError(
"@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." % klass.__name__
)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
klass.__str__ = lambda self: self.__unicode__().encode("utf-8")
return klass
@@ -859,8 +1010,10 @@ if sys.meta_path:
# be floating around. Therefore, we can't use isinstance() to check for
# the six meta path importer, since the other six instance will have
# inserted an importer with different class.
if (type(importer).__name__ == "_SixMetaPathImporter" and
importer.name == __name__):
if (
type(importer).__name__ == "_SixMetaPathImporter"
and importer.name == __name__
):
del sys.meta_path[i]
break
del i, importer

View File

@@ -16,4 +16,4 @@ except ImportError:
from ._implementation import CertificateError, match_hostname
# Not needed, but documenting what we provide.
__all__ = ('CertificateError', 'match_hostname')
__all__ = ("CertificateError", "match_hostname")

View File

@@ -15,7 +15,7 @@ try:
except ImportError:
ipaddress = None
__version__ = '3.5.0.1'
__version__ = "3.5.0.1"
class CertificateError(ValueError):
@@ -33,18 +33,19 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# Ported from python3-syntax:
# leftmost, *remainder = dn.split(r'.')
parts = dn.split(r'.')
parts = dn.split(r".")
leftmost = parts[0]
remainder = parts[1:]
wildcards = leftmost.count('*')
wildcards = leftmost.count("*")
if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established
# policy among SSL implementations showed it to be a
# reasonable choice.
raise CertificateError(
"too many wildcards in certificate DNS name: " + repr(dn))
"too many wildcards in certificate DNS name: " + repr(dn)
)
# speed up common case w/o wildcards
if not wildcards:
@@ -53,11 +54,11 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
# RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label.
if leftmost == '*':
if leftmost == "*":
# When '*' is a fragment by itself, it matches a non-empty dotless
# fragment.
pats.append('[^.]+')
elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
pats.append("[^.]+")
elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
# RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or
@@ -65,21 +66,22 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
pats.append(re.escape(leftmost))
else:
# Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
# add the remaining fragments, ignore any wildcards
for frag in remainder:
pats.append(re.escape(frag))
pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
return pat.match(hostname)
def _to_unicode(obj):
if isinstance(obj, str) and sys.version_info < (3,):
obj = unicode(obj, encoding='ascii', errors='strict')
obj = unicode(obj, encoding="ascii", errors="strict")
return obj
def _ipaddress_match(ipname, host_ip):
"""Exact matching of IP addresses.
@@ -101,9 +103,11 @@ def match_hostname(cert, hostname):
returns nothing.
"""
if not cert:
raise ValueError("empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED")
raise ValueError(
"empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED"
)
try:
# Divergence from upstream: ipaddress can't handle byte str
host_ip = ipaddress.ip_address(_to_unicode(hostname))
@@ -122,35 +126,35 @@ def match_hostname(cert, hostname):
else:
raise
dnsnames = []
san = cert.get('subjectAltName', ())
san = cert.get("subjectAltName", ())
for key, value in san:
if key == 'DNS':
if key == "DNS":
if host_ip is None and _dnsname_match(value, hostname):
return
dnsnames.append(value)
elif key == 'IP Address':
elif key == "IP Address":
if host_ip is not None and _ipaddress_match(value, host_ip):
return
dnsnames.append(value)
if not dnsnames:
# The subject is only checked when there is no dNSName entry
# in subjectAltName
for sub in cert.get('subject', ()):
for sub in cert.get("subject", ()):
for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name
# must be used.
if key == 'commonName':
if key == "commonName":
if _dnsname_match(value, hostname):
return
dnsnames.append(value)
if len(dnsnames) > 1:
raise CertificateError("hostname %r "
"doesn't match either of %s"
% (hostname, ', '.join(map(repr, dnsnames))))
raise CertificateError(
"hostname %r "
"doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1:
raise CertificateError("hostname %r "
"doesn't match %r"
% (hostname, dnsnames[0]))
raise CertificateError("hostname %r doesn't match %r" % (hostname, dnsnames[0]))
else:
raise CertificateError("no appropriate commonName or "
"subjectAltName fields were found")
raise CertificateError(
"no appropriate commonName or subjectAltName fields were found"
)

View File

@@ -2,11 +2,17 @@ from __future__ import absolute_import
import collections
import functools
import logging
import warnings
from ._collections import RecentlyUsedContainer
from .connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from .connectionpool import port_by_scheme
from .exceptions import LocationValueError, MaxRetryError, ProxySchemeUnknown
from .exceptions import (
LocationValueError,
MaxRetryError,
ProxySchemeUnknown,
InvalidProxyConfigurationWarning,
)
from .packages import six
from .packages.six.moves.urllib.parse import urljoin
from .request import RequestMethods
@@ -14,48 +20,55 @@ from .util.url import parse_url
from .util.retry import Retry
__all__ = ['PoolManager', 'ProxyManager', 'proxy_from_url']
__all__ = ["PoolManager", "ProxyManager", "proxy_from_url"]
log = logging.getLogger(__name__)
SSL_KEYWORDS = ('key_file', 'cert_file', 'cert_reqs', 'ca_certs',
'ssl_version', 'ca_cert_dir', 'ssl_context',
'key_password')
SSL_KEYWORDS = (
"key_file",
"cert_file",
"cert_reqs",
"ca_certs",
"ssl_version",
"ca_cert_dir",
"ssl_context",
"key_password",
)
# All known keyword arguments that could be provided to the pool manager, its
# pools, or the underlying connections. This is used to construct a pool key.
_key_fields = (
'key_scheme', # str
'key_host', # str
'key_port', # int
'key_timeout', # int or float or Timeout
'key_retries', # int or Retry
'key_strict', # bool
'key_block', # bool
'key_source_address', # str
'key_key_file', # str
'key_key_password', # str
'key_cert_file', # str
'key_cert_reqs', # str
'key_ca_certs', # str
'key_ssl_version', # str
'key_ca_cert_dir', # str
'key_ssl_context', # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
'key_maxsize', # int
'key_headers', # dict
'key__proxy', # parsed proxy url
'key__proxy_headers', # dict
'key_socket_options', # list of (level (int), optname (int), value (int or str)) tuples
'key__socks_options', # dict
'key_assert_hostname', # bool or string
'key_assert_fingerprint', # str
'key_server_hostname', # str
"key_scheme", # str
"key_host", # str
"key_port", # int
"key_timeout", # int or float or Timeout
"key_retries", # int or Retry
"key_strict", # bool
"key_block", # bool
"key_source_address", # str
"key_key_file", # str
"key_key_password", # str
"key_cert_file", # str
"key_cert_reqs", # str
"key_ca_certs", # str
"key_ssl_version", # str
"key_ca_cert_dir", # str
"key_ssl_context", # instance of ssl.SSLContext or urllib3.util.ssl_.SSLContext
"key_maxsize", # int
"key_headers", # dict
"key__proxy", # parsed proxy url
"key__proxy_headers", # dict
"key_socket_options", # list of (level (int), optname (int), value (int or str)) tuples
"key__socks_options", # dict
"key_assert_hostname", # bool or string
"key_assert_fingerprint", # str
"key_server_hostname", # str
)
#: The namedtuple class used to construct keys for the connection pool.
#: All custom key schemes should include the fields in this key at a minimum.
PoolKey = collections.namedtuple('PoolKey', _key_fields)
PoolKey = collections.namedtuple("PoolKey", _key_fields)
def _default_key_normalizer(key_class, request_context):
@@ -80,24 +93,24 @@ def _default_key_normalizer(key_class, request_context):
"""
# Since we mutate the dictionary, make a copy first
context = request_context.copy()
context['scheme'] = context['scheme'].lower()
context['host'] = context['host'].lower()
context["scheme"] = context["scheme"].lower()
context["host"] = context["host"].lower()
# These are both dictionaries and need to be transformed into frozensets
for key in ('headers', '_proxy_headers', '_socks_options'):
for key in ("headers", "_proxy_headers", "_socks_options"):
if key in context and context[key] is not None:
context[key] = frozenset(context[key].items())
# The socket_options key may be a list and needs to be transformed into a
# tuple.
socket_opts = context.get('socket_options')
socket_opts = context.get("socket_options")
if socket_opts is not None:
context['socket_options'] = tuple(socket_opts)
context["socket_options"] = tuple(socket_opts)
# Map the kwargs to the names in the namedtuple - this is necessary since
# namedtuples can't have fields starting with '_'.
for key in list(context.keys()):
context['key_' + key] = context.pop(key)
context["key_" + key] = context.pop(key)
# Default to ``None`` for keys missing from the context
for field in key_class._fields:
@@ -112,14 +125,11 @@ def _default_key_normalizer(key_class, request_context):
#: Each PoolManager makes a copy of this dictionary so they can be configured
#: globally here, or individually on the instance.
key_fn_by_scheme = {
'http': functools.partial(_default_key_normalizer, PoolKey),
'https': functools.partial(_default_key_normalizer, PoolKey),
"http": functools.partial(_default_key_normalizer, PoolKey),
"https": functools.partial(_default_key_normalizer, PoolKey),
}
pool_classes_by_scheme = {
'http': HTTPConnectionPool,
'https': HTTPSConnectionPool,
}
pool_classes_by_scheme = {"http": HTTPConnectionPool, "https": HTTPSConnectionPool}
class PoolManager(RequestMethods):
@@ -155,8 +165,7 @@ class PoolManager(RequestMethods):
def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools,
dispose_func=lambda p: p.close())
self.pools = RecentlyUsedContainer(num_pools, dispose_func=lambda p: p.close())
# Locally set the pool classes and keys so other PoolManagers can
# override them.
@@ -189,10 +198,10 @@ class PoolManager(RequestMethods):
# this function has historically only used the scheme, host, and port
# in the positional args. When an API change is acceptable these can
# be removed.
for key in ('scheme', 'host', 'port'):
for key in ("scheme", "host", "port"):
request_context.pop(key, None)
if scheme == 'http':
if scheme == "http":
for kw in SSL_KEYWORDS:
request_context.pop(kw, None)
@@ -207,7 +216,7 @@ class PoolManager(RequestMethods):
"""
self.pools.clear()
def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None):
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
"""
Get a :class:`ConnectionPool` based on the host, port, and scheme.
@@ -222,11 +231,11 @@ class PoolManager(RequestMethods):
raise LocationValueError("No host specified.")
request_context = self._merge_pool_kwargs(pool_kwargs)
request_context['scheme'] = scheme or 'http'
request_context["scheme"] = scheme or "http"
if not port:
port = port_by_scheme.get(request_context['scheme'].lower(), 80)
request_context['port'] = port
request_context['host'] = host
port = port_by_scheme.get(request_context["scheme"].lower(), 80)
request_context["port"] = port
request_context["host"] = host
return self.connection_from_context(request_context)
@@ -237,7 +246,7 @@ class PoolManager(RequestMethods):
``request_context`` must at least contain the ``scheme`` key and its
value must be a key in ``key_fn_by_scheme`` instance variable.
"""
scheme = request_context['scheme'].lower()
scheme = request_context["scheme"].lower()
pool_key_constructor = self.key_fn_by_scheme[scheme]
pool_key = pool_key_constructor(request_context)
@@ -259,9 +268,9 @@ class PoolManager(RequestMethods):
return pool
# Make a fresh ConnectionPool of the desired type
scheme = request_context['scheme']
host = request_context['host']
port = request_context['port']
scheme = request_context["scheme"]
host = request_context["host"]
port = request_context["port"]
pool = self._new_pool(scheme, host, port, request_context=request_context)
self.pools[pool_key] = pool
@@ -279,8 +288,9 @@ class PoolManager(RequestMethods):
not used.
"""
u = parse_url(url)
return self.connection_from_host(u.host, port=u.port, scheme=u.scheme,
pool_kwargs=pool_kwargs)
return self.connection_from_host(
u.host, port=u.port, scheme=u.scheme, pool_kwargs=pool_kwargs
)
def _merge_pool_kwargs(self, override):
"""
@@ -314,11 +324,11 @@ class PoolManager(RequestMethods):
u = parse_url(url)
conn = self.connection_from_host(u.host, port=u.port, scheme=u.scheme)
kw['assert_same_host'] = False
kw['redirect'] = False
kw["assert_same_host"] = False
kw["redirect"] = False
if 'headers' not in kw:
kw['headers'] = self.headers.copy()
if "headers" not in kw:
kw["headers"] = self.headers.copy()
if self.proxy is not None and u.scheme == "http":
response = conn.urlopen(method, url, **kw)
@@ -334,33 +344,37 @@ class PoolManager(RequestMethods):
# RFC 7231, Section 6.4.4
if response.status == 303:
method = 'GET'
method = "GET"
retries = kw.get('retries')
retries = kw.get("retries")
if not isinstance(retries, Retry):
retries = Retry.from_int(retries, redirect=redirect)
# Strip headers marked as unsafe to forward to the redirected location.
# Check remove_headers_on_redirect to avoid a potential network call within
# conn.is_same_host() which may use socket.gethostbyname() in the future.
if (retries.remove_headers_on_redirect
and not conn.is_same_host(redirect_location)):
headers = list(six.iterkeys(kw['headers']))
if retries.remove_headers_on_redirect and not conn.is_same_host(
redirect_location
):
headers = list(six.iterkeys(kw["headers"]))
for header in headers:
if header.lower() in retries.remove_headers_on_redirect:
kw['headers'].pop(header, None)
kw["headers"].pop(header, None)
try:
retries = retries.increment(method, url, response=response, _pool=conn)
except MaxRetryError:
if retries.raise_on_redirect:
response.drain_conn()
raise
return response
kw['retries'] = retries
kw['redirect'] = redirect
kw["retries"] = retries
kw["redirect"] = redirect
log.info("Redirecting %s -> %s", url, redirect_location)
response.drain_conn()
return self.urlopen(method, redirect_location, **kw)
@@ -391,12 +405,21 @@ class ProxyManager(PoolManager):
"""
def __init__(self, proxy_url, num_pools=10, headers=None,
proxy_headers=None, **connection_pool_kw):
def __init__(
self,
proxy_url,
num_pools=10,
headers=None,
proxy_headers=None,
**connection_pool_kw
):
if isinstance(proxy_url, HTTPConnectionPool):
proxy_url = '%s://%s:%i' % (proxy_url.scheme, proxy_url.host,
proxy_url.port)
proxy_url = "%s://%s:%i" % (
proxy_url.scheme,
proxy_url.host,
proxy_url.port,
)
proxy = parse_url(proxy_url)
if not proxy.port:
port = port_by_scheme.get(proxy.scheme, 80)
@@ -408,45 +431,59 @@ class ProxyManager(PoolManager):
self.proxy = proxy
self.proxy_headers = proxy_headers or {}
connection_pool_kw['_proxy'] = self.proxy
connection_pool_kw['_proxy_headers'] = self.proxy_headers
connection_pool_kw["_proxy"] = self.proxy
connection_pool_kw["_proxy_headers"] = self.proxy_headers
super(ProxyManager, self).__init__(
num_pools, headers, **connection_pool_kw)
super(ProxyManager, self).__init__(num_pools, headers, **connection_pool_kw)
def connection_from_host(self, host, port=None, scheme='http', pool_kwargs=None):
def connection_from_host(self, host, port=None, scheme="http", pool_kwargs=None):
if scheme == "https":
return super(ProxyManager, self).connection_from_host(
host, port, scheme, pool_kwargs=pool_kwargs)
host, port, scheme, pool_kwargs=pool_kwargs
)
return super(ProxyManager, self).connection_from_host(
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs)
self.proxy.host, self.proxy.port, self.proxy.scheme, pool_kwargs=pool_kwargs
)
def _set_proxy_headers(self, url, headers=None):
"""
Sets headers needed by proxies: specifically, the Accept and Host
headers. Only sets headers not provided by the user.
"""
headers_ = {'Accept': '*/*'}
headers_ = {"Accept": "*/*"}
netloc = parse_url(url).netloc
if netloc:
headers_['Host'] = netloc
headers_["Host"] = netloc
if headers:
headers_.update(headers)
return headers_
def _validate_proxy_scheme_url_selection(self, url_scheme):
if url_scheme == "https" and self.proxy.scheme == "https":
warnings.warn(
"Your proxy configuration specified an HTTPS scheme for the proxy. "
"Are you sure you want to use HTTPS to contact the proxy? "
"This most likely indicates an error in your configuration. "
"Read this issue for more info: "
"https://github.com/urllib3/urllib3/issues/1850",
InvalidProxyConfigurationWarning,
stacklevel=3,
)
def urlopen(self, method, url, redirect=True, **kw):
"Same as HTTP(S)ConnectionPool.urlopen, ``url`` must be absolute."
u = parse_url(url)
self._validate_proxy_scheme_url_selection(u.scheme)
if u.scheme == "http":
# For proxied HTTPS requests, httplib sets the necessary headers
# on the CONNECT to the proxy. For HTTP, we'll definitely
# need to set 'Host' at the very least.
headers = kw.get('headers', self.headers)
kw['headers'] = self._set_proxy_headers(url, headers)
headers = kw.get("headers", self.headers)
kw["headers"] = self._set_proxy_headers(url, headers)
return super(ProxyManager, self).urlopen(method, url, redirect=redirect, **kw)

View File

@@ -4,7 +4,7 @@ from .filepost import encode_multipart_formdata
from .packages.six.moves.urllib.parse import urlencode
__all__ = ['RequestMethods']
__all__ = ["RequestMethods"]
class RequestMethods(object):
@@ -36,16 +36,25 @@ class RequestMethods(object):
explicitly.
"""
_encode_url_methods = {'DELETE', 'GET', 'HEAD', 'OPTIONS'}
_encode_url_methods = {"DELETE", "GET", "HEAD", "OPTIONS"}
def __init__(self, headers=None):
self.headers = headers or {}
def urlopen(self, method, url, body=None, headers=None,
encode_multipart=True, multipart_boundary=None,
**kw): # Abstract
raise NotImplementedError("Classes extending RequestMethods must implement "
"their own ``urlopen`` method.")
def urlopen(
self,
method,
url,
body=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**kw
): # Abstract
raise NotImplementedError(
"Classes extending RequestMethods must implement "
"their own ``urlopen`` method."
)
def request(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
@@ -60,19 +69,18 @@ class RequestMethods(object):
"""
method = method.upper()
urlopen_kw['request_url'] = url
urlopen_kw["request_url"] = url
if method in self._encode_url_methods:
return self.request_encode_url(method, url, fields=fields,
headers=headers,
**urlopen_kw)
return self.request_encode_url(
method, url, fields=fields, headers=headers, **urlopen_kw
)
else:
return self.request_encode_body(method, url, fields=fields,
headers=headers,
**urlopen_kw)
return self.request_encode_body(
method, url, fields=fields, headers=headers, **urlopen_kw
)
def request_encode_url(self, method, url, fields=None, headers=None,
**urlopen_kw):
def request_encode_url(self, method, url, fields=None, headers=None, **urlopen_kw):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the url. This is useful for request methods like GET, HEAD, DELETE, etc.
@@ -80,17 +88,24 @@ class RequestMethods(object):
if headers is None:
headers = self.headers
extra_kw = {'headers': headers}
extra_kw = {"headers": headers}
extra_kw.update(urlopen_kw)
if fields:
url += '?' + urlencode(fields)
url += "?" + urlencode(fields)
return self.urlopen(method, url, **extra_kw)
def request_encode_body(self, method, url, fields=None, headers=None,
encode_multipart=True, multipart_boundary=None,
**urlopen_kw):
def request_encode_body(
self,
method,
url,
fields=None,
headers=None,
encode_multipart=True,
multipart_boundary=None,
**urlopen_kw
):
"""
Make a request using :meth:`urlopen` with the ``fields`` encoded in
the body. This is useful for request methods like POST, PUT, PATCH, etc.
@@ -129,22 +144,28 @@ class RequestMethods(object):
if headers is None:
headers = self.headers
extra_kw = {'headers': {}}
extra_kw = {"headers": {}}
if fields:
if 'body' in urlopen_kw:
if "body" in urlopen_kw:
raise TypeError(
"request got values for both 'fields' and 'body', can only specify one.")
"request got values for both 'fields' and 'body', can only specify one."
)
if encode_multipart:
body, content_type = encode_multipart_formdata(fields, boundary=multipart_boundary)
body, content_type = encode_multipart_formdata(
fields, boundary=multipart_boundary
)
else:
body, content_type = urlencode(fields), 'application/x-www-form-urlencoded'
body, content_type = (
urlencode(fields),
"application/x-www-form-urlencoded",
)
extra_kw['body'] = body
extra_kw['headers'] = {'Content-Type': content_type}
extra_kw["body"] = body
extra_kw["headers"] = {"Content-Type": content_type}
extra_kw['headers'].update(headers)
extra_kw["headers"].update(headers)
extra_kw.update(urlopen_kw)
return self.urlopen(method, url, **extra_kw)

View File

@@ -13,8 +13,14 @@ except ImportError:
from ._collections import HTTPHeaderDict
from .exceptions import (
BodyNotHttplibCompatible, ProtocolError, DecodeError, ReadTimeoutError,
ResponseNotChunked, IncompleteRead, InvalidHeader
BodyNotHttplibCompatible,
ProtocolError,
DecodeError,
ReadTimeoutError,
ResponseNotChunked,
IncompleteRead,
InvalidHeader,
HTTPError,
)
from .packages.six import string_types as basestring, PY3
from .packages.six.moves import http_client as httplib
@@ -25,10 +31,9 @@ log = logging.getLogger(__name__)
class DeflateDecoder(object):
def __init__(self):
self._first_try = True
self._data = b''
self._data = b""
self._obj = zlib.decompressobj()
def __getattr__(self, name):
@@ -65,7 +70,6 @@ class GzipDecoderState(object):
class GzipDecoder(object):
def __init__(self):
self._obj = zlib.decompressobj(16 + zlib.MAX_WBITS)
self._state = GzipDecoderState.FIRST_MEMBER
@@ -96,6 +100,7 @@ class GzipDecoder(object):
if brotli is not None:
class BrotliDecoder(object):
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
@@ -104,14 +109,14 @@ if brotli is not None:
self._obj = brotli.Decompressor()
def decompress(self, data):
if hasattr(self._obj, 'decompress'):
if hasattr(self._obj, "decompress"):
return self._obj.decompress(data)
return self._obj.process(data)
def flush(self):
if hasattr(self._obj, 'flush'):
if hasattr(self._obj, "flush"):
return self._obj.flush()
return b''
return b""
class MultiDecoder(object):
@@ -124,7 +129,7 @@ class MultiDecoder(object):
"""
def __init__(self, modes):
self._decoders = [_get_decoder(m.strip()) for m in modes.split(',')]
self._decoders = [_get_decoder(m.strip()) for m in modes.split(",")]
def flush(self):
return self._decoders[0].flush()
@@ -136,13 +141,13 @@ class MultiDecoder(object):
def _get_decoder(mode):
if ',' in mode:
if "," in mode:
return MultiDecoder(mode)
if mode == 'gzip':
if mode == "gzip":
return GzipDecoder()
if brotli is not None and mode == 'br':
if brotli is not None and mode == "br":
return BrotliDecoder()
return DeflateDecoder()
@@ -181,16 +186,31 @@ class HTTPResponse(io.IOBase):
value of Content-Length header, if present. Otherwise, raise error.
"""
CONTENT_DECODERS = ['gzip', 'deflate']
CONTENT_DECODERS = ["gzip", "deflate"]
if brotli is not None:
CONTENT_DECODERS += ['br']
CONTENT_DECODERS += ["br"]
REDIRECT_STATUSES = [301, 302, 303, 307, 308]
def __init__(self, body='', headers=None, status=0, version=0, reason=None,
strict=0, preload_content=True, decode_content=True,
original_response=None, pool=None, connection=None, msg=None,
retries=None, enforce_content_length=False,
request_method=None, request_url=None):
def __init__(
self,
body="",
headers=None,
status=0,
version=0,
reason=None,
strict=0,
preload_content=True,
decode_content=True,
original_response=None,
pool=None,
connection=None,
msg=None,
retries=None,
enforce_content_length=False,
request_method=None,
request_url=None,
auto_close=True,
):
if isinstance(headers, HTTPHeaderDict):
self.headers = headers
@@ -203,6 +223,7 @@ class HTTPResponse(io.IOBase):
self.decode_content = decode_content
self.retries = retries
self.enforce_content_length = enforce_content_length
self.auto_close = auto_close
self._decoder = None
self._body = None
@@ -218,13 +239,13 @@ class HTTPResponse(io.IOBase):
self._pool = pool
self._connection = connection
if hasattr(body, 'read'):
if hasattr(body, "read"):
self._fp = body
# Are we using the chunked-style of transfer encoding?
self.chunked = False
self.chunk_left = None
tr_enc = self.headers.get('transfer-encoding', '').lower()
tr_enc = self.headers.get("transfer-encoding", "").lower()
# Don't incur the penalty of creating a list and then discarding it
encodings = (enc.strip() for enc in tr_enc.split(","))
if "chunked" in encodings:
@@ -246,7 +267,7 @@ class HTTPResponse(io.IOBase):
location. ``False`` if not a redirect status code.
"""
if self.status in self.REDIRECT_STATUSES:
return self.headers.get('location')
return self.headers.get("location")
return False
@@ -257,6 +278,17 @@ class HTTPResponse(io.IOBase):
self._pool._put_conn(self._connection)
self._connection = None
def drain_conn(self):
"""
Read and discard any remaining HTTP response data in the response connection.
Unread data in the HTTPResponse connection blocks the connection from being released back to the pool.
"""
try:
self.read()
except (HTTPError, SocketError, BaseSSLError, HTTPException):
pass
@property
def data(self):
# For backwords-compat with earlier urllib3 0.4 and earlier.
@@ -285,18 +317,20 @@ class HTTPResponse(io.IOBase):
"""
Set initial length value for Response content if available.
"""
length = self.headers.get('content-length')
length = self.headers.get("content-length")
if length is not None:
if self.chunked:
# This Response will fail with an IncompleteRead if it can't be
# received as chunked. This method falls back to attempt reading
# the response before raising an exception.
log.warning("Received response with both Content-Length and "
"Transfer-Encoding set. This is expressly forbidden "
"by RFC 7230 sec 3.3.2. Ignoring Content-Length and "
"attempting to process response as Transfer-Encoding: "
"chunked.")
log.warning(
"Received response with both Content-Length and "
"Transfer-Encoding set. This is expressly forbidden "
"by RFC 7230 sec 3.3.2. Ignoring Content-Length and "
"attempting to process response as Transfer-Encoding: "
"chunked."
)
return None
try:
@@ -305,10 +339,12 @@ class HTTPResponse(io.IOBase):
# (e.g. Content-Length: 42, 42). This line ensures the values
# are all valid ints and that as long as the `set` length is 1,
# all values are the same. Otherwise, the header is invalid.
lengths = set([int(val) for val in length.split(',')])
lengths = set([int(val) for val in length.split(",")])
if len(lengths) > 1:
raise InvalidHeader("Content-Length contained multiple "
"unmatching values (%s)" % length)
raise InvalidHeader(
"Content-Length contained multiple "
"unmatching values (%s)" % length
)
length = lengths.pop()
except ValueError:
length = None
@@ -324,7 +360,7 @@ class HTTPResponse(io.IOBase):
status = 0
# Check for responses that shouldn't include a body
if status in (204, 304) or 100 <= status < 200 or request_method == 'HEAD':
if status in (204, 304) or 100 <= status < 200 or request_method == "HEAD":
length = 0
return length
@@ -335,14 +371,16 @@ class HTTPResponse(io.IOBase):
"""
# Note: content-encoding value should be case-insensitive, per RFC 7230
# Section 3.2
content_encoding = self.headers.get('content-encoding', '').lower()
content_encoding = self.headers.get("content-encoding", "").lower()
if self._decoder is None:
if content_encoding in self.CONTENT_DECODERS:
self._decoder = _get_decoder(content_encoding)
elif ',' in content_encoding:
elif "," in content_encoding:
encodings = [
e.strip() for e in content_encoding.split(',')
if e.strip() in self.CONTENT_DECODERS]
e.strip()
for e in content_encoding.split(",")
if e.strip() in self.CONTENT_DECODERS
]
if len(encodings):
self._decoder = _get_decoder(content_encoding)
@@ -361,10 +399,12 @@ class HTTPResponse(io.IOBase):
if self._decoder:
data = self._decoder.decompress(data)
except self.DECODER_ERROR_CLASSES as e:
content_encoding = self.headers.get('content-encoding', '').lower()
content_encoding = self.headers.get("content-encoding", "").lower()
raise DecodeError(
"Received response with content-encoding: %s, but "
"failed to decode it." % content_encoding, e)
"failed to decode it." % content_encoding,
e,
)
if flush_decoder:
data += self._flush_decoder()
@@ -376,10 +416,10 @@ class HTTPResponse(io.IOBase):
being used.
"""
if self._decoder:
buf = self._decoder.decompress(b'')
buf = self._decoder.decompress(b"")
return buf + self._decoder.flush()
return b''
return b""
@contextmanager
def _error_catcher(self):
@@ -399,20 +439,20 @@ class HTTPResponse(io.IOBase):
except SocketTimeout:
# FIXME: Ideally we'd like to include the url in the ReadTimeoutError but
# there is yet no clean way to get at it from this context.
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except BaseSSLError as e:
# FIXME: Is there a better way to differentiate between SSLErrors?
if 'read operation timed out' not in str(e): # Defensive:
if "read operation timed out" not in str(e): # Defensive:
# This shouldn't happen but just in case we're missing an edge
# case, let's avoid swallowing SSL errors.
raise
raise ReadTimeoutError(self._pool, None, 'Read timed out.')
raise ReadTimeoutError(self._pool, None, "Read timed out.")
except (HTTPException, SocketError) as e:
# This includes IncompleteRead.
raise ProtocolError('Connection broken: %r' % e, e)
raise ProtocolError("Connection broken: %r" % e, e)
# If no exception is thrown, we should avoid cleaning up
# unnecessarily.
@@ -467,17 +507,19 @@ class HTTPResponse(io.IOBase):
return
flush_decoder = False
data = None
fp_closed = getattr(self._fp, "closed", False)
with self._error_catcher():
if amt is None:
# cStringIO doesn't like amt=None
data = self._fp.read()
data = self._fp.read() if not fp_closed else b""
flush_decoder = True
else:
cache_content = False
data = self._fp.read(amt)
if amt != 0 and not data: # Platform-specific: Buggy versions of Python.
data = self._fp.read(amt) if not fp_closed else b""
if (
amt != 0 and not data
): # Platform-specific: Buggy versions of Python.
# Close the connection when no data is returned
#
# This is redundant to what httplib/http.client _should_
@@ -487,7 +529,10 @@ class HTTPResponse(io.IOBase):
# no harm in redundantly calling close.
self._fp.close()
flush_decoder = True
if self.enforce_content_length and self.length_remaining not in (0, None):
if self.enforce_content_length and self.length_remaining not in (
0,
None,
):
# This is an edge case that httplib failed to cover due
# to concerns of backward compatibility. We're
# addressing it here to make sure IncompleteRead is
@@ -507,7 +552,7 @@ class HTTPResponse(io.IOBase):
return data
def stream(self, amt=2**16, decode_content=None):
def stream(self, amt=2 ** 16, decode_content=None):
"""
A generator wrapper for the read() method. A call will block until
``amt`` bytes have been read from the connection or until the
@@ -552,15 +597,17 @@ class HTTPResponse(io.IOBase):
headers = HTTPHeaderDict.from_httplib(headers)
# HTTPResponse objects in Python 3 don't have a .strict attribute
strict = getattr(r, 'strict', 0)
resp = ResponseCls(body=r,
headers=headers,
status=r.status,
version=r.version,
reason=r.reason,
strict=strict,
original_response=r,
**response_kw)
strict = getattr(r, "strict", 0)
resp = ResponseCls(
body=r,
headers=headers,
status=r.status,
version=r.version,
reason=r.reason,
strict=strict,
original_response=r,
**response_kw
)
return resp
# Backwards-compatibility methods for httplib.HTTPResponse
@@ -582,13 +629,18 @@ class HTTPResponse(io.IOBase):
if self._connection:
self._connection.close()
if not self.auto_close:
io.IOBase.close(self)
@property
def closed(self):
if self._fp is None:
if not self.auto_close:
return io.IOBase.closed.__get__(self)
elif self._fp is None:
return True
elif hasattr(self._fp, 'isclosed'):
elif hasattr(self._fp, "isclosed"):
return self._fp.isclosed()
elif hasattr(self._fp, 'closed'):
elif hasattr(self._fp, "closed"):
return self._fp.closed
else:
return True
@@ -599,11 +651,17 @@ class HTTPResponse(io.IOBase):
elif hasattr(self._fp, "fileno"):
return self._fp.fileno()
else:
raise IOError("The file-like object this HTTPResponse is wrapped "
"around has no file descriptor")
raise IOError(
"The file-like object this HTTPResponse is wrapped "
"around has no file descriptor"
)
def flush(self):
if self._fp is not None and hasattr(self._fp, 'flush'):
if (
self._fp is not None
and hasattr(self._fp, "flush")
and not getattr(self._fp, "closed", False)
):
return self._fp.flush()
def readable(self):
@@ -616,7 +674,7 @@ class HTTPResponse(io.IOBase):
if len(temp) == 0:
return 0
else:
b[:len(temp)] = temp
b[: len(temp)] = temp
return len(temp)
def supports_chunked_reads(self):
@@ -626,7 +684,7 @@ class HTTPResponse(io.IOBase):
attribute. If it is present we assume it returns raw chunks as
processed by read_chunked().
"""
return hasattr(self._fp, 'fp')
return hasattr(self._fp, "fp")
def _update_chunk_length(self):
# First, we'll figure out length of a chunk and then
@@ -634,7 +692,7 @@ class HTTPResponse(io.IOBase):
if self.chunk_left is not None:
return
line = self._fp.fp.readline()
line = line.split(b';', 1)[0]
line = line.split(b";", 1)[0]
try:
self.chunk_left = int(line, 16)
except ValueError:
@@ -683,11 +741,13 @@ class HTTPResponse(io.IOBase):
if not self.chunked:
raise ResponseNotChunked(
"Response is not chunked. "
"Header 'transfer-encoding: chunked' is missing.")
"Header 'transfer-encoding: chunked' is missing."
)
if not self.supports_chunked_reads():
raise BodyNotHttplibCompatible(
"Body should be httplib.HTTPResponse like. "
"It should have have an fp attribute which returns raw chunks.")
"It should have have an fp attribute which returns raw chunks."
)
with self._error_catcher():
# Don't bother reading the body of a HEAD request.
@@ -705,8 +765,9 @@ class HTTPResponse(io.IOBase):
if self.chunk_left == 0:
break
chunk = self._handle_chunk(amt)
decoded = self._decode(chunk, decode_content=decode_content,
flush_decoder=False)
decoded = self._decode(
chunk, decode_content=decode_content, flush_decoder=False
)
if decoded:
yield decoded
@@ -724,7 +785,7 @@ class HTTPResponse(io.IOBase):
if not line:
# Some sites may not end with '\r\n'.
break
if line == b'\r\n':
if line == b"\r\n":
break
# We read everything; close the "file".
@@ -743,7 +804,7 @@ class HTTPResponse(io.IOBase):
return self._request_url
def __iter__(self):
buffer = [b""]
buffer = []
for chunk in self.stream(decode_content=True):
if b"\n" in chunk:
chunk = chunk.split(b"\n")

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import
# For backwards compatibility, provide imports that used to be here.
from .connection import is_connection_dropped
from .request import make_headers
@@ -14,43 +15,32 @@ from .ssl_ import (
ssl_wrap_socket,
PROTOCOL_TLS,
)
from .timeout import (
current_time,
Timeout,
)
from .timeout import current_time, Timeout
from .retry import Retry
from .url import (
get_host,
parse_url,
split_first,
Url,
)
from .wait import (
wait_for_read,
wait_for_write
)
from .url import get_host, parse_url, split_first, Url
from .wait import wait_for_read, wait_for_write
__all__ = (
'HAS_SNI',
'IS_PYOPENSSL',
'IS_SECURETRANSPORT',
'SSLContext',
'PROTOCOL_TLS',
'Retry',
'Timeout',
'Url',
'assert_fingerprint',
'current_time',
'is_connection_dropped',
'is_fp_closed',
'get_host',
'parse_url',
'make_headers',
'resolve_cert_reqs',
'resolve_ssl_version',
'split_first',
'ssl_wrap_socket',
'wait_for_read',
'wait_for_write'
"HAS_SNI",
"IS_PYOPENSSL",
"IS_SECURETRANSPORT",
"SSLContext",
"PROTOCOL_TLS",
"Retry",
"Timeout",
"Url",
"assert_fingerprint",
"current_time",
"is_connection_dropped",
"is_fp_closed",
"get_host",
"parse_url",
"make_headers",
"resolve_cert_reqs",
"resolve_ssl_version",
"split_first",
"ssl_wrap_socket",
"wait_for_read",
"wait_for_write",
)

View File

@@ -14,7 +14,7 @@ def is_connection_dropped(conn): # Platform-specific
Note: For platforms like AppEngine, this will always return ``False`` to
let the platform handle connection recycling transparently for us.
"""
sock = getattr(conn, 'sock', False)
sock = getattr(conn, "sock", False)
if sock is False: # Platform-specific: AppEngine
return False
if sock is None: # Connection already closed (such as by httplib).
@@ -30,8 +30,12 @@ def is_connection_dropped(conn): # Platform-specific
# library test suite. Added to its signature is only `socket_options`.
# One additional modification is that we avoid binding to IPv6 servers
# discovered in DNS if the system doesn't have IPv6 functionality.
def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None, socket_options=None):
def create_connection(
address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None,
socket_options=None,
):
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
@@ -45,8 +49,8 @@ def create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
"""
host, port = address
if host.startswith('['):
host = host.strip('[]')
if host.startswith("["):
host = host.strip("[]")
err = None
# Using the value from allowed_gai_family() in the context of getaddrinfo lets
@@ -117,7 +121,7 @@ def _has_ipv6(host):
# has_ipv6 returns true if cPython was compiled with IPv6 support.
# It does not tell us if the system has IPv6 support enabled. To
# determine that we must bind to an IPv6 address.
# https://github.com/shazow/urllib3/pull/611
# https://github.com/urllib3/urllib3/pull/611
# https://bugs.python.org/issue658327
try:
sock = socket.socket(socket.AF_INET6)
@@ -131,4 +135,4 @@ def _has_ipv6(host):
return has_ipv6
HAS_IPV6 = _has_ipv6('::1')
HAS_IPV6 = _has_ipv6("::1")

View File

@@ -4,19 +4,25 @@ from base64 import b64encode
from ..packages.six import b, integer_types
from ..exceptions import UnrewindableBodyError
ACCEPT_ENCODING = 'gzip,deflate'
ACCEPT_ENCODING = "gzip,deflate"
try:
import brotli as _unused_module_brotli # noqa: F401
except ImportError:
pass
else:
ACCEPT_ENCODING += ',br'
ACCEPT_ENCODING += ",br"
_FAILEDTELL = object()
def make_headers(keep_alive=None, accept_encoding=None, user_agent=None,
basic_auth=None, proxy_basic_auth=None, disable_cache=None):
def make_headers(
keep_alive=None,
accept_encoding=None,
user_agent=None,
basic_auth=None,
proxy_basic_auth=None,
disable_cache=None,
):
"""
Shortcuts for generating request headers.
@@ -56,27 +62,27 @@ def make_headers(keep_alive=None, accept_encoding=None, user_agent=None,
if isinstance(accept_encoding, str):
pass
elif isinstance(accept_encoding, list):
accept_encoding = ','.join(accept_encoding)
accept_encoding = ",".join(accept_encoding)
else:
accept_encoding = ACCEPT_ENCODING
headers['accept-encoding'] = accept_encoding
headers["accept-encoding"] = accept_encoding
if user_agent:
headers['user-agent'] = user_agent
headers["user-agent"] = user_agent
if keep_alive:
headers['connection'] = 'keep-alive'
headers["connection"] = "keep-alive"
if basic_auth:
headers['authorization'] = 'Basic ' + \
b64encode(b(basic_auth)).decode('utf-8')
headers["authorization"] = "Basic " + b64encode(b(basic_auth)).decode("utf-8")
if proxy_basic_auth:
headers['proxy-authorization'] = 'Basic ' + \
b64encode(b(proxy_basic_auth)).decode('utf-8')
headers["proxy-authorization"] = "Basic " + b64encode(
b(proxy_basic_auth)
).decode("utf-8")
if disable_cache:
headers['cache-control'] = 'no-cache'
headers["cache-control"] = "no-cache"
return headers
@@ -88,7 +94,7 @@ def set_file_position(body, pos):
"""
if pos is not None:
rewind_body(body, pos)
elif getattr(body, 'tell', None) is not None:
elif getattr(body, "tell", None) is not None:
try:
pos = body.tell()
except (IOError, OSError):
@@ -110,16 +116,20 @@ def rewind_body(body, body_pos):
:param int pos:
Position to seek to in file.
"""
body_seek = getattr(body, 'seek', None)
body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types):
try:
body_seek(body_pos)
except (IOError, OSError):
raise UnrewindableBodyError("An error occurred when rewinding request "
"body for redirect/retry.")
raise UnrewindableBodyError(
"An error occurred when rewinding request body for redirect/retry."
)
elif body_pos is _FAILEDTELL:
raise UnrewindableBodyError("Unable to record file position for rewinding "
"request body during a redirect/retry.")
raise UnrewindableBodyError(
"Unable to record file position for rewinding "
"request body during a redirect/retry."
)
else:
raise ValueError("body_pos must be of type integer, "
"instead it was %s." % type(body_pos))
raise ValueError(
"body_pos must be of type integer, instead it was %s." % type(body_pos)
)

View File

@@ -52,11 +52,10 @@ def assert_header_parsing(headers):
# This will fail silently if we pass in the wrong kind of parameter.
# To make debugging easier add an explicit check.
if not isinstance(headers, httplib.HTTPMessage):
raise TypeError('expected httplib.Message, got {0}.'.format(
type(headers)))
raise TypeError("expected httplib.Message, got {0}.".format(type(headers)))
defects = getattr(headers, 'defects', None)
get_payload = getattr(headers, 'get_payload', None)
defects = getattr(headers, "defects", None)
get_payload = getattr(headers, "get_payload", None)
unparsed_data = None
if get_payload:
@@ -84,4 +83,4 @@ def is_response_to_head(response):
method = response._method
if isinstance(method, int): # Platform-specific: Appengine
return method == 3
return method.upper() == 'HEAD'
return method.upper() == "HEAD"

View File

@@ -13,6 +13,7 @@ from ..exceptions import (
ReadTimeoutError,
ResponseError,
InvalidHeader,
ProxyError,
)
from ..packages import six
@@ -21,8 +22,9 @@ log = logging.getLogger(__name__)
# Data structure for representing the metadata of requests that result in a retry.
RequestHistory = namedtuple('RequestHistory', ["method", "url", "error",
"status", "redirect_location"])
RequestHistory = namedtuple(
"RequestHistory", ["method", "url", "error", "status", "redirect_location"]
)
class Retry(object):
@@ -146,21 +148,33 @@ class Retry(object):
request.
"""
DEFAULT_METHOD_WHITELIST = frozenset([
'HEAD', 'GET', 'PUT', 'DELETE', 'OPTIONS', 'TRACE'])
DEFAULT_METHOD_WHITELIST = frozenset(
["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"]
)
RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503])
DEFAULT_REDIRECT_HEADERS_BLACKLIST = frozenset(['Authorization'])
DEFAULT_REDIRECT_HEADERS_BLACKLIST = frozenset(["Authorization"])
#: Maximum backoff time.
BACKOFF_MAX = 120
def __init__(self, total=10, connect=None, read=None, redirect=None, status=None,
method_whitelist=DEFAULT_METHOD_WHITELIST, status_forcelist=None,
backoff_factor=0, raise_on_redirect=True, raise_on_status=True,
history=None, respect_retry_after_header=True,
remove_headers_on_redirect=DEFAULT_REDIRECT_HEADERS_BLACKLIST):
def __init__(
self,
total=10,
connect=None,
read=None,
redirect=None,
status=None,
method_whitelist=DEFAULT_METHOD_WHITELIST,
status_forcelist=None,
backoff_factor=0,
raise_on_redirect=True,
raise_on_status=True,
history=None,
respect_retry_after_header=True,
remove_headers_on_redirect=DEFAULT_REDIRECT_HEADERS_BLACKLIST,
):
self.total = total
self.connect = connect
@@ -179,20 +193,25 @@ class Retry(object):
self.raise_on_status = raise_on_status
self.history = history or tuple()
self.respect_retry_after_header = respect_retry_after_header
self.remove_headers_on_redirect = frozenset([
h.lower() for h in remove_headers_on_redirect])
self.remove_headers_on_redirect = frozenset(
[h.lower() for h in remove_headers_on_redirect]
)
def new(self, **kw):
params = dict(
total=self.total,
connect=self.connect, read=self.read, redirect=self.redirect, status=self.status,
connect=self.connect,
read=self.read,
redirect=self.redirect,
status=self.status,
method_whitelist=self.method_whitelist,
status_forcelist=self.status_forcelist,
backoff_factor=self.backoff_factor,
raise_on_redirect=self.raise_on_redirect,
raise_on_status=self.raise_on_status,
history=self.history,
remove_headers_on_redirect=self.remove_headers_on_redirect
remove_headers_on_redirect=self.remove_headers_on_redirect,
respect_retry_after_header=self.respect_retry_after_header,
)
params.update(kw)
return type(self)(**params)
@@ -217,8 +236,11 @@ class Retry(object):
:rtype: float
"""
# We want to consider only the last consecutive errors sequence (Ignore redirects).
consecutive_errors_len = len(list(takewhile(lambda x: x.redirect_location is None,
reversed(self.history))))
consecutive_errors_len = len(
list(
takewhile(lambda x: x.redirect_location is None, reversed(self.history))
)
)
if consecutive_errors_len <= 1:
return 0
@@ -274,7 +296,7 @@ class Retry(object):
this method will return immediately.
"""
if response:
if self.respect_retry_after_header and response:
slept = self.sleep_for_retry(response)
if slept:
return
@@ -285,6 +307,8 @@ class Retry(object):
""" Errors when we're fairly sure that the server did not receive the
request, so it should be safe to retry.
"""
if isinstance(err, ProxyError):
err = err.original_error
return isinstance(err, ConnectTimeoutError)
def _is_read_error(self, err):
@@ -315,8 +339,12 @@ class Retry(object):
if self.status_forcelist and status_code in self.status_forcelist:
return True
return (self.total and self.respect_retry_after_header and
has_retry_after and (status_code in self.RETRY_AFTER_STATUS_CODES))
return (
self.total
and self.respect_retry_after_header
and has_retry_after
and (status_code in self.RETRY_AFTER_STATUS_CODES)
)
def is_exhausted(self):
""" Are we out of retries? """
@@ -327,8 +355,15 @@ class Retry(object):
return min(retry_counts) < 0
def increment(self, method=None, url=None, response=None, error=None,
_pool=None, _stacktrace=None):
def increment(
self,
method=None,
url=None,
response=None,
error=None,
_pool=None,
_stacktrace=None,
):
""" Return a new Retry object with incremented retry counters.
:param response: A response object, or None, if the server did not
@@ -351,7 +386,7 @@ class Retry(object):
read = self.read
redirect = self.redirect
status_count = self.status
cause = 'unknown'
cause = "unknown"
status = None
redirect_location = None
@@ -373,7 +408,7 @@ class Retry(object):
# Redirect retry?
if redirect is not None:
redirect -= 1
cause = 'too many redirects'
cause = "too many redirects"
redirect_location = response.get_redirect_location()
status = response.status
@@ -384,16 +419,21 @@ class Retry(object):
if response and response.status:
if status_count is not None:
status_count -= 1
cause = ResponseError.SPECIFIC_ERROR.format(
status_code=response.status)
cause = ResponseError.SPECIFIC_ERROR.format(status_code=response.status)
status = response.status
history = self.history + (RequestHistory(method, url, error, status, redirect_location),)
history = self.history + (
RequestHistory(method, url, error, status, redirect_location),
)
new_retry = self.new(
total=total,
connect=connect, read=read, redirect=redirect, status=status_count,
history=history)
connect=connect,
read=read,
redirect=redirect,
status=status_count,
history=history,
)
if new_retry.is_exhausted():
raise MaxRetryError(_pool, url, error or ResponseError(cause))
@@ -403,9 +443,10 @@ class Retry(object):
return new_retry
def __repr__(self):
return ('{cls.__name__}(total={self.total}, connect={self.connect}, '
'read={self.read}, redirect={self.redirect}, status={self.status})').format(
cls=type(self), self=self)
return (
"{cls.__name__}(total={self.total}, connect={self.connect}, "
"read={self.read}, redirect={self.redirect}, status={self.status})"
).format(cls=type(self), self=self)
# For backwards compatibility (equivalent to pre-v1.9):

View File

@@ -2,14 +2,14 @@ from __future__ import absolute_import
import errno
import warnings
import hmac
import re
import sys
from binascii import hexlify, unhexlify
from hashlib import md5, sha1, sha256
from .url import IPV4_RE, BRACELESS_IPV6_ADDRZ_RE
from ..exceptions import SSLError, InsecurePlatformWarning, SNIMissingWarning
from ..packages import six
from ..packages.rfc3986 import abnf_regexp
SSLContext = None
@@ -18,11 +18,7 @@ IS_PYOPENSSL = False
IS_SECURETRANSPORT = False
# Maps the length of a digest to a possible hash function producing this digest
HASHFUNC_MAP = {
32: md5,
40: sha1,
64: sha256,
}
HASHFUNC_MAP = {32: md5, 40: sha1, 64: sha256}
def _const_compare_digest_backport(a, b):
@@ -38,18 +34,7 @@ def _const_compare_digest_backport(a, b):
return result == 0
_const_compare_digest = getattr(hmac, 'compare_digest',
_const_compare_digest_backport)
# Borrow rfc3986's regular expressions for IPv4
# and IPv6 addresses for use in is_ipaddress()
_IP_ADDRESS_REGEX = re.compile(
r'^(?:%s|%s|%s)$' % (
abnf_regexp.IPv4_RE,
abnf_regexp.IPv6_RE,
abnf_regexp.IPv6_ADDRZ_RFC4007_RE
)
)
_const_compare_digest = getattr(hmac, "compare_digest", _const_compare_digest_backport)
try: # Test for SSL features
import ssl
@@ -60,10 +45,12 @@ except ImportError:
try: # Platform-specific: Python 3.6
from ssl import PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
try:
from ssl import PROTOCOL_SSLv23 as PROTOCOL_TLS
PROTOCOL_SSLv23 = PROTOCOL_TLS
except ImportError:
PROTOCOL_SSLv23 = PROTOCOL_TLS = 2
@@ -93,26 +80,29 @@ except ImportError:
# insecure ciphers for security reasons.
# - NOTE: TLS 1.3 cipher suites are managed through a different interface
# not exposed by CPython (yet!) and are enabled by default if they're available.
DEFAULT_CIPHERS = ':'.join([
'ECDHE+AESGCM',
'ECDHE+CHACHA20',
'DHE+AESGCM',
'DHE+CHACHA20',
'ECDH+AESGCM',
'DH+AESGCM',
'ECDH+AES',
'DH+AES',
'RSA+AESGCM',
'RSA+AES',
'!aNULL',
'!eNULL',
'!MD5',
'!DSS',
])
DEFAULT_CIPHERS = ":".join(
[
"ECDHE+AESGCM",
"ECDHE+CHACHA20",
"DHE+AESGCM",
"DHE+CHACHA20",
"ECDH+AESGCM",
"DH+AESGCM",
"ECDH+AES",
"DH+AES",
"RSA+AESGCM",
"RSA+AES",
"!aNULL",
"!eNULL",
"!MD5",
"!DSS",
]
)
try:
from ssl import SSLContext # Modern SSL?
except ImportError:
class SSLContext(object): # Platform-specific: Python 2
def __init__(self, protocol_version):
self.protocol = protocol_version
@@ -129,32 +119,35 @@ except ImportError:
self.certfile = certfile
self.keyfile = keyfile
def load_verify_locations(self, cafile=None, capath=None):
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
self.ca_certs = cafile
if capath is not None:
raise SSLError("CA directories not supported in older Pythons")
if cadata is not None:
raise SSLError("CA data not supported in older Pythons")
def set_ciphers(self, cipher_suite):
self.ciphers = cipher_suite
def wrap_socket(self, socket, server_hostname=None, server_side=False):
warnings.warn(
'A true SSLContext object is not available. This prevents '
'urllib3 from configuring SSL appropriately and may cause '
'certain SSL connections to fail. You can upgrade to a newer '
'version of Python to solve this. For more information, see '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings',
InsecurePlatformWarning
"A true SSLContext object is not available. This prevents "
"urllib3 from configuring SSL appropriately and may cause "
"certain SSL connections to fail. You can upgrade to a newer "
"version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings",
InsecurePlatformWarning,
)
kwargs = {
'keyfile': self.keyfile,
'certfile': self.certfile,
'ca_certs': self.ca_certs,
'cert_reqs': self.verify_mode,
'ssl_version': self.protocol,
'server_side': server_side,
"keyfile": self.keyfile,
"certfile": self.certfile,
"ca_certs": self.ca_certs,
"cert_reqs": self.verify_mode,
"ssl_version": self.protocol,
"server_side": server_side,
}
return wrap_socket(socket, ciphers=self.ciphers, **kwargs)
@@ -169,12 +162,11 @@ def assert_fingerprint(cert, fingerprint):
Fingerprint as string of hexdigits, can be interspersed by colons.
"""
fingerprint = fingerprint.replace(':', '').lower()
fingerprint = fingerprint.replace(":", "").lower()
digest_length = len(fingerprint)
hashfunc = HASHFUNC_MAP.get(digest_length)
if not hashfunc:
raise SSLError(
'Fingerprint of invalid length: {0}'.format(fingerprint))
raise SSLError("Fingerprint of invalid length: {0}".format(fingerprint))
# We need encode() here for py32; works on py2 and p33.
fingerprint_bytes = unhexlify(fingerprint.encode())
@@ -182,15 +174,18 @@ def assert_fingerprint(cert, fingerprint):
cert_digest = hashfunc(cert).digest()
if not _const_compare_digest(cert_digest, fingerprint_bytes):
raise SSLError('Fingerprints did not match. Expected "{0}", got "{1}".'
.format(fingerprint, hexlify(cert_digest)))
raise SSLError(
'Fingerprints did not match. Expected "{0}", got "{1}".'.format(
fingerprint, hexlify(cert_digest)
)
)
def resolve_cert_reqs(candidate):
"""
Resolves the argument to a numeric constant, which can be passed to
the wrap_socket function/method from the ssl module.
Defaults to :data:`ssl.CERT_NONE`.
Defaults to :data:`ssl.CERT_REQUIRED`.
If given a string it is assumed to be the name of the constant in the
:mod:`ssl` module or its abbreviation.
(So you can specify `REQUIRED` instead of `CERT_REQUIRED`.
@@ -203,7 +198,7 @@ def resolve_cert_reqs(candidate):
if isinstance(candidate, str):
res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, 'CERT_' + candidate)
res = getattr(ssl, "CERT_" + candidate)
return res
return candidate
@@ -219,14 +214,15 @@ def resolve_ssl_version(candidate):
if isinstance(candidate, str):
res = getattr(ssl, candidate, None)
if res is None:
res = getattr(ssl, 'PROTOCOL_' + candidate)
res = getattr(ssl, "PROTOCOL_" + candidate)
return res
return candidate
def create_urllib3_context(ssl_version=None, cert_reqs=None,
options=None, ciphers=None):
def create_urllib3_context(
ssl_version=None, cert_reqs=None, options=None, ciphers=None
):
"""All arguments have the same meaning as ``ssl_wrap_socket``.
By default, this function does a lot of the same work that
@@ -279,18 +275,41 @@ def create_urllib3_context(ssl_version=None, cert_reqs=None,
context.options |= options
# Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is
# necessary for conditional client cert authentication with TLS 1.3.
# The attribute is None for OpenSSL <= 1.1.0 or does not exist in older
# versions of Python. We only enable on Python 3.7.4+ or if certificate
# verification is enabled to work around Python issue #37428
# See: https://bugs.python.org/issue37428
if (cert_reqs == ssl.CERT_REQUIRED or sys.version_info >= (3, 7, 4)) and getattr(
context, "post_handshake_auth", None
) is not None:
context.post_handshake_auth = True
context.verify_mode = cert_reqs
if getattr(context, 'check_hostname', None) is not None: # Platform-specific: Python 3.2
if (
getattr(context, "check_hostname", None) is not None
): # Platform-specific: Python 3.2
# We do our own verification, including fingerprints and alternative
# hostnames. So disable it here
context.check_hostname = False
return context
def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
ca_certs=None, server_hostname=None,
ssl_version=None, ciphers=None, ssl_context=None,
ca_cert_dir=None, key_password=None):
def ssl_wrap_socket(
sock,
keyfile=None,
certfile=None,
cert_reqs=None,
ca_certs=None,
server_hostname=None,
ssl_version=None,
ciphers=None,
ssl_context=None,
ca_cert_dir=None,
key_password=None,
ca_cert_data=None,
):
"""
All arguments except for server_hostname, ssl_context, and ca_cert_dir have
the same meaning as they do when using :func:`ssl.wrap_socket`.
@@ -308,18 +327,20 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
SSLContext.load_verify_locations().
:param key_password:
Optional password if the keyfile is encrypted.
:param ca_cert_data:
Optional string containing CA certificates in PEM format suitable for
passing as the cadata parameter to SSLContext.load_verify_locations()
"""
context = ssl_context
if context is None:
# Note: This branch of code and all the variables in it are no longer
# used by urllib3 itself. We should consider deprecating and removing
# this code.
context = create_urllib3_context(ssl_version, cert_reqs,
ciphers=ciphers)
context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers)
if ca_certs or ca_cert_dir:
if ca_certs or ca_cert_dir or ca_cert_data:
try:
context.load_verify_locations(ca_certs, ca_cert_dir)
context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data)
except IOError as e: # Platform-specific: Python 2.7
raise SSLError(e)
# Py33 raises FileNotFoundError which subclasses OSError
@@ -329,7 +350,7 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
raise SSLError(e)
raise
elif ssl_context is None and hasattr(context, 'load_default_certs'):
elif ssl_context is None and hasattr(context, "load_default_certs"):
# try to load OS default certs; works well on Windows (require Python3.4+)
context.load_default_certs()
@@ -349,20 +370,21 @@ def ssl_wrap_socket(sock, keyfile=None, certfile=None, cert_reqs=None,
# extension should not be used according to RFC3546 Section 3.1
# We shouldn't warn the user if SNI isn't available but we would
# not be using SNI anyways due to IP address for server_hostname.
if ((server_hostname is not None and not is_ipaddress(server_hostname))
or IS_SECURETRANSPORT):
if (
server_hostname is not None and not is_ipaddress(server_hostname)
) or IS_SECURETRANSPORT:
if HAS_SNI and server_hostname is not None:
return context.wrap_socket(sock, server_hostname=server_hostname)
warnings.warn(
'An HTTPS request has been made, but the SNI (Server Name '
'Indication) extension to TLS is not available on this platform. '
'This may cause the server to present an incorrect TLS '
'certificate, which can cause validation failures. You can upgrade to '
'a newer version of Python to solve this. For more information, see '
'https://urllib3.readthedocs.io/en/latest/advanced-usage.html'
'#ssl-warnings',
SNIMissingWarning
"An HTTPS request has been made, but the SNI (Server Name "
"Indication) extension to TLS is not available on this platform. "
"This may cause the server to present an incorrect TLS "
"certificate, which can cause validation failures. You can upgrade to "
"a newer version of Python to solve this. For more information, see "
"https://urllib3.readthedocs.io/en/latest/advanced-usage.html"
"#ssl-warnings",
SNIMissingWarning,
)
return context.wrap_socket(sock)
@@ -375,18 +397,18 @@ def is_ipaddress(hostname):
:param str hostname: Hostname to examine.
:return: True if the hostname is an IP address, False otherwise.
"""
if six.PY3 and isinstance(hostname, bytes):
if not six.PY2 and isinstance(hostname, bytes):
# IDN A-label bytes are ASCII compatible.
hostname = hostname.decode('ascii')
return _IP_ADDRESS_REGEX.match(hostname) is not None
hostname = hostname.decode("ascii")
return bool(IPV4_RE.match(hostname) or BRACELESS_IPV6_ADDRZ_RE.match(hostname))
def _is_key_file_encrypted(key_file):
"""Detects if a key file is encrypted or not."""
with open(key_file, 'r') as f:
with open(key_file, "r") as f:
for line in f:
# Look for Proc-Type: 4,ENCRYPTED
if 'ENCRYPTED' in line:
if "ENCRYPTED" in line:
return True
return False

View File

@@ -1,4 +1,5 @@
from __future__ import absolute_import
# The default socket timeout, used by httplib to indicate that no timeout was
# specified by the user
from socket import _GLOBAL_DEFAULT_TIMEOUT
@@ -45,19 +46,20 @@ class Timeout(object):
:type total: integer, float, or None
:param connect:
The maximum amount of time to wait for a connection attempt to a server
to succeed. Omitting the parameter will default the connect timeout to
the system default, probably `the global default timeout in socket.py
The maximum amount of time (in seconds) to wait for a connection
attempt to a server to succeed. Omitting the parameter will default the
connect timeout to the system default, probably `the global default
timeout in socket.py
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
None will set an infinite timeout for connection attempts.
:type connect: integer, float, or None
:param read:
The maximum amount of time to wait between consecutive
read operations for a response from the server. Omitting
the parameter will default the read timeout to the system
default, probably `the global default timeout in socket.py
The maximum amount of time (in seconds) to wait between consecutive
read operations for a response from the server. Omitting the parameter
will default the read timeout to the system default, probably `the
global default timeout in socket.py
<http://hg.python.org/cpython/file/603b4d593758/Lib/socket.py#l535>`_.
None will set an infinite timeout.
@@ -91,14 +93,21 @@ class Timeout(object):
DEFAULT_TIMEOUT = _GLOBAL_DEFAULT_TIMEOUT
def __init__(self, total=None, connect=_Default, read=_Default):
self._connect = self._validate_timeout(connect, 'connect')
self._read = self._validate_timeout(read, 'read')
self.total = self._validate_timeout(total, 'total')
self._connect = self._validate_timeout(connect, "connect")
self._read = self._validate_timeout(read, "read")
self.total = self._validate_timeout(total, "total")
self._start_connect = None
def __str__(self):
return '%s(connect=%r, read=%r, total=%r)' % (
type(self).__name__, self._connect, self._read, self.total)
def __repr__(self):
return "%s(connect=%r, read=%r, total=%r)" % (
type(self).__name__,
self._connect,
self._read,
self.total,
)
# __str__ provided for backwards compatibility
__str__ = __repr__
@classmethod
def _validate_timeout(cls, value, name):
@@ -118,23 +127,31 @@ class Timeout(object):
return value
if isinstance(value, bool):
raise ValueError("Timeout cannot be a boolean value. It must "
"be an int, float or None.")
raise ValueError(
"Timeout cannot be a boolean value. It must "
"be an int, float or None."
)
try:
float(value)
except (TypeError, ValueError):
raise ValueError("Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value))
raise ValueError(
"Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value)
)
try:
if value <= 0:
raise ValueError("Attempted to set %s timeout to %s, but the "
"timeout cannot be set to a value less "
"than or equal to 0." % (name, value))
raise ValueError(
"Attempted to set %s timeout to %s, but the "
"timeout cannot be set to a value less "
"than or equal to 0." % (name, value)
)
except TypeError:
# Python 3
raise ValueError("Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value))
raise ValueError(
"Timeout value %s was %s, but it must be an "
"int, float or None." % (name, value)
)
return value
@@ -166,8 +183,7 @@ class Timeout(object):
# We can't use copy.deepcopy because that will also create a new object
# for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to
# detect the user default.
return Timeout(connect=self._connect, read=self._read,
total=self.total)
return Timeout(connect=self._connect, read=self._read, total=self.total)
def start_connect(self):
""" Start the timeout clock, used during a connect() attempt
@@ -183,14 +199,15 @@ class Timeout(object):
def get_connect_duration(self):
""" Gets the time elapsed since the call to :meth:`start_connect`.
:return: Elapsed time.
:return: Elapsed time in seconds.
:rtype: float
:raises urllib3.exceptions.TimeoutStateError: if you attempt
to get duration for a timer that hasn't been started.
"""
if self._start_connect is None:
raise TimeoutStateError("Can't get connect duration for timer "
"that has not started.")
raise TimeoutStateError(
"Can't get connect duration for timer that has not started."
)
return current_time() - self._start_connect
@property
@@ -228,15 +245,16 @@ class Timeout(object):
:raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect`
has not yet been called on this object.
"""
if (self.total is not None and
self.total is not self.DEFAULT_TIMEOUT and
self._read is not None and
self._read is not self.DEFAULT_TIMEOUT):
if (
self.total is not None
and self.total is not self.DEFAULT_TIMEOUT
and self._read is not None
and self._read is not self.DEFAULT_TIMEOUT
):
# In case the connect timeout has not yet been established.
if self._start_connect is None:
return self._read
return max(0, min(self.total - self.get_connect_duration(),
self._read))
return max(0, min(self.total - self.get_connect_duration(), self._read))
elif self.total is not None and self.total is not self.DEFAULT_TIMEOUT:
return max(0, self.total - self.get_connect_duration())
else:

View File

@@ -3,41 +3,108 @@ import re
from collections import namedtuple
from ..exceptions import LocationParseError
from ..packages import six, rfc3986
from ..packages.rfc3986.exceptions import RFC3986Exception, ValidationError
from ..packages.rfc3986.validators import Validator
from ..packages.rfc3986 import abnf_regexp, normalizers, compat, misc
from ..packages import six
url_attrs = ['scheme', 'auth', 'host', 'port', 'path', 'query', 'fragment']
url_attrs = ["scheme", "auth", "host", "port", "path", "query", "fragment"]
# We only want to normalize urls with an HTTP(S) scheme.
# urllib3 infers URLs without a scheme (None) to be http.
NORMALIZABLE_SCHEMES = ('http', 'https', None)
NORMALIZABLE_SCHEMES = ("http", "https", None)
# Regex for detecting URLs with schemes. RFC 3986 Section 3.1
SCHEME_REGEX = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+\-]*:|/)")
# Almost all of these patterns were derived from the
# 'rfc3986' module: https://github.com/python-hyper/rfc3986
PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}")
SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)")
URI_RE = re.compile(
r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?"
r"(?://([^\\/?#]*))?"
r"([^?#]*)"
r"(?:\?([^#]*))?"
r"(?:#(.*))?$",
re.UNICODE | re.DOTALL,
)
PATH_CHARS = abnf_regexp.UNRESERVED_CHARS_SET | abnf_regexp.SUB_DELIMITERS_SET | {':', '@', '/'}
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {'?'}
IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}"
HEX_PAT = "[0-9A-Fa-f]{1,4}"
LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=HEX_PAT, ipv4=IPV4_PAT)
_subs = {"hex": HEX_PAT, "ls32": LS32_PAT}
_variations = [
# 6( h16 ":" ) ls32
"(?:%(hex)s:){6}%(ls32)s",
# "::" 5( h16 ":" ) ls32
"::(?:%(hex)s:){5}%(ls32)s",
# [ h16 ] "::" 4( h16 ":" ) ls32
"(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s",
# [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32
"(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s",
# [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32
"(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s",
# [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32
"(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s",
# [ *4( h16 ":" ) h16 ] "::" ls32
"(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s",
# [ *5( h16 ":" ) h16 ] "::" h16
"(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s",
# [ *6( h16 ":" ) h16 ] "::"
"(?:(?:%(hex)s:){0,6}%(hex)s)?::",
]
UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._!\-~"
IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")"
ZONE_ID_PAT = "(?:%25|%)(?:[" + UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+"
IPV6_ADDRZ_PAT = r"\[" + IPV6_PAT + r"(?:" + ZONE_ID_PAT + r")?\]"
REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*"
TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$")
IPV4_RE = re.compile("^" + IPV4_PAT + "$")
IPV6_RE = re.compile("^" + IPV6_PAT + "$")
IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT + "$")
BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + IPV6_ADDRZ_PAT[2:-2] + "$")
ZONE_ID_RE = re.compile("(" + ZONE_ID_PAT + r")\]$")
SUBAUTHORITY_PAT = (u"^(?:(.*)@)?(%s|%s|%s)(?::([0-9]{0,5}))?$") % (
REG_NAME_PAT,
IPV4_PAT,
IPV6_ADDRZ_PAT,
)
SUBAUTHORITY_RE = re.compile(SUBAUTHORITY_PAT, re.UNICODE | re.DOTALL)
UNRESERVED_CHARS = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~"
)
SUB_DELIM_CHARS = set("!$&'()*+,;=")
USERINFO_CHARS = UNRESERVED_CHARS | SUB_DELIM_CHARS | {":"}
PATH_CHARS = USERINFO_CHARS | {"@", "/"}
QUERY_CHARS = FRAGMENT_CHARS = PATH_CHARS | {"?"}
class Url(namedtuple('Url', url_attrs)):
class Url(namedtuple("Url", url_attrs)):
"""
Data structure for representing an HTTP URL. Used as a return value for
:func:`parse_url`. Both the scheme and host are normalized as they are
both case-insensitive according to RFC 3986.
"""
__slots__ = ()
def __new__(cls, scheme=None, auth=None, host=None, port=None, path=None,
query=None, fragment=None):
if path and not path.startswith('/'):
path = '/' + path
def __new__(
cls,
scheme=None,
auth=None,
host=None,
port=None,
path=None,
query=None,
fragment=None,
):
if path and not path.startswith("/"):
path = "/" + path
if scheme is not None:
scheme = scheme.lower()
return super(Url, cls).__new__(cls, scheme, auth, host, port, path,
query, fragment)
return super(Url, cls).__new__(
cls, scheme, auth, host, port, path, query, fragment
)
@property
def hostname(self):
@@ -47,10 +114,10 @@ class Url(namedtuple('Url', url_attrs)):
@property
def request_uri(self):
"""Absolute path including the query string."""
uri = self.path or '/'
uri = self.path or "/"
if self.query is not None:
uri += '?' + self.query
uri += "?" + self.query
return uri
@@ -58,7 +125,7 @@ class Url(namedtuple('Url', url_attrs)):
def netloc(self):
"""Network location including host and port"""
if self.port:
return '%s:%d' % (self.host, self.port)
return "%s:%d" % (self.host, self.port)
return self.host
@property
@@ -81,23 +148,23 @@ class Url(namedtuple('Url', url_attrs)):
'http://username:password@host.com:80/path?query#fragment'
"""
scheme, auth, host, port, path, query, fragment = self
url = u''
url = u""
# We use "is not None" we want things to happen with empty strings (or 0 port)
if scheme is not None:
url += scheme + u'://'
url += scheme + u"://"
if auth is not None:
url += auth + u'@'
url += auth + u"@"
if host is not None:
url += host
if port is not None:
url += u':' + str(port)
url += u":" + str(port)
if path is not None:
url += path
if query is not None:
url += u'?' + query
url += u"?" + query
if fragment is not None:
url += u'#' + fragment
url += u"#" + fragment
return url
@@ -135,48 +202,140 @@ def split_first(s, delims):
min_delim = d
if min_idx is None or min_idx < 0:
return s, '', None
return s, "", None
return s[:min_idx], s[min_idx + 1:], min_delim
return s[:min_idx], s[min_idx + 1 :], min_delim
def _encode_invalid_chars(component, allowed_chars, encoding='utf-8'):
def _encode_invalid_chars(component, allowed_chars, encoding="utf-8"):
"""Percent-encodes a URI component without reapplying
onto an already percent-encoded component. Based on
rfc3986.normalizers.encode_component()
onto an already percent-encoded component.
"""
if component is None:
return component
component = six.ensure_text(component)
# Normalize existing percent-encoded bytes.
# Try to see if the component we're encoding is already percent-encoded
# so we can skip all '%' characters but still encode all others.
percent_encodings = len(normalizers.PERCENT_MATCHER.findall(
compat.to_str(component, encoding)))
uri_bytes = component.encode('utf-8', 'surrogatepass')
is_percent_encoded = percent_encodings == uri_bytes.count(b'%')
component, percent_encodings = PERCENT_RE.subn(
lambda match: match.group(0).upper(), component
)
uri_bytes = component.encode("utf-8", "surrogatepass")
is_percent_encoded = percent_encodings == uri_bytes.count(b"%")
encoded_component = bytearray()
for i in range(0, len(uri_bytes)):
# Will return a single character bytestring on both Python 2 & 3
byte = uri_bytes[i:i+1]
byte = uri_bytes[i : i + 1]
byte_ord = ord(byte)
if ((is_percent_encoded and byte == b'%')
or (byte_ord < 128 and byte.decode() in allowed_chars)):
encoded_component.extend(byte)
if (is_percent_encoded and byte == b"%") or (
byte_ord < 128 and byte.decode() in allowed_chars
):
encoded_component += byte
continue
encoded_component.extend('%{0:02x}'.format(byte_ord).encode().upper())
encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper()))
return encoded_component.decode(encoding)
def _remove_path_dot_segments(path):
# See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code
segments = path.split("/") # Turn the path into a list of segments
output = [] # Initialize the variable to use to store output
for segment in segments:
# '.' is the current directory, so ignore it, it is superfluous
if segment == ".":
continue
# Anything other than '..', should be appended to the output
elif segment != "..":
output.append(segment)
# In this case segment == '..', if we can, we should pop the last
# element
elif output:
output.pop()
# If the path starts with '/' and the output is empty or the first string
# is non-empty
if path.startswith("/") and (not output or output[0]):
output.insert(0, "")
# If the path starts with '/.' or '/..' ensure we add one more empty
# string to add a trailing '/'
if path.endswith(("/.", "/..")):
output.append("")
return "/".join(output)
def _normalize_host(host, scheme):
if host:
if isinstance(host, six.binary_type):
host = six.ensure_str(host)
if scheme in NORMALIZABLE_SCHEMES:
is_ipv6 = IPV6_ADDRZ_RE.match(host)
if is_ipv6:
match = ZONE_ID_RE.search(host)
if match:
start, end = match.span(1)
zone_id = host[start:end]
if zone_id.startswith("%25") and zone_id != "%25":
zone_id = zone_id[3:]
else:
zone_id = zone_id[1:]
zone_id = "%" + _encode_invalid_chars(zone_id, UNRESERVED_CHARS)
return host[:start].lower() + zone_id + host[end:]
else:
return host.lower()
elif not IPV4_RE.match(host):
return six.ensure_str(
b".".join([_idna_encode(label) for label in host.split(".")])
)
return host
def _idna_encode(name):
if name and any([ord(x) > 128 for x in name]):
try:
from pip._vendor import idna
except ImportError:
six.raise_from(
LocationParseError("Unable to parse URL without the 'idna' module"),
None,
)
try:
return idna.encode(name.lower(), strict=True, std3_rules=True)
except idna.IDNAError:
six.raise_from(
LocationParseError(u"Name '%s' is not a valid IDNA label" % name), None
)
return name.lower().encode("ascii")
def _encode_target(target):
"""Percent-encodes a request target so that there are no invalid characters"""
path, query = TARGET_RE.match(target).groups()
target = _encode_invalid_chars(path, PATH_CHARS)
query = _encode_invalid_chars(query, QUERY_CHARS)
if query is not None:
target += "?" + query
return target
def parse_url(url):
"""
Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is
performed to parse incomplete urls. Fields not provided will be None.
This parser is RFC 3986 compliant.
The parser logic and helper functions are based heavily on
work done in the ``rfc3986`` module.
:param str url: URL to parse into a :class:`.Url` namedtuple.
Partly backwards-compatible with :mod:`urlparse`.
@@ -194,90 +353,72 @@ def parse_url(url):
# Empty
return Url()
is_string = not isinstance(url, six.binary_type)
# RFC 3986 doesn't like URLs that have a host but don't start
# with a scheme and we support URLs like that so we need to
# detect that problem and add an empty scheme indication.
# We don't get hurt on path-only URLs here as it's stripped
# off and given an empty scheme anyways.
if not SCHEME_REGEX.search(url):
source_url = url
if not SCHEME_RE.search(url):
url = "//" + url
def idna_encode(name):
if name and any([ord(x) > 128 for x in name]):
try:
from pip._vendor import idna
except ImportError:
raise LocationParseError("Unable to parse URL without the 'idna' module")
try:
return idna.encode(name.lower(), strict=True, std3_rules=True)
except idna.IDNAError:
raise LocationParseError(u"Name '%s' is not a valid IDNA label" % name)
return name
try:
split_iri = misc.IRI_MATCHER.match(compat.to_str(url)).groupdict()
iri_ref = rfc3986.IRIReference(
split_iri['scheme'], split_iri['authority'],
_encode_invalid_chars(split_iri['path'], PATH_CHARS),
_encode_invalid_chars(split_iri['query'], QUERY_CHARS),
_encode_invalid_chars(split_iri['fragment'], FRAGMENT_CHARS)
)
has_authority = iri_ref.authority is not None
uri_ref = iri_ref.encode(idna_encoder=idna_encode)
except (ValueError, RFC3986Exception):
return six.raise_from(LocationParseError(url), None)
scheme, authority, path, query, fragment = URI_RE.match(url).groups()
normalize_uri = scheme is None or scheme.lower() in NORMALIZABLE_SCHEMES
# rfc3986 strips the authority if it's invalid
if has_authority and uri_ref.authority is None:
raise LocationParseError(url)
if scheme:
scheme = scheme.lower()
# Only normalize schemes we understand to not break http+unix
# or other schemes that don't follow RFC 3986.
if uri_ref.scheme is None or uri_ref.scheme.lower() in NORMALIZABLE_SCHEMES:
uri_ref = uri_ref.normalize()
if authority:
auth, host, port = SUBAUTHORITY_RE.match(authority).groups()
if auth and normalize_uri:
auth = _encode_invalid_chars(auth, USERINFO_CHARS)
if port == "":
port = None
else:
auth, host, port = None, None, None
# Validate all URIReference components and ensure that all
# components that were set before are still set after
# normalization has completed.
validator = Validator()
try:
validator.check_validity_of(
*validator.COMPONENT_NAMES
).validate(uri_ref)
except ValidationError:
return six.raise_from(LocationParseError(url), None)
if port is not None:
port = int(port)
if not (0 <= port <= 65535):
raise LocationParseError(url)
host = _normalize_host(host, scheme)
if normalize_uri and path:
path = _remove_path_dot_segments(path)
path = _encode_invalid_chars(path, PATH_CHARS)
if normalize_uri and query:
query = _encode_invalid_chars(query, QUERY_CHARS)
if normalize_uri and fragment:
fragment = _encode_invalid_chars(fragment, FRAGMENT_CHARS)
except (ValueError, AttributeError):
return six.raise_from(LocationParseError(source_url), None)
# For the sake of backwards compatibility we put empty
# string values for path if there are any defined values
# beyond the path in the URL.
# TODO: Remove this when we break backwards compatibility.
path = uri_ref.path
if not path:
if (uri_ref.query is not None
or uri_ref.fragment is not None):
if query is not None or fragment is not None:
path = ""
else:
path = None
# Ensure that each part of the URL is a `str` for
# backwards compatibility.
def to_input_type(x):
if x is None:
return None
elif not is_string and not isinstance(x, six.binary_type):
return x.encode('utf-8')
return x
if isinstance(url, six.text_type):
ensure_func = six.ensure_text
else:
ensure_func = six.ensure_str
def ensure_type(x):
return x if x is None else ensure_func(x)
return Url(
scheme=to_input_type(uri_ref.scheme),
auth=to_input_type(uri_ref.userinfo),
host=to_input_type(uri_ref.host),
port=int(uri_ref.port) if uri_ref.port is not None else None,
path=to_input_type(path),
query=to_input_type(uri_ref.query),
fragment=to_input_type(uri_ref.fragment)
scheme=ensure_type(scheme),
auth=ensure_type(auth),
host=ensure_type(host),
port=port,
path=ensure_type(path),
query=ensure_type(query),
fragment=ensure_type(fragment),
)
@@ -286,4 +427,4 @@ def get_host(url):
Deprecated. Use :func:`parse_url` instead.
"""
p = parse_url(url)
return p.scheme or 'http', p.hostname, p.port
return p.scheme or "http", p.hostname, p.port

View File

@@ -2,6 +2,7 @@ import errno
from functools import partial
import select
import sys
try:
from time import monotonic
except ImportError:
@@ -40,6 +41,8 @@ if sys.version_info >= (3, 5):
# Modern Python, that retries syscalls by default
def _retry_on_intr(fn, timeout):
return fn(timeout)
else:
# Old and broken Pythons.
def _retry_on_intr(fn, timeout):