Support gzip encoding if available

This commit is contained in:
Matt Martz 2016-11-02 19:47:07 -05:00
parent 4280c448cf
commit 59880107a7
1 changed files with 81 additions and 12 deletions

View File

@ -29,6 +29,13 @@ import platform
import threading import threading
import xml.parsers.expat import xml.parsers.expat
try:
import gzip
GZIP_BASE = gzip.GzipFile
except ImportError:
gzip = None
GZIP_BASE = object
__version__ = '1.0.0' __version__ = '1.0.0'
@ -125,11 +132,13 @@ except ImportError:
try: try:
from cStringIO import StringIO from cStringIO import StringIO
BytesIO = None
except ImportError: except ImportError:
try: try:
from io import StringIO from io import StringIO, BytesIO
except ImportError: except ImportError:
from StringIO import StringIO from StringIO import StringIO
BytesIO = None
try: try:
import builtins import builtins
@ -216,15 +225,19 @@ class SpeedtestException(Exception):
"""Base exception for this module""" """Base exception for this module"""
class SpeedtestHTTPError(SpeedtestException):
"""Base HTTP exception for this module"""
class SpeedtestConfigError(SpeedtestException): class SpeedtestConfigError(SpeedtestException):
"""Configuration provided is invalid""" """Configuration provided is invalid"""
class ConfigRetrievalError(SpeedtestException): class ConfigRetrievalError(SpeedtestHTTPError):
"""Could not retrieve config.php""" """Could not retrieve config.php"""
class ServersRetrievalError(SpeedtestException): class ServersRetrievalError(SpeedtestHTTPError):
"""Could not retrieve speedtest-servers.php""" """Could not retrieve speedtest-servers.php"""
@ -266,6 +279,30 @@ class SpeedtestBestServerFailure(SpeedtestException):
"""Unable to determine best server""" """Unable to determine best server"""
class GzipDecodedResponse(GZIP_BASE):
"""A file-like object to decode a response encoded with the gzip
method, as described in RFC 1952.
Largely copied from ``xmlrpclib``/``xmlrpc.client`` and modified
to work for py2.4-py3
"""
def __init__(self, response):
# response doesn't support tell() and read(), required by
# GzipFile
if not gzip:
raise SpeedtestHTTPError('HTTP response body is gzip encoded, '
'but gzip support is not available')
IO = BytesIO or StringIO
self.io = IO(response.read())
gzip.GzipFile.__init__(self, mode='rb', fileobj=self.io)
def close(self):
try:
gzip.GzipFile.close(self)
finally:
self.io.close()
def bound_socket(*args, **kwargs): def bound_socket(*args, **kwargs):
"""Bind socket to a specified source IP address""" """Bind socket to a specified source IP address"""
@ -365,6 +402,23 @@ def catch_request(request):
return None, e return None, e
def get_response_stream(response):
"""Helper function to return either a Gzip reader if
``Content-Encoding`` is ``gzip`` otherwise the response itself
"""
try:
getheader = response.headers.getheader
except AttributeError:
getheader = response.getheader
if getheader('content-encoding') == 'gzip':
return GzipDecodedResponse(response)
return response
def get_attributes_by_tag_name(dom, tag_name): def get_attributes_by_tag_name(dom, tag_name):
"""Retrieve an attribute from an XML document and return it in a """Retrieve an attribute from an XML document and return it in a
consistent format consistent format
@ -639,21 +693,28 @@ class Speedtest(object):
we are interested in we are interested in
""" """
request = build_request('://www.speedtest.net/speedtest-config.php') headers = {}
if gzip:
headers['Accept-Encoding'] = 'gzip'
request = build_request('://www.speedtest.net/speedtest-config.php',
headers=headers)
uh, e = catch_request(request) uh, e = catch_request(request)
if e: if e:
raise ConfigRetrievalError(e) raise ConfigRetrievalError(e)
configxml = [] configxml = []
stream = get_response_stream(uh)
while 1: while 1:
configxml.append(uh.read(10240)) configxml.append(stream.read(10240))
if len(configxml[-1]) == 0: if len(configxml[-1]) == 0:
break break
stream.close()
uh.close()
if int(uh.code) != 200: if int(uh.code) != 200:
return None return None
uh.close()
printer(''.encode().join(configxml), debug=True) printer(''.encode().join(configxml), debug=True)
try: try:
@ -734,28 +795,36 @@ class Speedtest(object):
'http://c.speedtest.net/speedtest-servers.php', 'http://c.speedtest.net/speedtest-servers.php',
] ]
headers = {}
if gzip:
headers['Accept-Encoding'] = 'gzip'
errors = [] errors = []
for url in urls: for url in urls:
try: try:
request = build_request('%s?threads=%s' % request = build_request('%s?threads=%s' %
(url, (url,
self.config['threads']['download'])) self.config['threads']['download']),
headers=headers)
uh, e = catch_request(request) uh, e = catch_request(request)
if e: if e:
errors.append('%s' % e) errors.append('%s' % e)
raise ServersRetrievalError raise ServersRetrievalError
stream = get_response_stream(uh)
serversxml = [] serversxml = []
while 1: while 1:
serversxml.append(uh.read(10240)) serversxml.append(stream.read(10240))
if len(serversxml[-1]) == 0: if len(serversxml[-1]) == 0:
break break
if int(uh.code) != 200:
uh.close()
raise ServersRetrievalError
stream.close()
uh.close() uh.close()
if int(uh.code) != 200:
raise ServersRetrievalError
printer(''.encode().join(serversxml), debug=True) printer(''.encode().join(serversxml), debug=True)
try: try: