diff --git a/speedtest.py b/speedtest.py index 0053237..878edb7 100755 --- a/speedtest.py +++ b/speedtest.py @@ -52,7 +52,6 @@ class FakeShutdownEvent(object): # Some global variables we use SHUTDOWN_EVENT = FakeShutdownEvent() -SCHEME = 'http' DEBUG = False _GLOBAL_DEFAULT_TIMEOUT = object() @@ -554,7 +553,7 @@ def build_user_agent(): return user_agent -def build_request(url, data=None, headers=None, bump=''): +def build_request(url, data=None, headers=None, bump='', secure=False): """Build a urllib2 request object This function automatically adds a User-Agent header to all requests @@ -565,7 +564,9 @@ def build_request(url, data=None, headers=None, bump=''): headers = {} if url[0] == ':': - schemed_url = '%s%s' % (SCHEME, url) + scheme = ('http', 'https')[bool(secure)] + print scheme + schemed_url = '%s%s' % (scheme, url) else: schemed_url = url @@ -780,7 +781,8 @@ class SpeedtestResults(object): to get a share results image link. """ - def __init__(self, download=0, upload=0, ping=0, server=None, opener=None): + def __init__(self, download=0, upload=0, ping=0, server=None, opener=None, + secure=False): self.download = download self.upload = upload self.ping = ping @@ -798,6 +800,8 @@ class SpeedtestResults(object): else: self._opener = build_opener() + self._secure = secure + def __repr__(self): return repr(self.dict()) @@ -839,7 +843,7 @@ class SpeedtestResults(object): headers = {'Referer': 'http://c.speedtest.net/flash/speedtest.swf'} request = build_request('://www.speedtest.net/api/api.php', data='&'.join(api_data).encode(), - headers=headers) + headers=headers, secure=self._secure) f, e = catch_request(request, opener=self._opener) if e: raise ShareResultsConnectFailure(e) @@ -904,13 +908,16 @@ class SpeedtestResults(object): class Speedtest(object): """Class for performing standard speedtest.net testing operations""" - def __init__(self, config=None, source_address=None, timeout=10): + def __init__(self, config=None, source_address=None, timeout=10, + secure=False): self.config = {} self._source_address = source_address self._timeout = timeout self._opener = build_opener(source_address, timeout) + self._secure = secure + self.get_config() if config is not None: self.config.update(config) @@ -919,7 +926,7 @@ class Speedtest(object): self.closest = [] self.best = {} - self.results = SpeedtestResults(opener=self._opener) + self.results = SpeedtestResults(opener=self._opener, secure=secure) def get_config(self): """Download the speedtest.net configuration and return only the data @@ -930,7 +937,7 @@ class Speedtest(object): if gzip: headers['Accept-Encoding'] = 'gzip' request = build_request('://www.speedtest.net/speedtest-config.php', - headers=headers) + headers=headers, secure=self._secure) uh, e = catch_request(request, opener=self._opener) if e: raise ConfigRetrievalError(e) @@ -1044,10 +1051,12 @@ class Speedtest(object): errors = [] for url in urls: try: - request = build_request('%s?threads=%s' % - (url, - self.config['threads']['download']), - headers=headers) + request = build_request( + '%s?threads=%s' % (url, + self.config['threads']['download']), + headers=headers, + secure=self._secure + ) uh, e = catch_request(request, opener=self._opener) if e: errors.append('%s' % e) @@ -1274,7 +1283,9 @@ class Speedtest(object): request_count = len(urls) requests = [] for i, url in enumerate(urls): - requests.append(build_request(url, bump=i)) + requests.append( + build_request(url, bump=i, secure=self._secure) + ) def producer(q, requests, request_count): for i, request in enumerate(requests): @@ -1338,7 +1349,7 @@ class Speedtest(object): data.pre_allocate() requests.append( ( - build_request(self.best['url'], data), + build_request(self.best['url'], data, secure=self._secure), size ) ) @@ -1526,7 +1537,7 @@ def printer(string, quiet=False, debug=False, **kwargs): def shell(): """Run the full speedtest.net test""" - global SHUTDOWN_EVENT, SCHEME, DEBUG + global SHUTDOWN_EVENT, DEBUG SHUTDOWN_EVENT = threading.Event() signal.signal(signal.SIGINT, ctrl_c) @@ -1549,9 +1560,6 @@ def shell(): validate_optional_args(args) - if args.secure: - SCHEME = 'https' - debug = getattr(args, 'debug', False) if debug == 'SUPPRESSHELP': debug = False @@ -1576,7 +1584,11 @@ def shell(): printer('Retrieving speedtest.net configuration...', quiet) try: - speedtest = Speedtest(source_address=args.source, timeout=args.timeout) + speedtest = Speedtest( + source_address=args.source, + timeout=args.timeout, + secure=args.secure + ) except (ConfigRetrievalError,) + HTTP_ERRORS: printer('Cannot retrieve speedtest configuration') raise SpeedtestCLIError(get_exception())