Add option to exclude servers, and allow --server and --exclude to be specified multiple times

This commit is contained in:
Matt Martz 2017-05-12 13:02:35 -05:00
parent ca72d40033
commit 6bfa5922c3
1 changed files with 29 additions and 20 deletions

View File

@ -1063,21 +1063,26 @@ class Speedtest(object):
return self.config return self.config
def get_servers(self, servers=None): def get_servers(self, servers=None, exclude=None):
"""Retrieve a the list of speedtest.net servers, optionally filtered """Retrieve a the list of speedtest.net servers, optionally filtered
to servers matching those specified in the ``servers`` argument to servers matching those specified in the ``servers`` argument
""" """
if servers is None: if servers is None:
servers = [] servers = []
if exclude is None:
exclude = []
self.servers.clear() self.servers.clear()
for i, s in enumerate(servers): for server_list in (servers, exclude):
try: for i, s in enumerate(server_list):
servers[i] = int(s) try:
except ValueError: server_list[i] = int(s)
raise InvalidServerIDType('%s is an invalid server type, must ' except ValueError:
'be int' % s) raise InvalidServerIDType(
'%s is an invalid server type, must be int' % s
)
urls = [ urls = [
'://www.speedtest.net/speedtest-servers-static.php', '://www.speedtest.net/speedtest-servers-static.php',
@ -1140,7 +1145,8 @@ class Speedtest(object):
if servers and int(attrib.get('id')) not in servers: if servers and int(attrib.get('id')) not in servers:
continue continue
if int(attrib.get('id')) in self.config['ignore_servers']: if (int(attrib.get('id')) in self.config['ignore_servers']
or int(attrib.get('id')) in exclude):
continue continue
try: try:
@ -1162,7 +1168,7 @@ class Speedtest(object):
except ServersRetrievalError: except ServersRetrievalError:
continue continue
if servers and not self.servers: if (servers or exclude) and not self.servers:
raise NoMatchedServers() raise NoMatchedServers()
return self.servers return self.servers
@ -1518,8 +1524,11 @@ def parse_args():
parser.add_argument('--list', action='store_true', parser.add_argument('--list', action='store_true',
help='Display a list of speedtest.net servers ' help='Display a list of speedtest.net servers '
'sorted by distance') 'sorted by distance')
parser.add_argument('--server', help='Specify a server ID to test against', parser.add_argument('--server', type=PARSER_TYPE_INT, action='append',
type=PARSER_TYPE_INT) help='Specify a server ID to test against')
parser.add_argument('--exclude', type=PARSER_TYPE_INT, action='append',
help='Exclude a server from selection. Can be '
'supplied multiple times')
parser.add_argument('--mini', help='URL of the Speedtest Mini server') parser.add_argument('--mini', help='URL of the Speedtest Mini server')
parser.add_argument('--source', help='Source IP address to bind to') parser.add_argument('--source', help='Source IP address to bind to')
parser.add_argument('--timeout', default=10, type=PARSER_TYPE_INT, parser.add_argument('--timeout', default=10, type=PARSER_TYPE_INT,
@ -1658,26 +1667,26 @@ def shell():
raise raise
sys.exit(0) sys.exit(0)
# Set a filter of servers to retrieve
servers = []
if args.server:
servers.append(args.server)
printer('Testing from %(isp)s (%(ip)s)...' % speedtest.config['client'], printer('Testing from %(isp)s (%(ip)s)...' % speedtest.config['client'],
quiet) quiet)
if not args.mini: if not args.mini:
printer('Retrieving speedtest.net server list...', quiet) printer('Retrieving speedtest.net server list...', quiet)
try: try:
speedtest.get_servers(servers) speedtest.get_servers(servers=args.server, exclude=args.exclude)
except NoMatchedServers: except NoMatchedServers:
raise SpeedtestCLIError('No matched servers: %s' % args.server) raise SpeedtestCLIError(
'No matched servers: %s' %
', '.join('%s' % s for s in args.server)
)
except (ServersRetrievalError,) + HTTP_ERRORS: except (ServersRetrievalError,) + HTTP_ERRORS:
print_('Cannot retrieve speedtest server list') print_('Cannot retrieve speedtest server list')
raise SpeedtestCLIError(get_exception()) raise SpeedtestCLIError(get_exception())
except InvalidServerIDType: except InvalidServerIDType:
raise SpeedtestCLIError('%s is an invalid server type, must ' raise SpeedtestCLIError(
'be an int' % args.server) '%s is an invalid server type, must '
'be an int' % ', '.join('%s' % s for s in args.server)
)
printer('Selecting best server based on ping...', quiet) printer('Selecting best server based on ping...', quiet)
speedtest.get_best_server() speedtest.get_best_server()