diff --git a/speedtest.py b/speedtest.py index faed5f8..dae9ff0 100755 --- a/speedtest.py +++ b/speedtest.py @@ -758,7 +758,7 @@ class HTTPDownloader(threading.Thread): pass -class SocketDownloader(threading.Thread): +class SocketTestBase(threading.Thread): def __init__(self, i, address, size, start, timeout, shutdown_event=None, source_address=None): threading.Thread.__init__(self) @@ -780,6 +780,8 @@ class SocketDownloader(threading.Thread): source_address=source_address ) + +class SocketDownloader(SocketTestBase): def run(self): try: if (timeit.default_timer() - self.starttime) <= self.timeout: @@ -908,28 +910,7 @@ class HTTPUploader(threading.Thread): self.result = self.request.data.total -class SocketUploader(threading.Thread): - def __init__(self, i, address, size, start, timeout, shutdown_event=None, - source_address=None): - threading.Thread.__init__(self) - self.result = 0 - self.starttime = start - self.timeout = timeout - self.i = i - self.size = size - self.remaining = self.size - - if shutdown_event: - self._shutdown_event = shutdown_event - else: - self._shutdown_event = FakeShutdownEvent() - - self.sock = connection_factory( - address, - timeout=timeout, - source_address=source_address - ) - +class SocketUploader(SocketTestBase): def run(self): try: if (timeit.default_timer() - self.starttime) <= self.timeout: