diff --git a/.travis.yml b/.travis.yml index 8863201..20e5ecd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,5 @@ language: python -python: - - 2.7 - addons: apt: sources: @@ -11,26 +8,44 @@ addons: - python2.4 - python2.5 - python2.6 - - pypy + - python3.2 + - python3.3 -env: - - TOXENV=py24 - - TOXENV=py25 - - TOXENV=py26 - - TOXENV=py27 - - TOXENV=py32 - - TOXENV=py33 - - TOXENV=py34 - - TOXENV=py35 - - TOXENV=pypy - - TOXENV=flake8 +matrix: + include: + - python: 2.7 + env: TOXENV=flake8 + - python: 2.7 + env: TOXENV=py24 + - python: 2.7 + env: TOXENV=py25 + - python: 2.7 + env: TOXENV=py26 + - python: 2.7 + env: TOXENV=py27 + - python: 2.7 + env: TOXENV=py32 + - python: 2.7 + env: TOXENV=py33 + - python: 3.4 + env: TOXENV=py34 + - python: 3.5 + env: TOXENV=py35 + - python: 3.6 + env: TOXENV=py36 + - python: pypy + env: TOXENV=pypy + +before_install: + - if [[ $(echo "$TOXENV" | egrep -c "py35") != 0 ]]; then pyenv global system 3.5; fi; install: - - if [[ $(echo "$TOXENV" | egrep -c "(py2[45]|py3[12])") != 0 ]]; then pip install virtualenv==1.7.2 tox==1.3; fi; - - if [[ $(echo "$TOXENV" | egrep -c "(py2[45]|py3[12])") == 0 ]]; then pip install tox; fi; + - if [[ $(echo "$TOXENV" | egrep -c "py32") != 0 ]]; then pip install setuptools==17.1.1; fi; + - if [[ $(echo "$TOXENV" | egrep -c "(py2[45]|py3[12])") != 0 ]]; then pip install virtualenv==1.7.2 tox==1.3; fi; + - if [[ $(echo "$TOXENV" | egrep -c "(py2[45]|py3[12])") == 0 ]]; then pip install tox; fi; script: - - tox + - tox notifications: email: diff --git a/README.rst b/README.rst index 0043b5c..d80188a 100644 --- a/README.rst +++ b/README.rst @@ -17,7 +17,7 @@ speedtest.net Versions -------- -speedtest-cli works with Python 2.4-3.5 +speedtest-cli works with Python 2.4-3.6 .. image:: https://img.shields.io/pypi/pyversions/speedtest-cli.svg :target: https://pypi.python.org/pypi/speedtest-cli/ @@ -77,13 +77,14 @@ Usage usage: speedtest-cli [-h] [--no-download] [--no-upload] [--bytes] [--share] [--simple] [--csv] [--csv-delimiter CSV_DELIMITER] [--csv-header] [--json] [--list] [--server SERVER] - [--mini MINI] [--source SOURCE] [--timeout TIMEOUT] - [--secure] [--no-pre-allocate] [--version] - + [--exclude EXCLUDE] [--mini MINI] [--source SOURCE] + [--timeout TIMEOUT] [--secure] [--no-pre-allocate] + [--version] + Command line interface for testing internet bandwidth using speedtest.net. -------------------------------------------------------------------------- https://github.com/sivel/speedtest-cli - + optional arguments: -h, --help show this help message and exit --no-download Do not perform download test @@ -106,7 +107,10 @@ Usage affected by --bytes --list Display a list of speedtest.net servers sorted by distance - --server SERVER Specify a server ID to test against + --server SERVER Specify a server ID to test against. Can be supplied + multiple times + --exclude EXCLUDE Exclude a server from selection. Can be supplied + multiple times --mini MINI URL of the Speedtest Mini server --source SOURCE Source IP address to bind to --timeout TIMEOUT HTTP timeout in seconds. Default 10 diff --git a/setup.py b/setup.py index 00a8054..f323f81 100644 --- a/setup.py +++ b/setup.py @@ -90,5 +90,6 @@ setup( 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', ] ) diff --git a/speedtest.py b/speedtest.py index 91872d2..96439c7 100755 --- a/speedtest.py +++ b/speedtest.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# Copyright 2012-2016 Matt Martz +# Copyright 2012-2018 Matt Martz # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -36,7 +36,7 @@ except ImportError: gzip = None GZIP_BASE = object -__version__ = '1.0.7' +__version__ = '2.0.0b' class FakeShutdownEvent(object): @@ -51,14 +51,8 @@ class FakeShutdownEvent(object): # Some global variables we use -USER_AGENT = None -SOURCE = None -SHUTDOWN_EVENT = FakeShutdownEvent() -SCHEME = 'http' DEBUG = False - -# Used for bound_interface -SOCKET_SOCKET = socket.socket +_GLOBAL_DEFAULT_TIMEOUT = object() # Begin import game to handle Python 2 and Python 3 try: @@ -79,9 +73,15 @@ except ImportError: ET = None try: - from urllib2 import urlopen, Request, HTTPError, URLError + from urllib2 import (urlopen, Request, HTTPError, URLError, + AbstractHTTPHandler, ProxyHandler, + HTTPDefaultErrorHandler, HTTPRedirectHandler, + HTTPErrorProcessor, OpenerDirector) except ImportError: - from urllib.request import urlopen, Request, HTTPError, URLError + from urllib.request import (urlopen, Request, HTTPError, URLError, + AbstractHTTPHandler, ProxyHandler, + HTTPDefaultErrorHandler, HTTPRedirectHandler, + HTTPErrorProcessor, OpenerDirector) try: from httplib import HTTPConnection @@ -124,11 +124,13 @@ try: from argparse import SUPPRESS as ARG_SUPPRESS PARSER_TYPE_INT = int PARSER_TYPE_STR = str + PARSER_TYPE_FLOAT = float except ImportError: from optparse import OptionParser as ArgParser from optparse import SUPPRESS_HELP as ARG_SUPPRESS PARSER_TYPE_INT = 'int' PARSER_TYPE_STR = 'string' + PARSER_TYPE_FLOAT = 'float' try: from cStringIO import StringIO @@ -146,24 +148,25 @@ except ImportError: import builtins from io import TextIOWrapper, FileIO - class _Py3Utf8Stdout(TextIOWrapper): + class _Py3Utf8Output(TextIOWrapper): """UTF-8 encoded wrapper around stdout for py3, to override ASCII stdout """ - def __init__(self, **kwargs): - buf = FileIO(sys.stdout.fileno(), 'w') - super(_Py3Utf8Stdout, self).__init__( + def __init__(self, f, **kwargs): + buf = FileIO(f.fileno(), 'w') + super(_Py3Utf8Output, self).__init__( buf, encoding='utf8', errors='strict' ) def write(self, s): - super(_Py3Utf8Stdout, self).write(s) + super(_Py3Utf8Output, self).write(s) self.flush() _py3_print = getattr(builtins, 'print') - _py3_utf8_stdout = _Py3Utf8Stdout() + _py3_utf8_stdout = _Py3Utf8Output(sys.stdout) + _py3_utf8_stderr = _Py3Utf8Output(sys.stderr) def to_utf8(v): """No-op encode to utf-8 for py3""" @@ -171,7 +174,10 @@ except ImportError: def print_(*args, **kwargs): """Wrapper function for py3 to print, with a utf-8 encoded stdout""" - kwargs['file'] = _py3_utf8_stdout + if kwargs.get('file') == sys.stderr: + kwargs['file'] = _py3_utf8_stderr + else: + kwargs['file'] = kwargs.get('file', _py3_utf8_stdout) _py3_print(*args, **kwargs) else: del __builtin__ @@ -188,7 +194,7 @@ else: Taken from https://pypi.python.org/pypi/six/ - Modified to set encoding to UTF-8 always + Modified to set encoding to UTF-8 always, and to flush after write """ fp = kwargs.pop("file", sys.stdout) if fp is None: @@ -207,6 +213,7 @@ else: errors = "strict" data = data.encode(encoding, errors) fp.write(data) + fp.flush() want_unicode = False sep = kwargs.pop("sep", None) if sep is not None: @@ -320,6 +327,201 @@ class SpeedtestBestServerFailure(SpeedtestException): """Unable to determine best server""" +class SpeedtestMissingBestServer(SpeedtestException): + """get_best_server not called or not able to determine best server""" + + +def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + """Connect to *address* and return the socket object. + + Convenience function. Connect to *address* (a 2-tuple ``(host, + port)``) and return the socket object. Passing the optional + *timeout* parameter will set the timeout on the socket instance + before attempting to connect. If no *timeout* is supplied, the + global default timeout setting returned by :func:`getdefaulttimeout` + is used. If *source_address* is set it must be a tuple of (host, port) + for the socket to bind as a source address before making the connection. + An host of '' or port 0 tells the OS to use the default. + + Largely vendored from Python 2.7, modified to work with Python 2.4 + """ + + host, port = address + err = None + for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + if timeout is not _GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(float(timeout)) + if source_address: + sock.bind(source_address) + sock.connect(sa) + return sock + + except socket.error: + err = get_exception() + if sock is not None: + sock.close() + + if err is not None: + raise err + else: + raise socket.error("getaddrinfo returns an empty list") + + +class SpeedtestHTTPConnection(HTTPConnection): + """Custom HTTPConnection to support source_address across + Python 2.4 - Python 3 + """ + def __init__(self, *args, **kwargs): + source_address = kwargs.pop('source_address', None) + context = kwargs.pop('context', None) + timeout = kwargs.pop('timeout', 10) + + HTTPConnection.__init__(self, *args, **kwargs) + + self.source_address = source_address + self._context = context + self.timeout = timeout + + def connect(self): + """Connect to the host and port specified in __init__.""" + try: + self.sock = socket.create_connection( + (self.host, self.port), + self.timeout, + self.source_address + ) + except (AttributeError, TypeError): + self.sock = create_connection( + (self.host, self.port), + self.timeout, + self.source_address + ) + + +if HTTPSConnection: + class SpeedtestHTTPSConnection(HTTPSConnection, + SpeedtestHTTPConnection): + """Custom HTTPSConnection to support source_address across + Python 2.4 - Python 3 + """ + def connect(self): + "Connect to a host on a given (SSL) port." + + SpeedtestHTTPConnection.connect(self) + + kwargs = {} + if hasattr(ssl, 'SSLContext'): + kwargs['server_hostname'] = self.host + + self.sock = self._context.wrap_socket(self.sock, **kwargs) + + +def _build_connection(connection, source_address, timeout, context=None): + """Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or + ``HTTPSConnection`` with the args we need + + Called from ``http(s)_open`` methods of ``SpeedtestHTTPHandler`` or + ``SpeedtestHTTPSHandler`` + """ + def inner(host, **kwargs): + kwargs.update({ + 'source_address': source_address, + 'timeout': timeout + }) + if context: + kwargs['context'] = context + return connection(host, **kwargs) + return inner + + +class SpeedtestHTTPHandler(AbstractHTTPHandler): + """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the + args we need for ``source_address`` and ``timeout`` + """ + def __init__(self, debuglevel=0, source_address=None, timeout=10): + AbstractHTTPHandler.__init__(self, debuglevel) + self.source_address = source_address + self.timeout = timeout + + def http_open(self, req): + return self.do_open( + _build_connection( + SpeedtestHTTPConnection, + self.source_address, + self.timeout + ), + req + ) + + http_request = AbstractHTTPHandler.do_request_ + + +class SpeedtestHTTPSHandler(AbstractHTTPHandler): + """Custom ``HTTPSHandler`` that can build a ``HTTPSConnection`` with the + args we need for ``source_address`` and ``timeout`` + """ + def __init__(self, debuglevel=0, context=None, source_address=None, + timeout=10): + AbstractHTTPHandler.__init__(self, debuglevel) + self._context = context + self.source_address = source_address + self.timeout = timeout + + def https_open(self, req): + return self.do_open( + _build_connection( + SpeedtestHTTPSConnection, + self.source_address, + self.timeout, + context=self._context, + ), + req + ) + + https_request = AbstractHTTPHandler.do_request_ + + +def build_opener(source_address=None, timeout=10): + """Function similar to ``urllib2.build_opener`` that will build + an ``OpenerDirector`` with the explicit handlers we want, + ``source_address`` for binding, ``timeout`` and our custom + `User-Agent` + """ + + printer('Timeout set to %d' % timeout, debug=True) + + if source_address: + source_address_tuple = (source_address, 0) + printer('Binding to source address: %r' % (source_address_tuple,), + debug=True) + else: + source_address_tuple = None + + handlers = [ + ProxyHandler(), + SpeedtestHTTPHandler(source_address=source_address_tuple, + timeout=timeout), + SpeedtestHTTPSHandler(source_address=source_address_tuple, + timeout=timeout), + HTTPDefaultErrorHandler(), + HTTPRedirectHandler(), + HTTPErrorProcessor() + ] + + opener = OpenerDirector() + opener.addheaders = [('User-agent', build_user_agent())] + + for handler in handlers: + opener.add_handler(handler) + + return opener + + class GzipDecodedResponse(GZIP_BASE): """A file-like object to decode a response encoded with the gzip method, as described in RFC 1952. @@ -357,14 +559,6 @@ def get_exception(): return sys.exc_info()[1] -def bound_socket(*args, **kwargs): - """Bind socket to a specified source IP address""" - - sock = SOCKET_SOCKET(*args, **kwargs) - sock.bind((SOURCE, 0)) - return sock - - def distance(origin, destination): """Determine distance between 2 sets of [lat,lon] in km""" @@ -387,10 +581,6 @@ def distance(origin, destination): def build_user_agent(): """Build a Mozilla/5.0 compatible User-Agent string""" - global USER_AGENT - if USER_AGENT: - return USER_AGENT - ua_tuple = ( 'Mozilla/5.0', '(%s; U; %s; en-us)' % (platform.system(), platform.architecture()[0]), @@ -398,26 +588,24 @@ def build_user_agent(): '(KHTML, like Gecko)', 'speedtest-cli/%s' % __version__ ) - USER_AGENT = ' '.join(ua_tuple) - printer(USER_AGENT, debug=True) - return USER_AGENT + user_agent = ' '.join(ua_tuple) + printer('User-Agent: %s' % user_agent, debug=True) + return user_agent -def build_request(url, data=None, headers=None, bump=''): +def build_request(url, data=None, headers=None, bump='0', secure=False): """Build a urllib2 request object This function automatically adds a User-Agent header to all requests """ - if not USER_AGENT: - build_user_agent() - if not headers: headers = {} if url[0] == ':': - schemed_url = '%s%s' % (SCHEME, url) + scheme = ('http', 'https')[bool(secure)] + schemed_url = '%s%s' % (scheme, url) else: schemed_url = url @@ -432,7 +620,6 @@ def build_request(url, data=None, headers=None, bump=''): bump) headers.update({ - 'User-Agent': USER_AGENT, 'Cache-Control': 'no-cache', }) @@ -442,14 +629,19 @@ def build_request(url, data=None, headers=None, bump=''): return Request(final_url, data=data, headers=headers) -def catch_request(request): +def catch_request(request, opener=None): """Helper function to catch common exceptions encountered when establishing a connection with a HTTP/HTTPS request """ + if opener: + _open = opener.open + else: + _open = urlopen + try: - uh = urlopen(request) + uh = _open(request) return uh, False except HTTP_ERRORS: e = get_exception() @@ -484,18 +676,19 @@ def get_attributes_by_tag_name(dom, tag_name): return dict(list(elem.attributes.items())) -def print_dots(current, total, start=False, end=False): +def print_dots(shutdown_event): """Built in callback function used by Thread classes for printing status """ + def inner(current, total, start=False, end=False): + if shutdown_event.isSet(): + return - if SHUTDOWN_EVENT.isSet(): - return - - sys.stdout.write('.') - if current + 1 == total and end is True: - sys.stdout.write('\n') - sys.stdout.flush() + sys.stdout.write('.') + if current + 1 == total and end is True: + sys.stdout.write('\n') + sys.stdout.flush() + return inner def do_nothing(*args, **kwargs): @@ -505,19 +698,29 @@ def do_nothing(*args, **kwargs): class HTTPDownloader(threading.Thread): """Thread class for retrieving a URL""" - def __init__(self, i, request, start, timeout): + def __init__(self, i, request, start, timeout, opener=None, + shutdown_event=None): threading.Thread.__init__(self) self.request = request self.result = [0] self.starttime = start self.timeout = timeout self.i = i + if opener: + self._opener = opener.open + else: + self._opener = urlopen + + if shutdown_event: + self._shutdown_event = shutdown_event + else: + self._shutdown_event = FakeShutdownEvent() def run(self): try: if (timeit.default_timer() - self.starttime) <= self.timeout: - f = urlopen(self.request) - while (not SHUTDOWN_EVENT.isSet() and + f = self._opener(self.request) + while (not self._shutdown_event.isSet() and (timeit.default_timer() - self.starttime) <= self.timeout): self.result.append(len(f.read(10240))) @@ -533,11 +736,16 @@ class HTTPUploaderData(object): has been reached """ - def __init__(self, length, start, timeout): + def __init__(self, length, start, timeout, shutdown_event=None): self.length = length self.start = start self.timeout = timeout + if shutdown_event: + self._shutdown_event = shutdown_event + else: + self._shutdown_event = FakeShutdownEvent() + self._data = None self.total = [0] @@ -546,11 +754,17 @@ class HTTPUploaderData(object): chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' multiplier = int(round(int(self.length) / 36.0)) IO = BytesIO or StringIO - self._data = IO( - ('content1=%s' % - (chars * multiplier)[0:int(self.length) - 9] - ).encode() - ) + try: + self._data = IO( + ('content1=%s' % + (chars * multiplier)[0:int(self.length) - 9] + ).encode() + ) + except MemoryError: + raise SpeedtestCLIError( + 'Insufficient memory to pre-allocate upload data. Please ' + 'use --no-pre-allocate' + ) @property def data(self): @@ -560,7 +774,7 @@ class HTTPUploaderData(object): def read(self, n=10240): if ((timeit.default_timer() - self.start) <= self.timeout and - not SHUTDOWN_EVENT.isSet()): + not self._shutdown_event.isSet()): chunk = self.data.read(n) self.total.append(len(chunk)) return chunk @@ -574,7 +788,8 @@ class HTTPUploaderData(object): class HTTPUploader(threading.Thread): """Thread class for putting a URL""" - def __init__(self, i, request, start, size, timeout): + def __init__(self, i, request, start, size, timeout, opener=None, + shutdown_event=None): threading.Thread.__init__(self) self.request = request self.request.data.start = self.starttime = start @@ -583,20 +798,30 @@ class HTTPUploader(threading.Thread): self.timeout = timeout self.i = i + if opener: + self._opener = opener.open + else: + self._opener = urlopen + + if shutdown_event: + self._shutdown_event = shutdown_event + else: + self._shutdown_event = FakeShutdownEvent() + def run(self): request = self.request try: if ((timeit.default_timer() - self.starttime) <= self.timeout and - not SHUTDOWN_EVENT.isSet()): + not self._shutdown_event.isSet()): try: - f = urlopen(request) + f = self._opener(request) except TypeError: # PY24 expects a string or buffer # This also causes issues with Ctrl-C, but we will concede # for the moment that Ctrl-C on PY24 isn't immediate request = build_request(self.request.get_full_url(), data=request.data.read(self.size)) - f = urlopen(request) + f = self._opener(request) f.read(11) f.close() self.result = sum(self.request.data.total) @@ -619,7 +844,8 @@ class SpeedtestResults(object): to get a share results image link. """ - def __init__(self, download=0, upload=0, ping=0, server=None): + def __init__(self, download=0, upload=0, ping=0, server=None, client=None, + opener=None, secure=False): self.download = download self.upload = upload self.ping = ping @@ -627,11 +853,20 @@ class SpeedtestResults(object): self.server = {} else: self.server = server + self.client = client or {} + self._share = None self.timestamp = '%sZ' % datetime.datetime.utcnow().isoformat() self.bytes_received = 0 self.bytes_sent = 0 + if opener: + self._opener = opener + else: + self._opener = build_opener() + + self._secure = secure + def __repr__(self): return repr(self.dict()) @@ -673,8 +908,8 @@ class SpeedtestResults(object): headers = {'Referer': 'http://c.speedtest.net/flash/speedtest.swf'} request = build_request('://www.speedtest.net/api/api.php', data='&'.join(api_data).encode(), - headers=headers) - f, e = catch_request(request) + headers=headers, secure=self._secure) + f, e = catch_request(request, opener=self._opener) if e: raise ShareResultsConnectFailure(e) @@ -708,8 +943,20 @@ class SpeedtestResults(object): 'bytes_sent': self.bytes_sent, 'bytes_received': self.bytes_received, 'share': self._share, + 'client': self.client, } + @staticmethod + def csv_header(delimiter=','): + """Return CSV Headers""" + + row = ['Server ID', 'Sponsor', 'Server Name', 'Timestamp', 'Distance', + 'Ping', 'Download', 'Upload', 'Share', 'IP Address'] + out = StringIO() + writer = csv.writer(out, delimiter=delimiter, lineterminator='') + writer.writerow([to_utf8(v) for v in row]) + return out.getvalue() + def csv(self, delimiter=','): """Return data in CSV format""" @@ -719,7 +966,7 @@ class SpeedtestResults(object): row = [data['server']['id'], data['server']['sponsor'], data['server']['name'], data['timestamp'], data['server']['d'], data['ping'], data['download'], - data['upload']] + data['upload'], self._share or '', self.client['ip']] writer.writerow([to_utf8(v) for v in row]) return out.getvalue() @@ -738,17 +985,43 @@ class SpeedtestResults(object): class Speedtest(object): """Class for performing standard speedtest.net testing operations""" - def __init__(self, config=None): + def __init__(self, config=None, source_address=None, timeout=10, + secure=False, shutdown_event=None): self.config = {} + + self._source_address = source_address + self._timeout = timeout + self._opener = build_opener(source_address, timeout) + + self._secure = secure + + if shutdown_event: + self._shutdown_event = shutdown_event + else: + self._shutdown_event = FakeShutdownEvent() + self.get_config() if config is not None: self.config.update(config) self.servers = {} self.closest = [] - self.best = {} + self._best = {} - self.results = SpeedtestResults() + self.results = SpeedtestResults( + client=self.config['client'], + opener=self._opener, + secure=secure, + ) + + @property + def best(self): + if not self._best: + raise SpeedtestMissingBestServer( + 'get_best_server not called or not able to determine best ' + 'server' + ) + return self._best def get_config(self): """Download the speedtest.net configuration and return only the data @@ -759,8 +1032,8 @@ class Speedtest(object): if gzip: headers['Accept-Encoding'] = 'gzip' request = build_request('://www.speedtest.net/speedtest-config.php', - headers=headers) - uh, e = catch_request(request) + headers=headers, secure=self._secure) + uh, e = catch_request(request, opener=self._opener) if e: raise ConfigRetrievalError(e) configxml = [] @@ -768,7 +1041,10 @@ class Speedtest(object): stream = get_response_stream(uh) while 1: - configxml.append(stream.read(1024)) + try: + configxml.append(stream.read(1024)) + except (OSError, EOFError): + raise ConfigRetrievalError(get_exception()) if len(configxml[-1]) == 0: break stream.close() @@ -777,7 +1053,7 @@ class Speedtest(object): if int(uh.code) != 200: return None - printer(''.encode().join(configxml), debug=True) + printer('Config XML:\n%s' % ''.encode().join(configxml), debug=True) try: root = ET.fromstring(''.encode().join(configxml)) @@ -839,25 +1115,30 @@ class Speedtest(object): self.lat_lon = (float(client['lat']), float(client['lon'])) - printer(self.config, debug=True) + printer('Config:\n%r' % self.config, debug=True) return self.config - def get_servers(self, servers=None): + def get_servers(self, servers=None, exclude=None): """Retrieve a the list of speedtest.net servers, optionally filtered to servers matching those specified in the ``servers`` argument """ if servers is None: servers = [] + if exclude is None: + exclude = [] + self.servers.clear() - for i, s in enumerate(servers): - try: - servers[i] = int(s) - except ValueError: - raise InvalidServerIDType('%s is an invalid server type, must ' - 'be int' % s) + for server_list in (servers, exclude): + for i, s in enumerate(server_list): + try: + server_list[i] = int(s) + except ValueError: + raise InvalidServerIDType( + '%s is an invalid server type, must be int' % s + ) urls = [ '://www.speedtest.net/speedtest-servers-static.php', @@ -873,11 +1154,13 @@ class Speedtest(object): errors = [] for url in urls: try: - request = build_request('%s?threads=%s' % - (url, - self.config['threads']['download']), - headers=headers) - uh, e = catch_request(request) + request = build_request( + '%s?threads=%s' % (url, + self.config['threads']['download']), + headers=headers, + secure=self._secure + ) + uh, e = catch_request(request, opener=self._opener) if e: errors.append('%s' % e) raise ServersRetrievalError() @@ -886,7 +1169,10 @@ class Speedtest(object): serversxml = [] while 1: - serversxml.append(stream.read(1024)) + try: + serversxml.append(stream.read(1024)) + except (OSError, EOFError): + raise ServersRetrievalError(get_exception()) if len(serversxml[-1]) == 0: break @@ -896,7 +1182,8 @@ class Speedtest(object): if int(uh.code) != 200: raise ServersRetrievalError() - printer(''.encode().join(serversxml), debug=True) + printer('Servers XML:\n%s' % ''.encode().join(serversxml), + debug=True) try: try: @@ -917,7 +1204,8 @@ class Speedtest(object): if servers and int(attrib.get('id')) not in servers: continue - if int(attrib.get('id')) in self.config['ignore_servers']: + if (int(attrib.get('id')) in self.config['ignore_servers'] + or int(attrib.get('id')) in exclude): continue try: @@ -934,14 +1222,12 @@ class Speedtest(object): except KeyError: self.servers[d] = [attrib] - printer(''.encode().join(serversxml), debug=True) - break except ServersRetrievalError: continue - if servers and not self.servers: + if (servers or exclude) and not self.servers: raise NoMatchedServers() return self.servers @@ -960,7 +1246,7 @@ class Speedtest(object): url = server request = build_request(url) - uh, e = catch_request(request) + uh, e = catch_request(request, opener=self._opener) if e: raise SpeedtestMiniConnectFailure('Failed to connect to %s' % server) @@ -973,7 +1259,9 @@ class Speedtest(object): if not extension: for ext in ['php', 'asp', 'aspx', 'jsp']: try: - f = urlopen('%s/speedtest/upload.%s' % (url, ext)) + f = self._opener.open( + '%s/speedtest/upload.%s' % (url, ext) + ) except Exception: pass else: @@ -1015,7 +1303,7 @@ class Speedtest(object): continue break - printer(self.closest, debug=True) + printer('Closest Servers:\n%r' % self.closest, debug=True) return self.closest def get_best_server(self, servers=None): @@ -1028,26 +1316,44 @@ class Speedtest(object): servers = self.get_closest_servers() servers = self.closest + if self._source_address: + source_address_tuple = (self._source_address, 0) + else: + source_address_tuple = None + + user_agent = build_user_agent() + results = {} for server in servers: cum = [] url = os.path.dirname(server['url']) - urlparts = urlparse('%s/latency.txt' % url) - printer('%s %s/latency.txt' % ('GET', url), debug=True) - for _ in range(0, 3): + stamp = int(timeit.time.time() * 1000) + latency_url = '%s/latency.txt?x=%s' % (url, stamp) + for i in range(0, 3): + this_latency_url = '%s.%s' % (latency_url, i) + printer('%s %s' % ('GET', this_latency_url), + debug=True) + urlparts = urlparse(latency_url) try: if urlparts[0] == 'https': - h = HTTPSConnection(urlparts[1]) + h = SpeedtestHTTPSConnection( + urlparts[1], + source_address=source_address_tuple + ) else: - h = HTTPConnection(urlparts[1]) - headers = {'User-Agent': USER_AGENT} + h = SpeedtestHTTPConnection( + urlparts[1], + source_address=source_address_tuple + ) + headers = {'User-Agent': user_agent} + path = '%s?%s' % (urlparts[2], urlparts[4]) start = timeit.default_timer() - h.request("GET", urlparts[2], headers=headers) + h.request("GET", path, headers=headers) r = h.getresponse() total = (timeit.default_timer() - start) except HTTP_ERRORS: e = get_exception() - printer('%r' % e, debug=True) + printer('ERROR: %r' % e, debug=True) cum.append(3600) continue @@ -1072,8 +1378,8 @@ class Speedtest(object): self.results.ping = fastest self.results.server = best - self.best.update(best) - printer(best, debug=True) + self._best.update(best) + printer('Best Server:\n%r' % best, debug=True) return best def download(self, callback=do_nothing): @@ -1088,12 +1394,20 @@ class Speedtest(object): request_count = len(urls) requests = [] for i, url in enumerate(urls): - requests.append(build_request(url, bump=i)) + requests.append( + build_request(url, bump=i, secure=self._secure) + ) def producer(q, requests, request_count): for i, request in enumerate(requests): - thread = HTTPDownloader(i, request, start, - self.config['length']['download']) + thread = HTTPDownloader( + i, + request, + start, + self.config['length']['download'], + opener=self._opener, + shutdown_event=self._shutdown_event + ) thread.start() q.put(thread, True) callback(i, request_count, start=True) @@ -1146,20 +1460,32 @@ class Speedtest(object): for i, size in enumerate(sizes): # We set ``0`` for ``start`` and handle setting the actual # ``start`` in ``HTTPUploader`` to get better measurements - data = HTTPUploaderData(size, 0, self.config['length']['upload']) + data = HTTPUploaderData( + size, + 0, + self.config['length']['upload'], + shutdown_event=self._shutdown_event + ) if pre_allocate: data.pre_allocate() requests.append( ( - build_request(self.best['url'], data), + build_request(self.best['url'], data, secure=self._secure), size ) ) def producer(q, requests, request_count): for i, request in enumerate(requests[:request_count]): - thread = HTTPUploader(i, request[0], start, request[1], - self.config['length']['upload']) + thread = HTTPUploader( + i, + request[0], + start, + request[1], + self.config['length']['upload'], + opener=self._opener, + shutdown_event=self._shutdown_event + ) thread.start() q.put(thread, True) callback(i, request_count, start=True) @@ -1195,32 +1521,28 @@ class Speedtest(object): return self.results.upload -def ctrl_c(signum, frame): +def ctrl_c(shutdown_event): """Catch Ctrl-C key sequence and set a SHUTDOWN_EVENT for our threaded operations """ - - SHUTDOWN_EVENT.set() - print_('\nCancelling...') - sys.exit(0) + def inner(signum, frame): + shutdown_event.set() + printer('\nCancelling...', error=True) + sys.exit(0) + return inner def version(): """Print the version""" - print_(__version__) + printer(__version__) sys.exit(0) def csv_header(delimiter=','): """Print the CSV Headers""" - row = ['Server ID', 'Sponsor', 'Server Name', 'Timestamp', 'Distance', - 'Ping', 'Download', 'Upload'] - out = StringIO() - writer = csv.writer(out, delimiter=delimiter, lineterminator='') - writer.writerow([to_utf8(v) for v in row]) - print_(out.getvalue()) + printer(SpeedtestResults.csv_header(delimiter=delimiter)) sys.exit(0) @@ -1273,11 +1595,15 @@ def parse_args(): parser.add_argument('--list', action='store_true', help='Display a list of speedtest.net servers ' 'sorted by distance') - parser.add_argument('--server', help='Specify a server ID to test against', - type=PARSER_TYPE_INT) + parser.add_argument('--server', type=PARSER_TYPE_INT, action='append', + help='Specify a server ID to test against. Can be ' + 'supplied multiple times') + parser.add_argument('--exclude', type=PARSER_TYPE_INT, action='append', + help='Exclude a server from selection. Can be ' + 'supplied multiple times') parser.add_argument('--mini', help='URL of the Speedtest Mini server') parser.add_argument('--source', help='Source IP address to bind to') - parser.add_argument('--timeout', default=10, type=PARSER_TYPE_INT, + parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT, help='HTTP timeout in seconds. Default 10') parser.add_argument('--secure', action='store_true', help='Use HTTPS instead of HTTP when communicating ' @@ -1335,17 +1661,23 @@ def format_speed(speed_bytes_per_second, unit): return '%0.2f %s%s/s' % (speed, seq[i], unit[0]) -def printer(string, quiet=False, debug=False, **kwargs): - """Helper function to print a string only when not quiet""" +def printer(string, quiet=False, debug=False, error=False, **kwargs): + """Helper function print a string with various features""" if debug and not DEBUG: return if debug: - out = '\033[1;30mDEBUG: %s\033[0m' % string + if sys.stdout.isatty(): + out = '\033[1;30mDEBUG: %s\033[0m' % string + else: + out = 'DEBUG: %s' % string else: out = string + if error: + kwargs['file'] = sys.stderr + if not quiet: print_(out, **kwargs) @@ -1353,10 +1685,10 @@ def printer(string, quiet=False, debug=False, **kwargs): def shell(): """Run the full speedtest.net test""" - global SHUTDOWN_EVENT, SOURCE, SCHEME, DEBUG - SHUTDOWN_EVENT = threading.Event() + global DEBUG + shutdown_event = threading.Event() - signal.signal(signal.SIGINT, ctrl_c) + signal.signal(signal.SIGINT, ctrl_c(shutdown_event)) args = parse_args() @@ -1376,25 +1708,12 @@ def shell(): validate_optional_args(args) - socket.setdefaulttimeout(args.timeout) - - # If specified bind to a specific IP address - if args.source: - SOURCE = args.source - socket.socket = bound_socket - - if args.secure: - SCHEME = 'https' - debug = getattr(args, 'debug', False) if debug == 'SUPPRESSHELP': debug = False if debug: DEBUG = True - # Pre-cache the user agent string - build_user_agent() - if args.simple or args.csv or args.json: quiet = True else: @@ -1409,20 +1728,24 @@ def shell(): if quiet or debug: callback = do_nothing else: - callback = print_dots + callback = print_dots(shutdown_event) printer('Retrieving speedtest.net configuration...', quiet) try: - speedtest = Speedtest() - except (ConfigRetrievalError, HTTP_ERRORS): - printer('Cannot retrieve speedtest configuration') + speedtest = Speedtest( + source_address=args.source, + timeout=args.timeout, + secure=args.secure + ) + except (ConfigRetrievalError,) + HTTP_ERRORS: + printer('Cannot retrieve speedtest configuration', error=True) raise SpeedtestCLIError(get_exception()) if args.list: try: speedtest.get_servers() - except (ServersRetrievalError, HTTP_ERRORS): - print_('Cannot retrieve speedtest server list') + except (ServersRetrievalError,) + HTTP_ERRORS: + printer('Cannot retrieve speedtest server list', error=True) raise SpeedtestCLIError(get_exception()) for _, servers in sorted(speedtest.servers.items()): @@ -1430,35 +1753,38 @@ def shell(): line = ('%(id)5s) %(sponsor)s (%(name)s, %(country)s) ' '[%(d)0.2f km]' % server) try: - print_(line) + printer(line) except IOError: e = get_exception() if e.errno != errno.EPIPE: raise sys.exit(0) - # Set a filter of servers to retrieve - servers = [] - if args.server: - servers.append(args.server) - printer('Testing from %(isp)s (%(ip)s)...' % speedtest.config['client'], quiet) if not args.mini: printer('Retrieving speedtest.net server list...', quiet) try: - speedtest.get_servers(servers) + speedtest.get_servers(servers=args.server, exclude=args.exclude) except NoMatchedServers: - raise SpeedtestCLIError('No matched servers: %s' % args.server) - except (ServersRetrievalError, HTTP_ERRORS): - print_('Cannot retrieve speedtest server list') + raise SpeedtestCLIError( + 'No matched servers: %s' % + ', '.join('%s' % s for s in args.server) + ) + except (ServersRetrievalError,) + HTTP_ERRORS: + printer('Cannot retrieve speedtest server list', error=True) raise SpeedtestCLIError(get_exception()) except InvalidServerIDType: - raise SpeedtestCLIError('%s is an invalid server type, must ' - 'be an int' % args.server) + raise SpeedtestCLIError( + '%s is an invalid server type, must ' + 'be an int' % ', '.join('%s' % s for s in args.server) + ) - printer('Selecting best server based on ping...', quiet) + if args.server and len(args.server) == 1: + printer('Retrieving information for the selected server...', quiet) + else: + printer('Selecting best server based on ping...', quiet) speedtest.get_best_server() elif args.mini: speedtest.get_best_server(speedtest.set_mini_server(args.mini)) @@ -1477,7 +1803,7 @@ def shell(): quiet) else: - printer('Skipping download test') + printer('Skipping download test', quiet) if args.upload: printer('Testing upload speed', quiet, @@ -1487,18 +1813,20 @@ def shell(): format_speed(results.upload, args.units), quiet) else: - printer('Skipping upload test') + printer('Skipping upload test', quiet) + + printer('Results:\n%r' % results.dict(), debug=True) if args.simple: print_('Ping: %s ms\nDownload: %s\nUpload: %s' % (results.ping, format_speed(results.download, args.units), format_speed(results.upload, args.units))) elif args.csv: - print_(results.csv(delimiter=args.csv_delimiter)) + printer(results.csv(delimiter=args.csv_delimiter)) elif args.json: if args.share: results.share() - print_(results.json()) + printer(results.json()) if args.share and not machine_format: printer('Share results: %s' % results.share()) @@ -1508,10 +1836,11 @@ def main(): try: shell() except KeyboardInterrupt: - print_('\nCancelling...') + printer('\nCancelling...', error=True) except (SpeedtestException, SystemExit): e = get_exception() - if getattr(e, 'code', 1) != 0: + # Ignore a successful exit, or argparse exit + if getattr(e, 'code', 1) not in (0, 2): raise SystemExit('ERROR: %s' % e) diff --git a/tests/scripts/source.py b/tests/scripts/source.py new file mode 100644 index 0000000..357f4c6 --- /dev/null +++ b/tests/scripts/source.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2018 Matt Martz +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sys +import subprocess + +cmd = [sys.executable, 'speedtest.py', '--source', '127.0.0.1'] + +p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE +) + +stdout, stderr = p.communicate() + +if p.returncode != 1: + raise SystemExit('%s did not fail with exit code 1' % ' '.join(cmd)) + +if 'Invalid argument'.encode() not in stderr: + raise SystemExit( + '"Invalid argument" not found in stderr:\n%s' % stderr.decode() + ) diff --git a/tox.ini b/tox.ini index 477fa1c..8a63b5b 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,8 @@ commands = {envpython} -V {envpython} -m compileall speedtest.py {envpython} speedtest.py + {envpython} speedtest.py --source 172.17.0.1 + {envpython} tests/scripts/source.py [testenv:flake8] basepython=python @@ -19,3 +21,5 @@ commands = pypy -V pypy -m compileall speedtest.py pypy speedtest.py + pypy speedtest.py --source 172.17.0.1 + pypy tests/scripts/source.py