diff --git a/speedtest_cli.py b/speedtest_cli.py index 21f2245..c38e1b7 100755 --- a/speedtest_cli.py +++ b/speedtest_cli.py @@ -80,6 +80,14 @@ try: except ImportError: from md5 import md5 +try: + from cStringIO import StringIO +except ImportError: + try: + from io import StringIO + except ImportError: + from StringIO import StringIO + try: from argparse import ArgumentParser as ArgParser PARSER_TYPE_INT = int @@ -192,22 +200,23 @@ def print_dots(current, total, start=False, end=False): class HTTPDownloader(threading.Thread): """Thread class for retrieving a URL""" - def __init__(self, i, url, start): + def __init__(self, i, url, start, timeout): self.url = url self.result = None self.starttime = start + self.timeout = timeout self.i = i threading.Thread.__init__(self) def run(self): self.result = [0] try: - if (time.time() - self.starttime) <= 10: + if (time.time() - self.starttime) <= self.timeout: req = Request(self.url) req.add_header('User-Agent', USER_AGENT) f = urlopen(req) while (1 and not shutdown_event.isSet() and - (time.time() - self.starttime) <= 10): + (time.time() - self.starttime) <= self.timeout): self.result.append(len(f.read(10240))) if self.result[-1] == 0: break @@ -216,34 +225,62 @@ class HTTPDownloader(threading.Thread): pass +class SpeedtestUploadTimeout(Exception): + pass + + +class HTTPUploaderData(object): + def __init__(self, length, start, timeout): + self.length = length + self.start = start + self.timeout = timeout + + chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' + data = chars * (int(round(int(length) / 36.0))) + self.data = StringIO() + self.data.write(('content1=%s' % data[0:int(length) - 9]).encode()) + self.data.seek(0) + + self.total = [0] + + def read(self, n=10240): + if (time.time() - self.start) <= self.timeout: + chunk = self.data.read(n) + self.total.append(len(chunk)) + return chunk + else: + raise SpeedtestUploadTimeout + + def __len__(self): + return self.length + + class HTTPUploader(threading.Thread): """Thread class for uploading to a URL""" - def __init__(self, i, url, start, size): + def __init__(self, i, url, start, size, timeout): self.url = url - chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ' - data = chars * (int(round(int(size) / 36.0))) - self.data = ('content1=%s' % data[0:int(size) - 9]).encode() - del data + self.data = HTTPUploaderData(size, start, timeout) self.result = None self.starttime = start + self.timeout = timeout self.i = i threading.Thread.__init__(self) def run(self): try: - if ((time.time() - self.starttime) <= 10 and + if ((time.time() - self.starttime) <= self.timeout and not shutdown_event.isSet()): req = Request(self.url, self.data) req.add_header('User-Agent', USER_AGENT) f = urlopen(req) f.read(11) f.close() - self.result = len(self.data) + self.result = sum(self.data.total) else: self.result = 0 except: - self.result = 0 + self.result = sum(self.data.total) class SpeedtestException(Exception): @@ -510,12 +547,16 @@ class Speedtest(object): threads = dict(upload=int(upload['threads']), download=int(server_config['threadcount'])) + length = dict(upload=int(upload['testlength']), + download=int(download['testlength'])) + self.config.update({ 'client': client, 'ignore_servers': ignore_servers, 'sizes': sizes, 'counts': counts, - 'threads': threads + 'threads': threads, + 'length': length, }) self.lat_lon = (float(client['lat']), float(client['lon'])) @@ -711,7 +752,8 @@ class Speedtest(object): def producer(q, urls, url_count): for i, url in enumerate(urls): - thread = HTTPDownloader(i, url, start) + thread = HTTPDownloader(i, url, start, + self.config['length']['download']) thread.start() q.put(thread, True) if not shutdown_event.isSet() and callback: @@ -758,7 +800,8 @@ class Speedtest(object): def producer(q, sizes, size_count): for i, size in enumerate(sizes): - thread = HTTPUploader(i, self.best['url'], start, size) + thread = HTTPUploader(i, self.best['url'], start, size, + self.config['length']['upload']) thread.start() q.put(thread, True) if not shutdown_event.isSet() and callback: