Gracefully handle XML parsing errors. Fixes #490 #491

This commit is contained in:
Matt Martz 2018-03-09 09:46:10 -06:00
parent f8aa20ecdf
commit 9c2977acfc
1 changed files with 46 additions and 14 deletions

View File

@ -70,6 +70,7 @@ except ImportError:
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
except ImportError: except ImportError:
from xml.dom import minidom as DOM from xml.dom import minidom as DOM
from xml.parsers.expat import ExpatError
ET = None ET = None
try: try:
@ -284,7 +285,11 @@ class SpeedtestHTTPError(SpeedtestException):
class SpeedtestConfigError(SpeedtestException): class SpeedtestConfigError(SpeedtestException):
"""Configuration provided is invalid""" """Configuration XML is invalid"""
class SpeedtestServersError(SpeedtestException):
"""Servers XML is invalid"""
class ConfigRetrievalError(SpeedtestHTTPError): class ConfigRetrievalError(SpeedtestHTTPError):
@ -1042,16 +1047,16 @@ class Speedtest(object):
uh, e = catch_request(request, opener=self._opener) uh, e = catch_request(request, opener=self._opener)
if e: if e:
raise ConfigRetrievalError(e) raise ConfigRetrievalError(e)
configxml = [] configxml_list = []
stream = get_response_stream(uh) stream = get_response_stream(uh)
while 1: while 1:
try: try:
configxml.append(stream.read(1024)) configxml_list.append(stream.read(1024))
except (OSError, EOFError): except (OSError, EOFError):
raise ConfigRetrievalError(get_exception()) raise ConfigRetrievalError(get_exception())
if len(configxml[-1]) == 0: if len(configxml_list[-1]) == 0:
break break
stream.close() stream.close()
uh.close() uh.close()
@ -1059,10 +1064,18 @@ class Speedtest(object):
if int(uh.code) != 200: if int(uh.code) != 200:
return None return None
printer('Config XML:\n%s' % ''.encode().join(configxml), debug=True) configxml = ''.encode().join(configxml_list)
printer('Config XML:\n%s' % configxml, debug=True)
try: try:
root = ET.fromstring(''.encode().join(configxml)) try:
root = ET.fromstring(configxml)
except ET.ParseError:
e = get_exception()
raise SpeedtestConfigError(
'Malformed speedtest.net configuration: %s' % e
)
server_config = root.find('server-config').attrib server_config = root.find('server-config').attrib
download = root.find('download').attrib download = root.find('download').attrib
upload = root.find('upload').attrib upload = root.find('upload').attrib
@ -1070,7 +1083,13 @@ class Speedtest(object):
client = root.find('client').attrib client = root.find('client').attrib
except AttributeError: except AttributeError:
root = DOM.parseString(''.join(configxml)) try:
root = DOM.parseString(configxml)
except ExpatError:
e = get_exception()
raise SpeedtestConfigError(
'Malformed speedtest.net configuration: %s' % e
)
server_config = get_attributes_by_tag_name(root, 'server-config') server_config = get_attributes_by_tag_name(root, 'server-config')
download = get_attributes_by_tag_name(root, 'download') download = get_attributes_by_tag_name(root, 'download')
upload = get_attributes_by_tag_name(root, 'upload') upload = get_attributes_by_tag_name(root, 'upload')
@ -1179,13 +1198,13 @@ class Speedtest(object):
stream = get_response_stream(uh) stream = get_response_stream(uh)
serversxml = [] serversxml_list = []
while 1: while 1:
try: try:
serversxml.append(stream.read(1024)) serversxml_list.append(stream.read(1024))
except (OSError, EOFError): except (OSError, EOFError):
raise ServersRetrievalError(get_exception()) raise ServersRetrievalError(get_exception())
if len(serversxml[-1]) == 0: if len(serversxml_list[-1]) == 0:
break break
stream.close() stream.close()
@ -1194,15 +1213,28 @@ class Speedtest(object):
if int(uh.code) != 200: if int(uh.code) != 200:
raise ServersRetrievalError() raise ServersRetrievalError()
printer('Servers XML:\n%s' % ''.encode().join(serversxml), serversxml = ''.encode().join(serversxml_list)
debug=True)
printer('Servers XML:\n%s' % serversxml, debug=True)
try: try:
try: try:
root = ET.fromstring(''.encode().join(serversxml)) try:
root = ET.fromstring(serversxml)
except ET.ParseError:
e = get_exception()
raise SpeedtestServersError(
'Malformed speedtest.net server list: %s' % e
)
elements = root.getiterator('server') elements = root.getiterator('server')
except AttributeError: except AttributeError:
root = DOM.parseString(''.join(serversxml)) try:
root = DOM.parseString(serversxml)
except ExpatError:
e = get_exception()
raise SpeedtestServersError(
'Malformed speedtest.net server list: %s' % e
)
elements = root.getElementsByTagName('server') elements = root.getElementsByTagName('server')
except (SyntaxError, xml.parsers.expat.ExpatError): except (SyntaxError, xml.parsers.expat.ExpatError):
raise ServersRetrievalError() raise ServersRetrievalError()