Compare commits

...

7 Commits

Author SHA1 Message Date
Matt Martz 3692ad259b flake8 fixup 2018-03-23 11:08:35 -05:00
Matt Martz 59d4172446 Add ability to set new custom speedtest servers (not mini) 2018-03-23 11:06:57 -05:00
Matt Martz 5d0c62faec Handle socket failures while testing latency 2018-03-23 09:37:51 -05:00
Matt Martz ced2890261 Create base class SocketTestBase to dedupe code 2018-03-23 09:37:38 -05:00
Matt Martz fedf42e838 flake8 fixes 2018-03-22 17:00:29 -05:00
Matt Martz 4ce4019331 Add socket based latency test 2018-03-22 16:58:34 -05:00
Matt Martz 0ef2f6b04c First pass at adding ability to test using newer socket based connection 2018-03-22 15:49:28 -05:00
1 changed files with 436 additions and 108 deletions

View File

@ -318,6 +318,16 @@ class InvalidSpeedtestMiniServer(SpeedtestException):
"""
class SpeedtestCustomConnectFailure(SpeedtestException):
"""Could not connect to the provided speedtest custom server"""
class InvalidSpeedtestCustomServer(SpeedtestException):
"""Server provided as a speedtest custom server does not actually appear
to be a speedtest custom server
"""
class ShareResultsConnectFailure(SpeedtestException):
"""Could not connect to speedtest.net API to POST results"""
@ -342,6 +352,14 @@ class SpeedtestMissingBestServer(SpeedtestException):
"""get_best_server not called or not able to determine best server"""
def make_source_address_tuple(source_address):
if isinstance(source_address, (list, tuple)):
return source_address
elif source_address:
return (source_address, 0)
return None
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
source_address=None):
"""Connect to *address* and return the socket object.
@ -383,6 +401,22 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
raise socket.error("getaddrinfo returns an empty list")
def connection_factory(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
source_address=None):
try:
return socket.create_connection(
address,
timeout,
source_address
)
except (AttributeError, TypeError):
return create_connection(
address,
timeout,
source_address
)
class SpeedtestHTTPConnection(HTTPConnection):
"""Custom HTTPConnection to support source_address across
Python 2.4 - Python 3
@ -400,18 +434,11 @@ class SpeedtestHTTPConnection(HTTPConnection):
def connect(self):
"""Connect to the host and port specified in __init__."""
try:
self.sock = socket.create_connection(
(self.host, self.port),
self.timeout,
self.source_address
)
except (AttributeError, TypeError):
self.sock = create_connection(
(self.host, self.port),
self.timeout,
self.source_address
)
self.sock = connection_factory(
(self.host, self.port),
self.timeout,
self.source_address
)
if HTTPSConnection:
@ -506,18 +533,16 @@ def build_opener(source_address=None, timeout=10):
printer('Timeout set to %d' % timeout, debug=True)
source_address = make_source_address_tuple(source_address)
if source_address:
source_address_tuple = (source_address, 0)
printer('Binding to source address: %r' % (source_address_tuple,),
printer('Binding to source address: %r' % (source_address,),
debug=True)
else:
source_address_tuple = None
handlers = [
ProxyHandler(),
SpeedtestHTTPHandler(source_address=source_address_tuple,
SpeedtestHTTPHandler(source_address=source_address,
timeout=timeout),
SpeedtestHTTPSHandler(source_address=source_address_tuple,
SpeedtestHTTPSHandler(source_address=source_address,
timeout=timeout),
HTTPDefaultErrorHandler(),
HTTPRedirectHandler(),
@ -713,7 +738,7 @@ class HTTPDownloader(threading.Thread):
shutdown_event=None):
threading.Thread.__init__(self)
self.request = request
self.result = [0]
self.result = 0
self.starttime = start
self.timeout = timeout
self.i = i
@ -734,14 +759,69 @@ class HTTPDownloader(threading.Thread):
while (not self._shutdown_event.isSet() and
(timeit.default_timer() - self.starttime) <=
self.timeout):
self.result.append(len(f.read(10240)))
if self.result[-1] == 0:
data = len(f.read(10240))
if data == 0:
break
self.result += data
f.close()
except IOError:
pass
class SocketTestBase(threading.Thread):
def __init__(self, i, address, size, start, timeout, shutdown_event=None,
source_address=None):
threading.Thread.__init__(self)
self.result = 0
self.starttime = start
self.timeout = timeout
self.i = i
self.size = size
self.remaining = self.size
self._address = address
if shutdown_event:
self._shutdown_event = shutdown_event
else:
self._shutdown_event = FakeShutdownEvent()
self.sock = connection_factory(
address,
timeout=timeout,
source_address=source_address
)
class SocketDownloader(SocketTestBase):
def run(self):
try:
if (timeit.default_timer() - self.starttime) <= self.timeout:
self.sock.sendall('HI\n'.encode())
self.sock.recv(1024)
while (self.remaining and not self._shutdown_event.isSet() and
(timeit.default_timer() - self.starttime) <=
self.timeout):
if self.remaining > 1000000:
ask = 1000000
else:
ask = self.remaining
down = 0
self.sock.sendall(('DOWNLOAD %d\n' % ask).encode())
while down < ask:
down += len(self.sock.recv(10240))
self.result += down
self.remaining -= down
self.sock.close()
except IOError:
pass
class HTTPUploaderData(object):
"""File like object to improve cutting off the upload once the timeout
has been reached
@ -759,7 +839,7 @@ class HTTPUploaderData(object):
self._data = None
self.total = [0]
self.total = 0
def pre_allocate(self):
chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'
@ -787,7 +867,7 @@ class HTTPUploaderData(object):
if ((timeit.default_timer() - self.start) <= self.timeout and
not self._shutdown_event.isSet()):
chunk = self.data.read(n)
self.total.append(len(chunk))
self.total += len(chunk)
return chunk
else:
raise SpeedtestUploadTimeout()
@ -835,11 +915,41 @@ class HTTPUploader(threading.Thread):
f = self._opener(request)
f.read(11)
f.close()
self.result = sum(self.request.data.total)
self.result = self.request.data.total
else:
self.result = 0
except (IOError, SpeedtestUploadTimeout):
self.result = sum(self.request.data.total)
self.result = self.request.data.total
class SocketUploader(SocketTestBase):
def run(self):
try:
if (timeit.default_timer() - self.starttime) <= self.timeout:
self.sock.sendall('HI\n'.encode())
self.sock.recv(1024)
while (self.remaining and not self._shutdown_event.isSet() and
(timeit.default_timer() - self.starttime) <=
self.timeout):
if self.remaining > 100000:
give = 100000
else:
give = self.remaining
header = ('UPLOAD %d 0\n' % give).encode()
data = '0'.encode() * (give - len(header))
self.sock.sendall(header)
self.sock.sendall(data)
self.sock.recv(24)
self.result += give
self.remaining -= give
self.sock.close()
except IOError:
pass
class SpeedtestResults(object):
@ -997,7 +1107,7 @@ class Speedtest(object):
"""Class for performing standard speedtest.net testing operations"""
def __init__(self, config=None, source_address=None, timeout=10,
secure=False, shutdown_event=None):
secure=False, shutdown_event=None, use_socket=False):
self.config = {}
self._source_address = source_address
@ -1011,16 +1121,19 @@ class Speedtest(object):
else:
self._shutdown_event = FakeShutdownEvent()
self.get_config()
self._use_socket = use_socket
if config is not None:
self.config.update(config)
else:
self.get_config()
self.servers = {}
self.closest = []
self._best = {}
self.results = SpeedtestResults(
client=self.config['client'],
client=self.config.get('client'),
opener=self._opener,
secure=secure,
)
@ -1105,9 +1218,14 @@ class Speedtest(object):
up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032]
sizes = {
'upload': up_sizes[ratio - 1:],
'download': [350, 500, 750, 1000, 1500, 2000, 2500,
3000, 3500, 4000]
}
if self._use_socket:
sizes['download'] = [245388, 505544, 1118012, 1986284, 4468241,
7907740, 12407926, 17816816, 24262167,
31625365]
else:
sizes['download'] = [350, 500, 750, 1000, 1500, 2000, 2500,
3000, 3500, 4000]
size_count = len(sizes['upload'])
@ -1252,6 +1370,9 @@ class Speedtest(object):
or int(attrib.get('id')) in exclude):
continue
host, port = attrib['host'].split(':')
attrib['host'] = (host, int(port))
try:
d = distance(self.lat_lon,
(float(attrib.get('lat')),
@ -1330,6 +1451,87 @@ class Speedtest(object):
return self.servers
def set_custom_server(self, url, include=None, exclude=None):
request = build_request(url)
uh, e = catch_request(request, opener=self._opener)
if e:
raise SpeedtestCustomConnectFailure(
'Failed to connect to %s' % url
)
else:
text = uh.read()
uh.close()
match = re.search('window.ST_PARAMS = (\{.*\});'.encode(), text)
try:
params = json.loads(match.group(1))
except (TypeError, ValueError):
e = get_exception()
printer('ERROR: %r' % e, debug=True)
raise InvalidSpeedtestCustomServer(
'Invalid Speedtest Custom Server: %s' % url
)
test_globals = params['testGlobals']
config = {
'client': {
'ip': test_globals['ipAddress'],
'lat': test_globals['location']['latitude'],
'lon': test_globals['location']['longitude'],
'country': test_globals['location']['countryCode'],
'isp': test_globals['ispName'],
},
'ignore_servers': [],
'sizes': {
'upload': [524288, 1048576, 7340032],
},
'counts': {
'upload': 17,
'download': 4
},
'threads': {
'upload': 2,
'download': 8
},
'length': {
'upload': 10,
'download': 10
},
'upload_max': 51,
}
if self._use_socket:
config['sizes']['download'] = [245388, 505544, 1118012, 1986284,
4468241, 7907740, 12407926,
17816816, 24262167, 31625365]
else:
config['sizes']['download'] = [350, 500, 750, 1000, 1500, 2000,
2500, 3000, 3500, 4000]
self.config = config
self.results.client = config['client']
servers = {}
for server in params['serverList']:
if include and int(server.get('id')) not in include:
continue
if exclude and int(server.get('id')) in exclude:
continue
host, port = server['host'].split(':')
server['host'] = (host, int(port))
server['country'] = server['cc']
d = server.pop('distance')
server['d'] = d
try:
servers[d].append(server)
except KeyError:
servers[d] = [server]
self.servers = servers
return self.servers
def get_closest_servers(self, limit=5):
"""Limit servers to the closest speedtest.net servers based on
geographic distance
@ -1350,20 +1552,8 @@ class Speedtest(object):
printer('Closest Servers:\n%r' % self.closest, debug=True)
return self.closest
def get_best_server(self, servers=None):
"""Perform a speedtest.net "ping" to determine which speedtest.net
server has the lowest latency
"""
if not servers:
if not self.closest:
servers = self.get_closest_servers()
servers = self.closest
if self._source_address:
source_address_tuple = (self._source_address, 0)
else:
source_address_tuple = None
def _http_latency(self, servers):
source_address = make_source_address_tuple(self._source_address)
user_agent = build_user_agent()
@ -1382,12 +1572,14 @@ class Speedtest(object):
if urlparts[0] == 'https':
h = SpeedtestHTTPSConnection(
urlparts[1],
source_address=source_address_tuple
source_address=source_address,
timeout=self._timeout,
)
else:
h = SpeedtestHTTPConnection(
urlparts[1],
source_address=source_address_tuple
source_address=source_address,
timeout=self._timeout,
)
headers = {'User-Agent': user_agent}
path = '%s?%s' % (urlparts[2], urlparts[4])
@ -1411,11 +1603,77 @@ class Speedtest(object):
avg = round((sum(cum) / 6) * 1000.0, 3)
results[avg] = server
return results
def _socket_latency(self, servers):
source_address = make_source_address_tuple(self._source_address)
results = {}
for server in servers:
cum = []
try:
sock = connection_factory(
server['host'],
timeout=self._timeout,
source_address=source_address
)
sock.sendall('HI\n'.encode())
sock.recv(1024)
except socket.error:
e = get_exception()
printer('ERROR: %r' % e, debug=True)
cum.append(3600 * 3)
continue
for _ in range(0, 3):
printer('%s %s' % ('PING', server['host']),
debug=True)
start = timeit.default_timer()
try:
sock.sendall(
('PING %d\n' %
(int(timeit.time.time()) * 1000,)).encode()
)
resp = sock.recv(1024)
except socket.errror:
e = get_exception()
printer('ERROR: %r' % e, debug=True)
cum.append(3600)
continue
total = (timeit.default_timer() - start)
if resp.startswith('PONG '.encode()):
cum.append(total)
else:
cum.append(3600)
avg = round((sum(cum) / 3) * 1000.0, 3)
results[avg] = server
return results
def get_best_server(self, servers=None):
"""Perform a speedtest.net "ping" to determine which speedtest.net
server has the lowest latency
"""
if not servers:
if not self.closest:
servers = self.get_closest_servers()
servers = self.closest
if self._use_socket:
results = self._socket_latency(servers)
else:
results = self._http_latency(servers)
try:
fastest = sorted(results.keys())[0]
except IndexError:
if self._use_socket:
extra = ' Try the HTTP based tests by removing --socket.'
else:
extra = ''
raise SpeedtestBestServerFailure('Unable to connect to servers to '
'test latency.')
'test latency.%s' % extra)
best = results[fastest]
best['latency'] = fastest
@ -1429,29 +1687,54 @@ class Speedtest(object):
def download(self, callback=do_nothing):
"""Test download speed against speedtest.net"""
urls = []
for size in self.config['sizes']['download']:
for _ in range(0, self.config['counts']['download']):
urls.append('%s/random%sx%s.jpg' %
(os.path.dirname(self.best['url']), size, size))
if self._use_socket:
requests = []
for size in self.config['sizes']['download']:
for _ in range(0, self.config['counts']['download']):
requests.append(size)
printer(
'DOWNLOAD %s %s' % (self.best['host'], size),
debug=True
)
request_count = len(urls)
requests = []
for i, url in enumerate(urls):
requests.append(
build_request(url, bump=i, secure=self._secure)
)
request_count = len(requests)
else:
urls = []
for size in self.config['sizes']['download']:
for _ in range(0, self.config['counts']['download']):
urls.append(
'%s/random%sx%s.jpg' %
(os.path.dirname(self.best['url']), size, size)
)
request_count = len(urls)
requests = []
for i, url in enumerate(urls):
requests.append(
build_request(url, bump=i, secure=self._secure)
)
def producer(q, requests, request_count):
for i, request in enumerate(requests):
thread = HTTPDownloader(
i,
request,
start,
self.config['length']['download'],
opener=self._opener,
shutdown_event=self._shutdown_event
)
if self._use_socket:
thread = SocketDownloader(
i,
self.best['host'],
request,
start,
self.config['length']['download'],
shutdown_event=self._shutdown_event,
source_address=self._source_address
)
else:
thread = HTTPDownloader(
i,
request,
start,
self.config['length']['download'],
opener=self._opener,
shutdown_event=self._shutdown_event
)
thread.start()
q.put(thread, True)
callback(i, request_count, start=True)
@ -1463,7 +1746,7 @@ class Speedtest(object):
thread = q.get(True)
while thread.isAlive():
thread.join(timeout=0.1)
finished.append(sum(thread.result))
finished.append(thread.result)
callback(thread.i, request_count, end=True)
q = Queue(self.config['threads']['download'])
@ -1502,34 +1785,56 @@ class Speedtest(object):
requests = []
for i, size in enumerate(sizes):
# We set ``0`` for ``start`` and handle setting the actual
# ``start`` in ``HTTPUploader`` to get better measurements
data = HTTPUploaderData(
size,
0,
self.config['length']['upload'],
shutdown_event=self._shutdown_event
)
if pre_allocate:
data.pre_allocate()
requests.append(
(
build_request(self.best['url'], data, secure=self._secure),
size
if self._use_socket:
requests.append(size)
printer(
'UPLOAD %s %s' % (self.best['host'], size),
debug=True
)
else:
# We set ``0`` for ``start`` and handle setting the actual
# ``start`` in ``HTTPUploader`` to get better measurements
data = HTTPUploaderData(
size,
0,
self.config['length']['upload'],
shutdown_event=self._shutdown_event
)
if pre_allocate:
data.pre_allocate()
requests.append(
(
build_request(
self.best['url'],
data,
secure=self._secure
),
size
)
)
)
def producer(q, requests, request_count):
for i, request in enumerate(requests[:request_count]):
thread = HTTPUploader(
i,
request[0],
start,
request[1],
self.config['length']['upload'],
opener=self._opener,
shutdown_event=self._shutdown_event
)
if self._use_socket:
thread = SocketUploader(
i,
self.best['host'],
request,
start,
self.config['length']['upload'],
shutdown_event=self._shutdown_event,
source_address=self._source_address
)
else:
thread = HTTPUploader(
i,
request[0],
start,
request[1],
self.config['length']['upload'],
opener=self._opener,
shutdown_event=self._shutdown_event
)
thread.start()
q.put(thread, True)
callback(i, request_count, start=True)
@ -1646,12 +1951,15 @@ def parse_args():
help='Exclude a server from selection. Can be '
'supplied multiple times')
parser.add_argument('--mini', help='URL of the Speedtest Mini server')
parser.add_argument('--custom', help='URL of the Speedtest Custom Server')
parser.add_argument('--source', help='Source IP address to bind to')
parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT,
help='HTTP timeout in seconds. Default 10')
parser.add_argument('--secure', action='store_true',
help='Use HTTPS instead of HTTP when communicating '
'with speedtest.net operated servers')
parser.add_argument('--socket', action='store_true',
help='Use socket test instead of HTTP based tests')
parser.add_argument('--no-pre-allocate', dest='pre_allocate',
action='store_const', default=True, const=False,
help='Do not pre allocate upload data. Pre allocation '
@ -1681,6 +1989,7 @@ def validate_optional_args(args):
"""
optional_args = {
'json': ('json/simplejson python module', json),
'custom': ('json/simplejson python module', json),
'secure': ('SSL support', HTTPSConnection),
}
@ -1761,18 +2070,33 @@ def shell():
printer('Retrieving speedtest.net configuration...', quiet)
try:
kwargs = {}
if args.custom:
kwargs['config'] = {}
speedtest = Speedtest(
source_address=args.source,
timeout=args.timeout,
secure=args.secure
secure=args.secure,
use_socket=args.socket,
**kwargs
)
except (ConfigRetrievalError,) + HTTP_ERRORS:
printer('Cannot retrieve speedtest configuration', error=True)
raise SpeedtestCLIError(get_exception())
if args.custom:
kwargs = {}
if not args.list:
kwargs.update({
'include': args.server,
'exclude': args.exclude,
})
speedtest.set_custom_server(args.custom, **kwargs)
if args.list:
try:
speedtest.get_servers()
if not args.custom:
speedtest.get_servers()
except (ServersRetrievalError,) + HTTP_ERRORS:
printer('Cannot retrieve speedtest server list', error=True)
raise SpeedtestCLIError(get_exception())
@ -1794,21 +2118,25 @@ def shell():
if not args.mini:
printer('Retrieving speedtest.net server list...', quiet)
try:
speedtest.get_servers(servers=args.server, exclude=args.exclude)
except NoMatchedServers:
raise SpeedtestCLIError(
'No matched servers: %s' %
', '.join('%s' % s for s in args.server)
)
except (ServersRetrievalError,) + HTTP_ERRORS:
printer('Cannot retrieve speedtest server list', error=True)
raise SpeedtestCLIError(get_exception())
except InvalidServerIDType:
raise SpeedtestCLIError(
'%s is an invalid server type, must '
'be an int' % ', '.join('%s' % s for s in args.server)
)
if not args.custom:
try:
speedtest.get_servers(
servers=args.server,
exclude=args.exclude
)
except NoMatchedServers:
raise SpeedtestCLIError(
'No matched servers: %s' %
', '.join('%s' % s for s in args.server)
)
except (ServersRetrievalError,) + HTTP_ERRORS:
printer('Cannot retrieve speedtest server list', error=True)
raise SpeedtestCLIError(get_exception())
except InvalidServerIDType:
raise SpeedtestCLIError(
'%s is an invalid server type, must '
'be an int' % ', '.join('%s' % s for s in args.server)
)
if args.server and len(args.server) == 1:
printer('Retrieving information for the selected server...', quiet)