diff --git a/speedtest.py b/speedtest.py
index 6d9fe6f..3f046d9 100755
--- a/speedtest.py
+++ b/speedtest.py
@@ -322,7 +322,7 @@ class SpeedtestConfigError(SpeedtestException):
class SpeedtestServersError(SpeedtestException):
- """Servers XML is invalid"""
+ """Servers XML or JSON is invalid"""
class ConfigRetrievalError(SpeedtestHTTPError):
@@ -1239,18 +1239,35 @@ class Speedtest(object):
return self.config
- def json_to_xml(self,json_url=None):
- if json_url:
- r = requests.get(json_url)
- if r.status_code == 200:
- json_data = json.loads(r.text)[0]
- message = '\n\n'
+ def json_to_xml(self,data=None, server_id_list=None):
+ """Converts text data representing a link with json or json text to XML"""
+ if data:
try:
- message += f''
- except (KeyError,SyntaxError) as e:
- pass
- message += "\n\n\n"
- return message.replace("&","").replace("%","").encode()
+ r = requests.get(data)
+ except requests.exceptions.MissingSchema:
+ raise SpeedtestServersError("Invalid --custom link")
+ if r.status_code == 200:
+ message = '\n\n'
+ try:
+ json_data = json.loads(r.text)
+ if server_id_list and len(server_id_list)>=1:
+ for server_json in json_data:
+ if int(server_json["id"]) in server_id_list:
+ json_data = server_json
+ try:
+ message += f''
+ except (KeyError,SyntaxError) as e:
+ pass
+ else:
+ json_data = json_data[0]
+ try:
+ message += f''
+ except (KeyError,SyntaxError) as e:
+ pass
+ except json.decoder.JSONDecodeError:
+ raise SpeedtestServersError("Invalid json data provided by the link")
+ message += "\n\n\n"
+ return message.replace("&","").replace("%","").encode()
def get_servers(self, servers=None, exclude=None, custom_server=None):
@@ -1287,7 +1304,10 @@ class Speedtest(object):
errors = []
if custom_server:
- serversxml = "".encode().join([self.json_to_xml(custom_server)])
+ if custom_server and servers:
+ serversxml = "".encode().join([self.json_to_xml(custom_server,servers)])
+ else:
+ serversxml = "".encode().join([self.json_to_xml(custom_server)])
try:
try:
try:
@@ -1990,7 +2010,9 @@ def shell():
if not args.mini:
printer('Retrieving speedtest.net server list...', quiet)
try:
- speedtest.get_servers(servers=args.server, exclude=args.exclude,custom_server=args.custom)
+ speedtest.get_servers(servers=args.server,
+ exclude=args.exclude,
+ custom_server=args.custom)
except NoMatchedServers:
raise SpeedtestCLIError(
'No matched servers: %s' %