diff --git a/speedtest.py b/speedtest.py index 7a99c99..1a00659 100755 --- a/speedtest.py +++ b/speedtest.py @@ -321,6 +321,16 @@ class InvalidSpeedtestMiniServer(SpeedtestException): """ +class SpeedtestCustomConnectFailure(SpeedtestException): + """Could not connect to the provided speedtest custom server""" + + +class InvalidSpeedtestCustomServer(SpeedtestException): + """Server provided as a speedtest custom server does not actually appear + to be a speedtest custom server + """ + + class ShareResultsConnectFailure(SpeedtestException): """Could not connect to speedtest.net API to POST results""" @@ -345,6 +355,14 @@ class SpeedtestMissingBestServer(SpeedtestException): """get_best_server not called or not able to determine best server""" +def make_source_address_tuple(source_address): + if isinstance(source_address, (list, tuple)): + return source_address + elif source_address: + return (source_address, 0) + return None + + def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None): """Connect to *address* and return the socket object. @@ -386,6 +404,22 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, raise socket.error("getaddrinfo returns an empty list") +def connection_factory(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, + source_address=None): + try: + return socket.create_connection( + address, + timeout, + source_address + ) + except (AttributeError, TypeError): + return create_connection( + address, + timeout, + source_address + ) + + class SpeedtestHTTPConnection(HTTPConnection): """Custom HTTPConnection to support source_address across Python 2.4 - Python 3 @@ -401,18 +435,11 @@ class SpeedtestHTTPConnection(HTTPConnection): def connect(self): """Connect to the host and port specified in __init__.""" - try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = connection_factory( + (self.host, self.port), + self.timeout, + self.source_address + ) if HTTPSConnection: @@ -519,18 +546,16 @@ def build_opener(source_address=None, timeout=10): printer('Timeout set to %d' % timeout, debug=True) + source_address = make_source_address_tuple(source_address) if source_address: - source_address_tuple = (source_address, 0) - printer('Binding to source address: %r' % (source_address_tuple,), + printer('Binding to source address: %r' % (source_address,), debug=True) - else: - source_address_tuple = None handlers = [ ProxyHandler(), - SpeedtestHTTPHandler(source_address=source_address_tuple, + SpeedtestHTTPHandler(source_address=source_address, timeout=timeout), - SpeedtestHTTPSHandler(source_address=source_address_tuple, + SpeedtestHTTPSHandler(source_address=source_address, timeout=timeout), HTTPDefaultErrorHandler(), HTTPRedirectHandler(), @@ -726,7 +751,7 @@ class HTTPDownloader(threading.Thread): shutdown_event=None): threading.Thread.__init__(self) self.request = request - self.result = [0] + self.result = 0 self.starttime = start self.timeout = timeout self.i = i @@ -747,14 +772,69 @@ class HTTPDownloader(threading.Thread): while (not self._shutdown_event.isSet() and (timeit.default_timer() - self.starttime) <= self.timeout): - self.result.append(len(f.read(10240))) - if self.result[-1] == 0: + data = len(f.read(10240)) + if data == 0: break + self.result += data f.close() except IOError: pass +class SocketTestBase(threading.Thread): + def __init__(self, i, address, size, start, timeout, shutdown_event=None, + source_address=None): + threading.Thread.__init__(self) + self.result = 0 + self.starttime = start + self.timeout = timeout + self.i = i + self.size = size + self.remaining = self.size + + self._address = address + + if shutdown_event: + self._shutdown_event = shutdown_event + else: + self._shutdown_event = FakeShutdownEvent() + + self.sock = connection_factory( + address, + timeout=timeout, + source_address=source_address + ) + + +class SocketDownloader(SocketTestBase): + def run(self): + try: + if (timeit.default_timer() - self.starttime) <= self.timeout: + self.sock.sendall('HI\n'.encode()) + self.sock.recv(1024) + + while (self.remaining and not self._shutdown_event.isSet() and + (timeit.default_timer() - self.starttime) <= + self.timeout): + + if self.remaining > 1000000: + ask = 1000000 + else: + ask = self.remaining + + down = 0 + self.sock.sendall(('DOWNLOAD %d\n' % ask).encode()) + while down < ask: + down += len(self.sock.recv(10240)) + + self.result += down + self.remaining -= down + + self.sock.close() + except IOError: + pass + + class HTTPUploaderData(object): """File like object to improve cutting off the upload once the timeout has been reached @@ -772,7 +852,7 @@ class HTTPUploaderData(object): self._data = None - self.total = [0] + self.total = 0 def pre_allocate(self): chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' @@ -800,7 +880,7 @@ class HTTPUploaderData(object): if ((timeit.default_timer() - self.start) <= self.timeout and not self._shutdown_event.isSet()): chunk = self.data.read(n) - self.total.append(len(chunk)) + self.total += len(chunk) return chunk else: raise SpeedtestUploadTimeout() @@ -848,11 +928,41 @@ class HTTPUploader(threading.Thread): f = self._opener(request) f.read(11) f.close() - self.result = sum(self.request.data.total) + self.result = self.request.data.total else: self.result = 0 except (IOError, SpeedtestUploadTimeout): - self.result = sum(self.request.data.total) + self.result = self.request.data.total + + +class SocketUploader(SocketTestBase): + def run(self): + try: + if (timeit.default_timer() - self.starttime) <= self.timeout: + self.sock.sendall('HI\n'.encode()) + self.sock.recv(1024) + + while (self.remaining and not self._shutdown_event.isSet() and + (timeit.default_timer() - self.starttime) <= + self.timeout): + + if self.remaining > 100000: + give = 100000 + else: + give = self.remaining + + header = ('UPLOAD %d 0\n' % give).encode() + data = '0'.encode() * (give - len(header)) + + self.sock.sendall(header) + self.sock.sendall(data) + self.sock.recv(24) + self.result += give + self.remaining -= give + + self.sock.close() + except IOError: + pass class SpeedtestResults(object): @@ -1010,7 +1120,7 @@ class Speedtest(object): """Class for performing standard speedtest.net testing operations""" def __init__(self, config=None, source_address=None, timeout=10, - secure=False, shutdown_event=None): + secure=False, shutdown_event=None, use_socket=False): self.config = {} self._source_address = source_address @@ -1024,16 +1134,19 @@ class Speedtest(object): else: self._shutdown_event = FakeShutdownEvent() - self.get_config() + self._use_socket = use_socket + if config is not None: self.config.update(config) + else: + self.get_config() self.servers = {} self.closest = [] self._best = {} self.results = SpeedtestResults( - client=self.config['client'], + client=self.config.get('client'), opener=self._opener, secure=secure, ) @@ -1118,9 +1231,14 @@ class Speedtest(object): up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032] sizes = { 'upload': up_sizes[ratio - 1:], - 'download': [350, 500, 750, 1000, 1500, 2000, 2500, - 3000, 3500, 4000] } + if self._use_socket: + sizes['download'] = [245388, 505544, 1118012, 1986284, 4468241, + 7907740, 12407926, 17816816, 24262167, + 31625365] + else: + sizes['download'] = [350, 500, 750, 1000, 1500, 2000, 2500, + 3000, 3500, 4000] size_count = len(sizes['upload']) @@ -1265,6 +1383,9 @@ class Speedtest(object): or int(attrib.get('id')) in exclude): continue + host, port = attrib['host'].split(':') + attrib['host'] = (host, int(port)) + try: d = distance(self.lat_lon, (float(attrib.get('lat')), @@ -1343,6 +1464,87 @@ class Speedtest(object): return self.servers + def set_custom_server(self, url, include=None, exclude=None): + request = build_request(url) + uh, e = catch_request(request, opener=self._opener) + if e: + raise SpeedtestCustomConnectFailure( + 'Failed to connect to %s' % url + ) + else: + text = uh.read() + uh.close() + + match = re.search('window.ST_PARAMS = (\{.*\});'.encode(), text) + + try: + params = json.loads(match.group(1)) + except (TypeError, ValueError): + e = get_exception() + printer('ERROR: %r' % e, debug=True) + raise InvalidSpeedtestCustomServer( + 'Invalid Speedtest Custom Server: %s' % url + ) + + test_globals = params['testGlobals'] + + config = { + 'client': { + 'ip': test_globals['ipAddress'], + 'lat': test_globals['location']['latitude'], + 'lon': test_globals['location']['longitude'], + 'country': test_globals['location']['countryCode'], + 'isp': test_globals['ispName'], + }, + 'ignore_servers': [], + 'sizes': { + 'upload': [524288, 1048576, 7340032], + }, + 'counts': { + 'upload': 17, + 'download': 4 + }, + 'threads': { + 'upload': 2, + 'download': 8 + }, + 'length': { + 'upload': 10, + 'download': 10 + }, + 'upload_max': 51, + } + + if self._use_socket: + config['sizes']['download'] = [245388, 505544, 1118012, 1986284, + 4468241, 7907740, 12407926, + 17816816, 24262167, 31625365] + else: + config['sizes']['download'] = [350, 500, 750, 1000, 1500, 2000, + 2500, 3000, 3500, 4000] + self.config = config + self.results.client = config['client'] + + servers = {} + for server in params['serverList']: + if include and int(server.get('id')) not in include: + continue + if exclude and int(server.get('id')) in exclude: + continue + + host, port = server['host'].split(':') + server['host'] = (host, int(port)) + server['country'] = server['cc'] + d = server.pop('distance') + server['d'] = d + try: + servers[d].append(server) + except KeyError: + servers[d] = [server] + + self.servers = servers + return self.servers + def get_closest_servers(self, limit=5): """Limit servers to the closest speedtest.net servers based on geographic distance @@ -1363,20 +1565,8 @@ class Speedtest(object): printer('Closest Servers:\n%r' % self.closest, debug=True) return self.closest - def get_best_server(self, servers=None): - """Perform a speedtest.net "ping" to determine which speedtest.net - server has the lowest latency - """ - - if not servers: - if not self.closest: - servers = self.get_closest_servers() - servers = self.closest - - if self._source_address: - source_address_tuple = (self._source_address, 0) - else: - source_address_tuple = None + def _http_latency(self, servers): + source_address = make_source_address_tuple(self._source_address) user_agent = build_user_agent() @@ -1395,12 +1585,14 @@ class Speedtest(object): if urlparts[0] == 'https': h = SpeedtestHTTPSConnection( urlparts[1], - source_address=source_address_tuple + source_address=source_address, + timeout=self._timeout, ) else: h = SpeedtestHTTPConnection( urlparts[1], - source_address=source_address_tuple + source_address=source_address, + timeout=self._timeout, ) headers = {'User-Agent': user_agent} path = '%s?%s' % (urlparts[2], urlparts[4]) @@ -1424,11 +1616,77 @@ class Speedtest(object): avg = round((sum(cum) / 6) * 1000.0, 3) results[avg] = server + return results + + def _socket_latency(self, servers): + source_address = make_source_address_tuple(self._source_address) + + results = {} + for server in servers: + cum = [] + try: + sock = connection_factory( + server['host'], + timeout=self._timeout, + source_address=source_address + ) + sock.sendall('HI\n'.encode()) + sock.recv(1024) + except socket.error: + e = get_exception() + printer('ERROR: %r' % e, debug=True) + cum.append(3600 * 3) + continue + + for _ in range(0, 3): + printer('%s %s' % ('PING', server['host']), + debug=True) + start = timeit.default_timer() + try: + sock.sendall( + ('PING %d\n' % + (int(timeit.time.time()) * 1000,)).encode() + ) + resp = sock.recv(1024) + except socket.errror: + e = get_exception() + printer('ERROR: %r' % e, debug=True) + cum.append(3600) + continue + total = (timeit.default_timer() - start) + if resp.startswith('PONG '.encode()): + cum.append(total) + else: + cum.append(3600) + + avg = round((sum(cum) / 3) * 1000.0, 3) + results[avg] = server + + return results + + def get_best_server(self, servers=None): + """Perform a speedtest.net "ping" to determine which speedtest.net + server has the lowest latency + """ + if not servers: + if not self.closest: + servers = self.get_closest_servers() + servers = self.closest + + if self._use_socket: + results = self._socket_latency(servers) + else: + results = self._http_latency(servers) + try: fastest = sorted(results.keys())[0] except IndexError: + if self._use_socket: + extra = ' Try the HTTP based tests by removing --socket.' + else: + extra = '' raise SpeedtestBestServerFailure('Unable to connect to servers to ' - 'test latency.') + 'test latency.%s' % extra) best = results[fastest] best['latency'] = fastest @@ -1442,29 +1700,54 @@ class Speedtest(object): def download(self, callback=do_nothing): """Test download speed against speedtest.net""" - urls = [] - for size in self.config['sizes']['download']: - for _ in range(0, self.config['counts']['download']): - urls.append('%s/random%sx%s.jpg' % - (os.path.dirname(self.best['url']), size, size)) + if self._use_socket: + requests = [] + for size in self.config['sizes']['download']: + for _ in range(0, self.config['counts']['download']): + requests.append(size) + printer( + 'DOWNLOAD %s %s' % (self.best['host'], size), + debug=True + ) - request_count = len(urls) - requests = [] - for i, url in enumerate(urls): - requests.append( - build_request(url, bump=i, secure=self._secure) - ) + request_count = len(requests) + else: + urls = [] + for size in self.config['sizes']['download']: + for _ in range(0, self.config['counts']['download']): + urls.append( + '%s/random%sx%s.jpg' % + (os.path.dirname(self.best['url']), size, size) + ) + + request_count = len(urls) + requests = [] + for i, url in enumerate(urls): + requests.append( + build_request(url, bump=i, secure=self._secure) + ) def producer(q, requests, request_count): for i, request in enumerate(requests): - thread = HTTPDownloader( - i, - request, - start, - self.config['length']['download'], - opener=self._opener, - shutdown_event=self._shutdown_event - ) + if self._use_socket: + thread = SocketDownloader( + i, + self.best['host'], + request, + start, + self.config['length']['download'], + shutdown_event=self._shutdown_event, + source_address=self._source_address + ) + else: + thread = HTTPDownloader( + i, + request, + start, + self.config['length']['download'], + opener=self._opener, + shutdown_event=self._shutdown_event + ) thread.start() q.put(thread, True) callback(i, request_count, start=True) @@ -1476,7 +1759,7 @@ class Speedtest(object): thread = q.get(True) while thread.isAlive(): thread.join(timeout=0.1) - finished.append(sum(thread.result)) + finished.append(thread.result) callback(thread.i, request_count, end=True) q = Queue(self.config['threads']['download']) @@ -1515,34 +1798,56 @@ class Speedtest(object): requests = [] for i, size in enumerate(sizes): - # We set ``0`` for ``start`` and handle setting the actual - # ``start`` in ``HTTPUploader`` to get better measurements - data = HTTPUploaderData( - size, - 0, - self.config['length']['upload'], - shutdown_event=self._shutdown_event - ) - if pre_allocate: - data.pre_allocate() - requests.append( - ( - build_request(self.best['url'], data, secure=self._secure), - size + if self._use_socket: + requests.append(size) + printer( + 'UPLOAD %s %s' % (self.best['host'], size), + debug=True + ) + else: + # We set ``0`` for ``start`` and handle setting the actual + # ``start`` in ``HTTPUploader`` to get better measurements + data = HTTPUploaderData( + size, + 0, + self.config['length']['upload'], + shutdown_event=self._shutdown_event + ) + if pre_allocate: + data.pre_allocate() + requests.append( + ( + build_request( + self.best['url'], + data, + secure=self._secure + ), + size + ) ) - ) def producer(q, requests, request_count): for i, request in enumerate(requests[:request_count]): - thread = HTTPUploader( - i, - request[0], - start, - request[1], - self.config['length']['upload'], - opener=self._opener, - shutdown_event=self._shutdown_event - ) + if self._use_socket: + thread = SocketUploader( + i, + self.best['host'], + request, + start, + self.config['length']['upload'], + shutdown_event=self._shutdown_event, + source_address=self._source_address + ) + else: + thread = HTTPUploader( + i, + request[0], + start, + request[1], + self.config['length']['upload'], + opener=self._opener, + shutdown_event=self._shutdown_event + ) thread.start() q.put(thread, True) callback(i, request_count, start=True) @@ -1659,12 +1964,15 @@ def parse_args(): help='Exclude a server from selection. Can be ' 'supplied multiple times') parser.add_argument('--mini', help='URL of the Speedtest Mini server') + parser.add_argument('--custom', help='URL of the Speedtest Custom Server') parser.add_argument('--source', help='Source IP address to bind to') parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT, help='HTTP timeout in seconds. Default 10') parser.add_argument('--secure', action='store_true', help='Use HTTPS instead of HTTP when communicating ' 'with speedtest.net operated servers') + parser.add_argument('--socket', action='store_true', + help='Use socket test instead of HTTP based tests') parser.add_argument('--no-pre-allocate', dest='pre_allocate', action='store_const', default=True, const=False, help='Do not pre allocate upload data. Pre allocation ' @@ -1694,6 +2002,7 @@ def validate_optional_args(args): """ optional_args = { 'json': ('json/simplejson python module', json), + 'custom': ('json/simplejson python module', json), 'secure': ('SSL support', HTTPSConnection), } @@ -1774,18 +2083,33 @@ def shell(): printer('Retrieving speedtest.net configuration...', quiet) try: + kwargs = {} + if args.custom: + kwargs['config'] = {} speedtest = Speedtest( source_address=args.source, timeout=args.timeout, - secure=args.secure + secure=args.secure, + use_socket=args.socket, + **kwargs ) except (ConfigRetrievalError,) + HTTP_ERRORS: printer('Cannot retrieve speedtest configuration', error=True) raise SpeedtestCLIError(get_exception()) + if args.custom: + kwargs = {} + if not args.list: + kwargs.update({ + 'include': args.server, + 'exclude': args.exclude, + }) + speedtest.set_custom_server(args.custom, **kwargs) + if args.list: try: - speedtest.get_servers() + if not args.custom: + speedtest.get_servers() except (ServersRetrievalError,) + HTTP_ERRORS: printer('Cannot retrieve speedtest server list', error=True) raise SpeedtestCLIError(get_exception()) @@ -1807,21 +2131,25 @@ def shell(): if not args.mini: printer('Retrieving speedtest.net server list...', quiet) - try: - speedtest.get_servers(servers=args.server, exclude=args.exclude) - except NoMatchedServers: - raise SpeedtestCLIError( - 'No matched servers: %s' % - ', '.join('%s' % s for s in args.server) - ) - except (ServersRetrievalError,) + HTTP_ERRORS: - printer('Cannot retrieve speedtest server list', error=True) - raise SpeedtestCLIError(get_exception()) - except InvalidServerIDType: - raise SpeedtestCLIError( - '%s is an invalid server type, must ' - 'be an int' % ', '.join('%s' % s for s in args.server) - ) + if not args.custom: + try: + speedtest.get_servers( + servers=args.server, + exclude=args.exclude + ) + except NoMatchedServers: + raise SpeedtestCLIError( + 'No matched servers: %s' % + ', '.join('%s' % s for s in args.server) + ) + except (ServersRetrievalError,) + HTTP_ERRORS: + printer('Cannot retrieve speedtest server list', error=True) + raise SpeedtestCLIError(get_exception()) + except InvalidServerIDType: + raise SpeedtestCLIError( + '%s is an invalid server type, must ' + 'be an int' % ', '.join('%s' % s for s in args.server) + ) if args.server and len(args.server) == 1: printer('Retrieving information for the selected server...', quiet)