Don't override socket.socket for binding, eliminiate globals SOURCE and USER_AGENT

This commit is contained in:
Matt Martz 2017-05-02 10:56:31 -05:00
parent 20e5d12a5c
commit 10b3b09f02
1 changed files with 237 additions and 61 deletions

View File

@ -51,14 +51,10 @@ class FakeShutdownEvent(object):
# Some global variables we use # Some global variables we use
USER_AGENT = None
SOURCE = None
SHUTDOWN_EVENT = FakeShutdownEvent() SHUTDOWN_EVENT = FakeShutdownEvent()
SCHEME = 'http' SCHEME = 'http'
DEBUG = False DEBUG = False
_GLOBAL_DEFAULT_TIMEOUT = object()
# Used for bound_interface
SOCKET_SOCKET = socket.socket
# Begin import game to handle Python 2 and Python 3 # Begin import game to handle Python 2 and Python 3
try: try:
@ -79,9 +75,15 @@ except ImportError:
ET = None ET = None
try: try:
from urllib2 import urlopen, Request, HTTPError, URLError from urllib2 import (urlopen, Request, HTTPError, URLError,
AbstractHTTPHandler, ProxyHandler,
HTTPDefaultErrorHandler, HTTPRedirectHandler,
HTTPErrorProcessor, OpenerDirector)
except ImportError: except ImportError:
from urllib.request import urlopen, Request, HTTPError, URLError from urllib.request import (urlopen, Request, HTTPError, URLError,
AbstractHTTPHandler, ProxyHandler,
HTTPDefaultErrorHandler, HTTPRedirectHandler,
HTTPErrorProcessor, OpenerDirector)
try: try:
from httplib import HTTPConnection from httplib import HTTPConnection
@ -320,6 +322,165 @@ class SpeedtestBestServerFailure(SpeedtestException):
"""Unable to determine best server""" """Unable to determine best server"""
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
source_address=None):
"""Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host,
port)``) and return the socket object. Passing the optional
*timeout* parameter will set the timeout on the socket instance
before attempting to connect. If no *timeout* is supplied, the
global default timeout setting returned by :func:`getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default.
"""
host, port = address
err = None
for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(float(timeout))
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except socket.error:
err = get_exception()
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise socket.error("getaddrinfo returns an empty list")
class SpeedtestHTTPConnection(HTTPConnection):
def __init__(self, *args, **kwargs):
source_address = kwargs.pop('source_address', None)
context = kwargs.pop('context', None)
timeout = kwargs.pop('timeout', 10)
HTTPConnection.__init__(self, *args, **kwargs)
self.source_address = source_address
self._context = context
self.timeout = timeout
try:
self._create_connection = socket.create_connection
except AttributeError:
self._create_connection = create_connection
def connect(self):
"""Connect to the host and port specified in __init__."""
self.sock = self._create_connection(
(self.host, self.port),
self.timeout,
self.source_address
)
if HTTPSConnection:
class SpeedtestHTTPSConnection(HTTPSConnection,
SpeedtestHTTPConnection):
def connect(self):
"Connect to a host on a given (SSL) port."
SpeedtestHTTPConnection.connect(self)
kwargs = {}
if hasattr(ssl, 'SSLContext'):
kwargs['server_hostname'] = self.host
self.sock = self._context.wrap_socket(self.sock, **kwargs)
def _build_connection(connection, source_address, timeout, context=None):
def inner(host, **kwargs):
kwargs.update({
'source_address': source_address,
'timeout': timeout
})
if context:
kwargs['context'] = context
return connection(host, **kwargs)
return inner
class SpeedtestHTTPHandler(AbstractHTTPHandler):
def __init__(self, debuglevel=0, source_address=None, timeout=10):
AbstractHTTPHandler.__init__(self, debuglevel)
self.source_address = source_address
self.timeout = timeout
def http_open(self, req):
return self.do_open(
_build_connection(
SpeedtestHTTPConnection,
self.source_address,
self.timeout
),
req
)
http_request = AbstractHTTPHandler.do_request_
class SpeedtestHTTPSHandler(AbstractHTTPHandler):
def __init__(self, debuglevel=0, context=None, source_address=None,
timeout=10):
AbstractHTTPHandler.__init__(self, debuglevel)
self._context = context
self.source_address = source_address
self.timeout = timeout
def https_open(self, req):
return self.do_open(
_build_connection(
SpeedtestHTTPSConnection,
self.source_address,
timeout,
context=self._context,
),
req
)
https_request = AbstractHTTPHandler.do_request_
def build_opener(source_address=None, timeout=10):
if source_address:
source_address_tuple = (source_address, 0)
else:
source_address_tuple = None
handlers = [
ProxyHandler(),
SpeedtestHTTPHandler(source_address=source_address_tuple,
timeout=timeout),
SpeedtestHTTPSHandler(source_address=source_address_tuple,
timeout=timeout),
HTTPDefaultErrorHandler(),
HTTPRedirectHandler(),
HTTPErrorProcessor()
]
opener = OpenerDirector()
opener.addheaders = [('User-agent', build_user_agent())]
for handler in handlers:
opener.add_handler(handler)
return opener
class GzipDecodedResponse(GZIP_BASE): class GzipDecodedResponse(GZIP_BASE):
"""A file-like object to decode a response encoded with the gzip """A file-like object to decode a response encoded with the gzip
method, as described in RFC 1952. method, as described in RFC 1952.
@ -357,14 +518,6 @@ def get_exception():
return sys.exc_info()[1] return sys.exc_info()[1]
def bound_socket(*args, **kwargs):
"""Bind socket to a specified source IP address"""
sock = SOCKET_SOCKET(*args, **kwargs)
sock.bind((SOURCE, 0))
return sock
def distance(origin, destination): def distance(origin, destination):
"""Determine distance between 2 sets of [lat,lon] in km""" """Determine distance between 2 sets of [lat,lon] in km"""
@ -387,10 +540,6 @@ def distance(origin, destination):
def build_user_agent(): def build_user_agent():
"""Build a Mozilla/5.0 compatible User-Agent string""" """Build a Mozilla/5.0 compatible User-Agent string"""
global USER_AGENT
if USER_AGENT:
return USER_AGENT
ua_tuple = ( ua_tuple = (
'Mozilla/5.0', 'Mozilla/5.0',
'(%s; U; %s; en-us)' % (platform.system(), platform.architecture()[0]), '(%s; U; %s; en-us)' % (platform.system(), platform.architecture()[0]),
@ -398,9 +547,9 @@ def build_user_agent():
'(KHTML, like Gecko)', '(KHTML, like Gecko)',
'speedtest-cli/%s' % __version__ 'speedtest-cli/%s' % __version__
) )
USER_AGENT = ' '.join(ua_tuple) user_agent = ' '.join(ua_tuple)
printer(USER_AGENT, debug=True) printer(user_agent, debug=True)
return USER_AGENT return user_agent
def build_request(url, data=None, headers=None, bump=''): def build_request(url, data=None, headers=None, bump=''):
@ -410,9 +559,6 @@ def build_request(url, data=None, headers=None, bump=''):
""" """
if not USER_AGENT:
build_user_agent()
if not headers: if not headers:
headers = {} headers = {}
@ -432,7 +578,6 @@ def build_request(url, data=None, headers=None, bump=''):
bump) bump)
headers.update({ headers.update({
'User-Agent': USER_AGENT,
'Cache-Control': 'no-cache', 'Cache-Control': 'no-cache',
}) })
@ -442,14 +587,19 @@ def build_request(url, data=None, headers=None, bump=''):
return Request(final_url, data=data, headers=headers) return Request(final_url, data=data, headers=headers)
def catch_request(request): def catch_request(request, opener=None):
"""Helper function to catch common exceptions encountered when """Helper function to catch common exceptions encountered when
establishing a connection with a HTTP/HTTPS request establishing a connection with a HTTP/HTTPS request
""" """
if opener:
_open = opener.open
else:
_open = urlopen
try: try:
uh = urlopen(request) uh = _open(request)
return uh, False return uh, False
except HTTP_ERRORS: except HTTP_ERRORS:
e = get_exception() e = get_exception()
@ -505,18 +655,22 @@ def do_nothing(*args, **kwargs):
class HTTPDownloader(threading.Thread): class HTTPDownloader(threading.Thread):
"""Thread class for retrieving a URL""" """Thread class for retrieving a URL"""
def __init__(self, i, request, start, timeout): def __init__(self, i, request, start, timeout, opener=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.request = request self.request = request
self.result = [0] self.result = [0]
self.starttime = start self.starttime = start
self.timeout = timeout self.timeout = timeout
self.i = i self.i = i
if opener:
self._opener = opener.open
else:
self._opener = urlopen
def run(self): def run(self):
try: try:
if (timeit.default_timer() - self.starttime) <= self.timeout: if (timeit.default_timer() - self.starttime) <= self.timeout:
f = urlopen(self.request) f = self._opener(self.request)
while (not SHUTDOWN_EVENT.isSet() and while (not SHUTDOWN_EVENT.isSet() and
(timeit.default_timer() - self.starttime) <= (timeit.default_timer() - self.starttime) <=
self.timeout): self.timeout):
@ -574,7 +728,7 @@ class HTTPUploaderData(object):
class HTTPUploader(threading.Thread): class HTTPUploader(threading.Thread):
"""Thread class for putting a URL""" """Thread class for putting a URL"""
def __init__(self, i, request, start, size, timeout): def __init__(self, i, request, start, size, timeout, opener=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.request = request self.request = request
self.request.data.start = self.starttime = start self.request.data.start = self.starttime = start
@ -583,20 +737,25 @@ class HTTPUploader(threading.Thread):
self.timeout = timeout self.timeout = timeout
self.i = i self.i = i
if opener:
self._opener = opener.open
else:
self._opener = urlopen
def run(self): def run(self):
request = self.request request = self.request
try: try:
if ((timeit.default_timer() - self.starttime) <= self.timeout and if ((timeit.default_timer() - self.starttime) <= self.timeout and
not SHUTDOWN_EVENT.isSet()): not SHUTDOWN_EVENT.isSet()):
try: try:
f = urlopen(request) f = self._opener(request)
except TypeError: except TypeError:
# PY24 expects a string or buffer # PY24 expects a string or buffer
# This also causes issues with Ctrl-C, but we will concede # This also causes issues with Ctrl-C, but we will concede
# for the moment that Ctrl-C on PY24 isn't immediate # for the moment that Ctrl-C on PY24 isn't immediate
request = build_request(self.request.get_full_url(), request = build_request(self.request.get_full_url(),
data=request.data.read(self.size)) data=request.data.read(self.size))
f = urlopen(request) f = self._opener(request)
f.read(11) f.read(11)
f.close() f.close()
self.result = sum(self.request.data.total) self.result = sum(self.request.data.total)
@ -619,7 +778,7 @@ class SpeedtestResults(object):
to get a share results image link. to get a share results image link.
""" """
def __init__(self, download=0, upload=0, ping=0, server=None): def __init__(self, download=0, upload=0, ping=0, server=None, opener=None):
self.download = download self.download = download
self.upload = upload self.upload = upload
self.ping = ping self.ping = ping
@ -632,6 +791,11 @@ class SpeedtestResults(object):
self.bytes_received = 0 self.bytes_received = 0
self.bytes_sent = 0 self.bytes_sent = 0
if opener:
self._opener = opener
else:
self._opener = build_opener()
def __repr__(self): def __repr__(self):
return repr(self.dict()) return repr(self.dict())
@ -674,7 +838,7 @@ class SpeedtestResults(object):
request = build_request('://www.speedtest.net/api/api.php', request = build_request('://www.speedtest.net/api/api.php',
data='&'.join(api_data).encode(), data='&'.join(api_data).encode(),
headers=headers) headers=headers)
f, e = catch_request(request) f, e = catch_request(request, opener=_self.opener)
if e: if e:
raise ShareResultsConnectFailure(e) raise ShareResultsConnectFailure(e)
@ -738,8 +902,13 @@ class SpeedtestResults(object):
class Speedtest(object): class Speedtest(object):
"""Class for performing standard speedtest.net testing operations""" """Class for performing standard speedtest.net testing operations"""
def __init__(self, config=None): def __init__(self, config=None, source_address=None, timeout=10):
self.config = {} self.config = {}
self._source_address = source_address
self._timeout = timeout
self._opener = build_opener(source_address, timeout)
self.get_config() self.get_config()
if config is not None: if config is not None:
self.config.update(config) self.config.update(config)
@ -748,7 +917,7 @@ class Speedtest(object):
self.closest = [] self.closest = []
self.best = {} self.best = {}
self.results = SpeedtestResults() self.results = SpeedtestResults(opener=self._opener)
def get_config(self): def get_config(self):
"""Download the speedtest.net configuration and return only the data """Download the speedtest.net configuration and return only the data
@ -760,7 +929,7 @@ class Speedtest(object):
headers['Accept-Encoding'] = 'gzip' headers['Accept-Encoding'] = 'gzip'
request = build_request('://www.speedtest.net/speedtest-config.php', request = build_request('://www.speedtest.net/speedtest-config.php',
headers=headers) headers=headers)
uh, e = catch_request(request) uh, e = catch_request(request, opener=self._opener)
if e: if e:
raise ConfigRetrievalError(e) raise ConfigRetrievalError(e)
configxml = [] configxml = []
@ -877,7 +1046,7 @@ class Speedtest(object):
(url, (url,
self.config['threads']['download']), self.config['threads']['download']),
headers=headers) headers=headers)
uh, e = catch_request(request) uh, e = catch_request(request, opener=self._opener)
if e: if e:
errors.append('%s' % e) errors.append('%s' % e)
raise ServersRetrievalError() raise ServersRetrievalError()
@ -960,7 +1129,7 @@ class Speedtest(object):
url = server url = server
request = build_request(url) request = build_request(url)
uh, e = catch_request(request) uh, e = catch_request(request, opener=self._opener)
if e: if e:
raise SpeedtestMiniConnectFailure('Failed to connect to %s' % raise SpeedtestMiniConnectFailure('Failed to connect to %s' %
server) server)
@ -973,7 +1142,9 @@ class Speedtest(object):
if not extension: if not extension:
for ext in ['php', 'asp', 'aspx', 'jsp']: for ext in ['php', 'asp', 'aspx', 'jsp']:
try: try:
f = urlopen('%s/speedtest/upload.%s' % (url, ext)) f = self._opener.open(
'%s/speedtest/upload.%s' % (url, ext)
)
except: except:
pass pass
else: else:
@ -1028,6 +1199,13 @@ class Speedtest(object):
servers = self.get_closest_servers() servers = self.get_closest_servers()
servers = self.closest servers = self.closest
if self._source_address:
source_address_tuple = (self._source_address, 0)
else:
source_address_tuple = None
user_agent = build_user_agent()
results = {} results = {}
for server in servers: for server in servers:
cum = [] cum = []
@ -1037,10 +1215,16 @@ class Speedtest(object):
for _ in range(0, 3): for _ in range(0, 3):
try: try:
if urlparts[0] == 'https': if urlparts[0] == 'https':
h = HTTPSConnection(urlparts[1]) h = SpeedtestHTTPSConnection(
urlparts[1],
source_address=source_address_tuple
)
else: else:
h = HTTPConnection(urlparts[1]) h = SpeedtestHTTPConnection(
headers = {'User-Agent': USER_AGENT} urlparts[1],
source_address=source_address_tuple
)
headers = {'User-Agent': user_agent}
start = timeit.default_timer() start = timeit.default_timer()
h.request("GET", urlparts[2], headers=headers) h.request("GET", urlparts[2], headers=headers)
r = h.getresponse() r = h.getresponse()
@ -1093,7 +1277,8 @@ class Speedtest(object):
def producer(q, requests, request_count): def producer(q, requests, request_count):
for i, request in enumerate(requests): for i, request in enumerate(requests):
thread = HTTPDownloader(i, request, start, thread = HTTPDownloader(i, request, start,
self.config['length']['download']) self.config['length']['download'],
opener=self._opener)
thread.start() thread.start()
q.put(thread, True) q.put(thread, True)
callback(i, request_count, start=True) callback(i, request_count, start=True)
@ -1159,7 +1344,8 @@ class Speedtest(object):
def producer(q, requests, request_count): def producer(q, requests, request_count):
for i, request in enumerate(requests[:request_count]): for i, request in enumerate(requests[:request_count]):
thread = HTTPUploader(i, request[0], start, request[1], thread = HTTPUploader(i, request[0], start, request[1],
self.config['length']['upload']) self.config['length']['upload'],
opener=self._opener)
thread.start() thread.start()
q.put(thread, True) q.put(thread, True)
callback(i, request_count, start=True) callback(i, request_count, start=True)
@ -1338,7 +1524,7 @@ def printer(string, quiet=False, debug=False, **kwargs):
def shell(): def shell():
"""Run the full speedtest.net test""" """Run the full speedtest.net test"""
global SHUTDOWN_EVENT, SOURCE, SCHEME, DEBUG global SHUTDOWN_EVENT, SCHEME, DEBUG
SHUTDOWN_EVENT = threading.Event() SHUTDOWN_EVENT = threading.Event()
signal.signal(signal.SIGINT, ctrl_c) signal.signal(signal.SIGINT, ctrl_c)
@ -1361,13 +1547,6 @@ def shell():
validate_optional_args(args) validate_optional_args(args)
socket.setdefaulttimeout(args.timeout)
# If specified bind to a specific IP address
if args.source:
SOURCE = args.source
socket.socket = bound_socket
if args.secure: if args.secure:
SCHEME = 'https' SCHEME = 'https'
@ -1377,9 +1556,6 @@ def shell():
if debug: if debug:
DEBUG = True DEBUG = True
# Pre-cache the user agent string
build_user_agent()
if args.simple or args.csv or args.json: if args.simple or args.csv or args.json:
quiet = True quiet = True
else: else:
@ -1398,15 +1574,15 @@ def shell():
printer('Retrieving speedtest.net configuration...', quiet) printer('Retrieving speedtest.net configuration...', quiet)
try: try:
speedtest = Speedtest() speedtest = Speedtest(source_address=args.source, timeout=args.timeout)
except (ConfigRetrievalError, HTTP_ERRORS): except (ConfigRetrievalError,) + HTTP_ERRORS:
printer('Cannot retrieve speedtest configuration') printer('Cannot retrieve speedtest configuration')
raise SpeedtestCLIError(get_exception()) raise SpeedtestCLIError(get_exception())
if args.list: if args.list:
try: try:
speedtest.get_servers() speedtest.get_servers()
except (ServersRetrievalError, HTTP_ERRORS): except (ServersRetrievalError,) + HTTP_ERRORS:
print_('Cannot retrieve speedtest server list') print_('Cannot retrieve speedtest server list')
raise SpeedtestCLIError(get_exception()) raise SpeedtestCLIError(get_exception())
@ -1436,7 +1612,7 @@ def shell():
speedtest.get_servers(servers) speedtest.get_servers(servers)
except NoMatchedServers: except NoMatchedServers:
raise SpeedtestCLIError('No matched servers: %s' % args.server) raise SpeedtestCLIError('No matched servers: %s' % args.server)
except (ServersRetrievalError, HTTP_ERRORS): except (ServersRetrievalError,) + HTTP_ERRORS:
print_('Cannot retrieve speedtest server list') print_('Cannot retrieve speedtest server list')
raise SpeedtestCLIError(get_exception()) raise SpeedtestCLIError(get_exception())
except InvalidServerIDType: except InvalidServerIDType: