From 10b3b09f02bfd9636af4a10b84abcd3c26035949 Mon Sep 17 00:00:00 2001 From: Matt Martz Date: Tue, 2 May 2017 10:56:31 -0500 Subject: [PATCH] Don't override socket.socket for binding, eliminiate globals SOURCE and USER_AGENT --- speedtest.py | 298 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 237 insertions(+), 61 deletions(-) diff --git a/speedtest.py b/speedtest.py index cb4a374..115b2f8 100755 --- a/speedtest.py +++ b/speedtest.py @@ -51,14 +51,10 @@ class FakeShutdownEvent(object): # Some global variables we use -USER_AGENT = None -SOURCE = None SHUTDOWN_EVENT = FakeShutdownEvent() SCHEME = 'http' DEBUG = False - -# Used for bound_interface -SOCKET_SOCKET = socket.socket +_GLOBAL_DEFAULT_TIMEOUT = object() # Begin import game to handle Python 2 and Python 3 try: @@ -79,9 +75,15 @@ except ImportError: ET = None try: - from urllib2 import urlopen, Request, HTTPError, URLError + from urllib2 import (urlopen, Request, HTTPError, URLError, + AbstractHTTPHandler, ProxyHandler, + HTTPDefaultErrorHandler, HTTPRedirectHandler, + HTTPErrorProcessor, OpenerDirector) except ImportError: - from urllib.request import urlopen, Request, HTTPError, URLError + from urllib.request import (urlopen, Request, HTTPError, URLError, + AbstractHTTPHandler, ProxyHandler, + HTTPDefaultErrorHandler, HTTPRedirectHandler, + HTTPErrorProcessor, OpenerDirector) try: from httplib import HTTPConnection @@ -320,6 +322,165 @@ class SpeedtestBestServerFailure(SpeedtestException): """Unable to determine best server""" +def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """Connect to *address* and return the socket object. + + Convenience function. Connect to *address* (a 2-tuple ``(host, + port)``) and return the socket object. Passing the optional + *timeout* parameter will set the timeout on the socket instance + before attempting to connect. If no *timeout* is supplied, the + global default timeout setting returned by :func:`getdefaulttimeout` + is used. If *source_address* is set it must be a tuple of (host, port) + for the socket to bind as a source address before making the connection. + An host of '' or port 0 tells the OS to use the default. + """ + + host, port = address + err = None + for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + if timeout is not _GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(float(timeout)) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + + except socket.error: + err = get_exception() + if sock is not None: + sock.close() + + if err is not None: + raise err + else: + raise socket.error("getaddrinfo returns an empty list") + + +class SpeedtestHTTPConnection(HTTPConnection): + def __init__(self, *args, **kwargs): + source_address = kwargs.pop('source_address', None) + context = kwargs.pop('context', None) + timeout = kwargs.pop('timeout', 10) + + HTTPConnection.__init__(self, *args, **kwargs) + + self.source_address = source_address + self._context = context + self.timeout = timeout + + try: + self._create_connection = socket.create_connection + except AttributeError: + self._create_connection = create_connection + + def connect(self): + """Connect to the host and port specified in __init__.""" + self.sock = self._create_connection( + (self.host, self.port), + self.timeout, + self.source_address + ) + + +if HTTPSConnection: + class SpeedtestHTTPSConnection(HTTPSConnection, + SpeedtestHTTPConnection): + def connect(self): + "Connect to a host on a given (SSL) port." + + SpeedtestHTTPConnection.connect(self) + + kwargs = {} + if hasattr(ssl, 'SSLContext'): + kwargs['server_hostname'] = self.host + + self.sock = self._context.wrap_socket(self.sock, **kwargs) + + +def _build_connection(connection, source_address, timeout, context=None): + def inner(host, **kwargs): + kwargs.update({ + 'source_address': source_address, + 'timeout': timeout + }) + if context: + kwargs['context'] = context + return connection(host, **kwargs) + return inner + + +class SpeedtestHTTPHandler(AbstractHTTPHandler): + def __init__(self, debuglevel=0, source_address=None, timeout=10): + AbstractHTTPHandler.__init__(self, debuglevel) + self.source_address = source_address + self.timeout = timeout + + def http_open(self, req): + return self.do_open( + _build_connection( + SpeedtestHTTPConnection, + self.source_address, + self.timeout + ), + req + ) + + http_request = AbstractHTTPHandler.do_request_ + + +class SpeedtestHTTPSHandler(AbstractHTTPHandler): + def __init__(self, debuglevel=0, context=None, source_address=None, + timeout=10): + AbstractHTTPHandler.__init__(self, debuglevel) + self._context = context + self.source_address = source_address + self.timeout = timeout + + def https_open(self, req): + return self.do_open( + _build_connection( + SpeedtestHTTPSConnection, + self.source_address, + timeout, + context=self._context, + ), + req + ) + + https_request = AbstractHTTPHandler.do_request_ + + +def build_opener(source_address=None, timeout=10): + if source_address: + source_address_tuple = (source_address, 0) + else: + source_address_tuple = None + + handlers = [ + ProxyHandler(), + SpeedtestHTTPHandler(source_address=source_address_tuple, + timeout=timeout), + SpeedtestHTTPSHandler(source_address=source_address_tuple, + timeout=timeout), + HTTPDefaultErrorHandler(), + HTTPRedirectHandler(), + HTTPErrorProcessor() + ] + + opener = OpenerDirector() + opener.addheaders = [('User-agent', build_user_agent())] + + for handler in handlers: + opener.add_handler(handler) + + return opener + + class GzipDecodedResponse(GZIP_BASE): """A file-like object to decode a response encoded with the gzip method, as described in RFC 1952. @@ -357,14 +518,6 @@ def get_exception(): return sys.exc_info()[1] -def bound_socket(*args, **kwargs): - """Bind socket to a specified source IP address""" - - sock = SOCKET_SOCKET(*args, **kwargs) - sock.bind((SOURCE, 0)) - return sock - - def distance(origin, destination): """Determine distance between 2 sets of [lat,lon] in km""" @@ -387,10 +540,6 @@ def distance(origin, destination): def build_user_agent(): """Build a Mozilla/5.0 compatible User-Agent string""" - global USER_AGENT - if USER_AGENT: - return USER_AGENT - ua_tuple = ( 'Mozilla/5.0', '(%s; U; %s; en-us)' % (platform.system(), platform.architecture()[0]), @@ -398,9 +547,9 @@ def build_user_agent(): '(KHTML, like Gecko)', 'speedtest-cli/%s' % __version__ ) - USER_AGENT = ' '.join(ua_tuple) - printer(USER_AGENT, debug=True) - return USER_AGENT + user_agent = ' '.join(ua_tuple) + printer(user_agent, debug=True) + return user_agent def build_request(url, data=None, headers=None, bump=''): @@ -410,9 +559,6 @@ def build_request(url, data=None, headers=None, bump=''): """ - if not USER_AGENT: - build_user_agent() - if not headers: headers = {} @@ -432,7 +578,6 @@ def build_request(url, data=None, headers=None, bump=''): bump) headers.update({ - 'User-Agent': USER_AGENT, 'Cache-Control': 'no-cache', }) @@ -442,14 +587,19 @@ def build_request(url, data=None, headers=None, bump=''): return Request(final_url, data=data, headers=headers) -def catch_request(request): +def catch_request(request, opener=None): """Helper function to catch common exceptions encountered when establishing a connection with a HTTP/HTTPS request """ + if opener: + _open = opener.open + else: + _open = urlopen + try: - uh = urlopen(request) + uh = _open(request) return uh, False except HTTP_ERRORS: e = get_exception() @@ -505,18 +655,22 @@ def do_nothing(*args, **kwargs): class HTTPDownloader(threading.Thread): """Thread class for retrieving a URL""" - def __init__(self, i, request, start, timeout): + def __init__(self, i, request, start, timeout, opener=None): threading.Thread.__init__(self) self.request = request self.result = [0] self.starttime = start self.timeout = timeout self.i = i + if opener: + self._opener = opener.open + else: + self._opener = urlopen def run(self): try: if (timeit.default_timer() - self.starttime) <= self.timeout: - f = urlopen(self.request) + f = self._opener(self.request) while (not SHUTDOWN_EVENT.isSet() and (timeit.default_timer() - self.starttime) <= self.timeout): @@ -574,7 +728,7 @@ class HTTPUploaderData(object): class HTTPUploader(threading.Thread): """Thread class for putting a URL""" - def __init__(self, i, request, start, size, timeout): + def __init__(self, i, request, start, size, timeout, opener=None): threading.Thread.__init__(self) self.request = request self.request.data.start = self.starttime = start @@ -583,20 +737,25 @@ class HTTPUploader(threading.Thread): self.timeout = timeout self.i = i + if opener: + self._opener = opener.open + else: + self._opener = urlopen + def run(self): request = self.request try: if ((timeit.default_timer() - self.starttime) <= self.timeout and not SHUTDOWN_EVENT.isSet()): try: - f = urlopen(request) + f = self._opener(request) except TypeError: # PY24 expects a string or buffer # This also causes issues with Ctrl-C, but we will concede # for the moment that Ctrl-C on PY24 isn't immediate request = build_request(self.request.get_full_url(), data=request.data.read(self.size)) - f = urlopen(request) + f = self._opener(request) f.read(11) f.close() self.result = sum(self.request.data.total) @@ -619,7 +778,7 @@ class SpeedtestResults(object): to get a share results image link. """ - def __init__(self, download=0, upload=0, ping=0, server=None): + def __init__(self, download=0, upload=0, ping=0, server=None, opener=None): self.download = download self.upload = upload self.ping = ping @@ -632,6 +791,11 @@ class SpeedtestResults(object): self.bytes_received = 0 self.bytes_sent = 0 + if opener: + self._opener = opener + else: + self._opener = build_opener() + def __repr__(self): return repr(self.dict()) @@ -674,7 +838,7 @@ class SpeedtestResults(object): request = build_request('://www.speedtest.net/api/api.php', data='&'.join(api_data).encode(), headers=headers) - f, e = catch_request(request) + f, e = catch_request(request, opener=_self.opener) if e: raise ShareResultsConnectFailure(e) @@ -738,8 +902,13 @@ class SpeedtestResults(object): class Speedtest(object): """Class for performing standard speedtest.net testing operations""" - def __init__(self, config=None): + def __init__(self, config=None, source_address=None, timeout=10): self.config = {} + + self._source_address = source_address + self._timeout = timeout + self._opener = build_opener(source_address, timeout) + self.get_config() if config is not None: self.config.update(config) @@ -748,7 +917,7 @@ class Speedtest(object): self.closest = [] self.best = {} - self.results = SpeedtestResults() + self.results = SpeedtestResults(opener=self._opener) def get_config(self): """Download the speedtest.net configuration and return only the data @@ -760,7 +929,7 @@ class Speedtest(object): headers['Accept-Encoding'] = 'gzip' request = build_request('://www.speedtest.net/speedtest-config.php', headers=headers) - uh, e = catch_request(request) + uh, e = catch_request(request, opener=self._opener) if e: raise ConfigRetrievalError(e) configxml = [] @@ -877,7 +1046,7 @@ class Speedtest(object): (url, self.config['threads']['download']), headers=headers) - uh, e = catch_request(request) + uh, e = catch_request(request, opener=self._opener) if e: errors.append('%s' % e) raise ServersRetrievalError() @@ -960,7 +1129,7 @@ class Speedtest(object): url = server request = build_request(url) - uh, e = catch_request(request) + uh, e = catch_request(request, opener=self._opener) if e: raise SpeedtestMiniConnectFailure('Failed to connect to %s' % server) @@ -973,7 +1142,9 @@ class Speedtest(object): if not extension: for ext in ['php', 'asp', 'aspx', 'jsp']: try: - f = urlopen('%s/speedtest/upload.%s' % (url, ext)) + f = self._opener.open( + '%s/speedtest/upload.%s' % (url, ext) + ) except: pass else: @@ -1028,6 +1199,13 @@ class Speedtest(object): servers = self.get_closest_servers() servers = self.closest + if self._source_address: + source_address_tuple = (self._source_address, 0) + else: + source_address_tuple = None + + user_agent = build_user_agent() + results = {} for server in servers: cum = [] @@ -1037,10 +1215,16 @@ class Speedtest(object): for _ in range(0, 3): try: if urlparts[0] == 'https': - h = HTTPSConnection(urlparts[1]) + h = SpeedtestHTTPSConnection( + urlparts[1], + source_address=source_address_tuple + ) else: - h = HTTPConnection(urlparts[1]) - headers = {'User-Agent': USER_AGENT} + h = SpeedtestHTTPConnection( + urlparts[1], + source_address=source_address_tuple + ) + headers = {'User-Agent': user_agent} start = timeit.default_timer() h.request("GET", urlparts[2], headers=headers) r = h.getresponse() @@ -1093,7 +1277,8 @@ class Speedtest(object): def producer(q, requests, request_count): for i, request in enumerate(requests): thread = HTTPDownloader(i, request, start, - self.config['length']['download']) + self.config['length']['download'], + opener=self._opener) thread.start() q.put(thread, True) callback(i, request_count, start=True) @@ -1159,7 +1344,8 @@ class Speedtest(object): def producer(q, requests, request_count): for i, request in enumerate(requests[:request_count]): thread = HTTPUploader(i, request[0], start, request[1], - self.config['length']['upload']) + self.config['length']['upload'], + opener=self._opener) thread.start() q.put(thread, True) callback(i, request_count, start=True) @@ -1338,7 +1524,7 @@ def printer(string, quiet=False, debug=False, **kwargs): def shell(): """Run the full speedtest.net test""" - global SHUTDOWN_EVENT, SOURCE, SCHEME, DEBUG + global SHUTDOWN_EVENT, SCHEME, DEBUG SHUTDOWN_EVENT = threading.Event() signal.signal(signal.SIGINT, ctrl_c) @@ -1361,13 +1547,6 @@ def shell(): validate_optional_args(args) - socket.setdefaulttimeout(args.timeout) - - # If specified bind to a specific IP address - if args.source: - SOURCE = args.source - socket.socket = bound_socket - if args.secure: SCHEME = 'https' @@ -1377,9 +1556,6 @@ def shell(): if debug: DEBUG = True - # Pre-cache the user agent string - build_user_agent() - if args.simple or args.csv or args.json: quiet = True else: @@ -1398,15 +1574,15 @@ def shell(): printer('Retrieving speedtest.net configuration...', quiet) try: - speedtest = Speedtest() - except (ConfigRetrievalError, HTTP_ERRORS): + speedtest = Speedtest(source_address=args.source, timeout=args.timeout) + except (ConfigRetrievalError,) + HTTP_ERRORS: printer('Cannot retrieve speedtest configuration') raise SpeedtestCLIError(get_exception()) if args.list: try: speedtest.get_servers() - except (ServersRetrievalError, HTTP_ERRORS): + except (ServersRetrievalError,) + HTTP_ERRORS: print_('Cannot retrieve speedtest server list') raise SpeedtestCLIError(get_exception()) @@ -1436,7 +1612,7 @@ def shell(): speedtest.get_servers(servers) except NoMatchedServers: raise SpeedtestCLIError('No matched servers: %s' % args.server) - except (ServersRetrievalError, HTTP_ERRORS): + except (ServersRetrievalError,) + HTTP_ERRORS: print_('Cannot retrieve speedtest server list') raise SpeedtestCLIError(get_exception()) except InvalidServerIDType: