diff --git a/speedtest.py b/speedtest.py index 7ea05a5..0fe067b 100755 --- a/speedtest.py +++ b/speedtest.py @@ -70,6 +70,7 @@ except ImportError: import xml.etree.ElementTree as ET except ImportError: from xml.dom import minidom as DOM + from xml.parsers.expat import ExpatError ET = None try: @@ -284,7 +285,11 @@ class SpeedtestHTTPError(SpeedtestException): class SpeedtestConfigError(SpeedtestException): - """Configuration provided is invalid""" + """Configuration XML is invalid""" + + +class SpeedtestServersError(SpeedtestException): + """Servers XML is invalid""" class ConfigRetrievalError(SpeedtestHTTPError): @@ -1042,16 +1047,16 @@ class Speedtest(object): uh, e = catch_request(request, opener=self._opener) if e: raise ConfigRetrievalError(e) - configxml = [] + configxml_list = [] stream = get_response_stream(uh) while 1: try: - configxml.append(stream.read(1024)) + configxml_list.append(stream.read(1024)) except (OSError, EOFError): raise ConfigRetrievalError(get_exception()) - if len(configxml[-1]) == 0: + if len(configxml_list[-1]) == 0: break stream.close() uh.close() @@ -1059,10 +1064,18 @@ class Speedtest(object): if int(uh.code) != 200: return None - printer('Config XML:\n%s' % ''.encode().join(configxml), debug=True) + configxml = ''.encode().join(configxml_list) + + printer('Config XML:\n%s' % configxml, debug=True) try: - root = ET.fromstring(''.encode().join(configxml)) + try: + root = ET.fromstring(configxml) + except ET.ParseError: + e = get_exception() + raise SpeedtestConfigError( + 'Malformed speedtest.net configuration: %s' % e + ) server_config = root.find('server-config').attrib download = root.find('download').attrib upload = root.find('upload').attrib @@ -1070,7 +1083,13 @@ class Speedtest(object): client = root.find('client').attrib except AttributeError: - root = DOM.parseString(''.join(configxml)) + try: + root = DOM.parseString(configxml) + except ExpatError: + e = get_exception() + raise SpeedtestConfigError( + 'Malformed speedtest.net configuration: %s' % e + ) server_config = get_attributes_by_tag_name(root, 'server-config') download = get_attributes_by_tag_name(root, 'download') upload = get_attributes_by_tag_name(root, 'upload') @@ -1179,13 +1198,13 @@ class Speedtest(object): stream = get_response_stream(uh) - serversxml = [] + serversxml_list = [] while 1: try: - serversxml.append(stream.read(1024)) + serversxml_list.append(stream.read(1024)) except (OSError, EOFError): raise ServersRetrievalError(get_exception()) - if len(serversxml[-1]) == 0: + if len(serversxml_list[-1]) == 0: break stream.close() @@ -1194,15 +1213,28 @@ class Speedtest(object): if int(uh.code) != 200: raise ServersRetrievalError() - printer('Servers XML:\n%s' % ''.encode().join(serversxml), - debug=True) + serversxml = ''.encode().join(serversxml_list) + + printer('Servers XML:\n%s' % serversxml, debug=True) try: try: - root = ET.fromstring(''.encode().join(serversxml)) + try: + root = ET.fromstring(serversxml) + except ET.ParseError: + e = get_exception() + raise SpeedtestServersError( + 'Malformed speedtest.net server list: %s' % e + ) elements = root.getiterator('server') except AttributeError: - root = DOM.parseString(''.join(serversxml)) + try: + root = DOM.parseString(serversxml) + except ExpatError: + e = get_exception() + raise SpeedtestServersError( + 'Malformed speedtest.net server list: %s' % e + ) elements = root.getElementsByTagName('server') except (SyntaxError, xml.parsers.expat.ExpatError): raise ServersRetrievalError()