From 9f3f5f5ca9e3f5edb9482dab3c622db2a57d7985 Mon Sep 17 00:00:00 2001 From: Panos Vouzis Date: Tue, 14 Jan 2020 19:54:31 -0500 Subject: [PATCH] feat: added interface, ipv4 and ipv6 option as source. --- patch_socket_create_connection.py | 92 ++++++++++++++++++++++++++++ speedtest.py | 99 +++++++++++-------------------- 2 files changed, 125 insertions(+), 66 deletions(-) create mode 100755 patch_socket_create_connection.py diff --git a/patch_socket_create_connection.py b/patch_socket_create_connection.py new file mode 100755 index 0000000..0dc5740 --- /dev/null +++ b/patch_socket_create_connection.py @@ -0,0 +1,92 @@ +import socket +import sys +from IN import SO_BINDTODEVICE + +class CustomSocket(object): + def __init__(self, + network_interface=None, + ipv4_source=None, ipv6_source=None, + network_timeout=None): + + if network_interface is not None: + network_interface = network_interface.strip()[:15] + '\0' + self.network_interface = network_interface + + if network_timeout is not None: + network_timeout = float(network_timeout) + self.network_timeout = network_timeout + + if ipv6_source is not None: + parsed_ipv6_source = self.parse_source_address(ipv6_source) + ipv6_source = self.extract_source_address_from_ipv6(parsed_ipv6_source) + self.ipv6_source = ipv6_source + + if ipv4_source is not None: + ipv4_source = self.parse_source_address(ipv4_source) + self.ipv4_source = ipv4_source + + @staticmethod + def parse_source_address(source_addr): + source_addr = source_addr.split(',') + if len(source_addr) == 1: + return (source_addr[0], 0) + return (source_addr[0], int(source_addr[1])) + + @staticmethod + def extract_source_address_from_ipv6(ipv6_source): + source_ip, source_port = ipv6_source + source_address = [addr for addr in socket.getaddrinfo(source_ip, source_port, socket.AF_INET6, socket.SOCK_STREAM, socket.SOL_TCP)] + if not source_address: + raise ValueError("Couldn't find ipv6 address for source %s" % source_ip) + return source_address[0][-1] + + def get_source_address(self, host): + if ':' in host: + return self.ipv6_source + + return self.ipv4_source + + def create_connection_with_custom_network_interface( + self, address, timeout, source_address=None): + """ + Patched the standard library (v2.7.3) socket.create_connection to + connect to a network interface. + + https://github.com/enthought/Python-2.7.3/blob/master/Lib/socket.py#L537 + """ + host, port = address + + for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + + if source_address is None: + source_address = self.get_source_address(host) + + if self.network_interface: + try: + sock.setsockopt(socket.SOL_SOCKET, + SO_BINDTODEVICE, + self.network_interface) + except Exception as e: + err_msg = "No device exists: {}".format(self.network_interface) + sys.exit(err_msg) + elif source_address: + sock.bind(source_address) + + if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: + sock.settimeout(timeout) + elif self.network_timeout: + sock.settimeout(self.network_timeout) + + sock.connect(sa) + + return sock + except socket.error as err: + if sock: + sock.close() + raise err + + raise error("getaddrinfo returns an empty list") \ No newline at end of file diff --git a/speedtest.py b/speedtest.py index 92a2be0..7a64677 100755 --- a/speedtest.py +++ b/speedtest.py @@ -23,6 +23,8 @@ import math import errno import signal import socket +import argparse + import timeit import datetime import platform @@ -364,47 +366,6 @@ class SpeedtestMissingBestServer(SpeedtestException): """get_best_server not called or not able to determine best server""" -def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, - source_address=None): - """Connect to *address* and return the socket object. - - Convenience function. Connect to *address* (a 2-tuple ``(host, - port)``) and return the socket object. Passing the optional - *timeout* parameter will set the timeout on the socket instance - before attempting to connect. If no *timeout* is supplied, the - global default timeout setting returned by :func:`getdefaulttimeout` - is used. If *source_address* is set it must be a tuple of (host, port) - for the socket to bind as a source address before making the connection. - An host of '' or port 0 tells the OS to use the default. - - Largely vendored from Python 2.7, modified to work with Python 2.4 - """ - - host, port = address - err = None - for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = None - try: - sock = socket.socket(af, socktype, proto) - if timeout is not _GLOBAL_DEFAULT_TIMEOUT: - sock.settimeout(float(timeout)) - if source_address: - sock.bind(source_address) - sock.connect(sa) - return sock - - except socket.error: - err = get_exception() - if sock is not None: - sock.close() - - if err is not None: - raise err - else: - raise socket.error("getaddrinfo returns an empty list") - - class SpeedtestHTTPConnection(HTTPConnection): """Custom HTTPConnection to support source_address across Python 2.4 - Python 3 @@ -422,18 +383,11 @@ class SpeedtestHTTPConnection(HTTPConnection): def connect(self): """Connect to the host and port specified in __init__.""" - try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = socket.create_connection( + (self.host, self.port), + self.timeout, + self.source_address + ) if self._tunnel_host: self._tunnel() @@ -459,18 +413,11 @@ if HTTPSConnection: def connect(self): "Connect to a host on a given (SSL) port." - try: - self.sock = socket.create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) - except (AttributeError, TypeError): - self.sock = create_connection( - (self.host, self.port), - self.timeout, - self.source_address - ) + self.sock = socket.create_connection( + (self.host, self.port), + self.timeout, + self.source_address + ) if self._tunnel_host: self._tunnel() @@ -1771,6 +1718,9 @@ def parse_args(): help='Show the version number and exit') parser.add_argument('--debug', action='store_true', help=ARG_SUPPRESS, default=ARG_SUPPRESS) + parser.add_argument('-4', '--ipv4', dest='ipv4', help='IPv4 source address') + parser.add_argument('-6', '--ipv6', dest='ipv6', help='IPv6 source address') + parser.add_argument('-i', '--interface', dest='interface', help='Set network interface') options = parser.parse_args() if isinstance(options, tuple): @@ -1829,6 +1779,23 @@ def shell(): args = parse_args() + ipv4_source = None + ipv6_source = None + + ipv4_source = args.ipv4 + ipv6_source = args.ipv6 + network_interface = args.interface + network_timeout = args.timeout + + from patch_socket_create_connection import CustomSocket + + my_socket = CustomSocket( + ipv4_source=ipv4_source, + ipv6_source=ipv6_source, + network_interface=network_interface, + network_timeout=network_timeout) + socket.create_connection = my_socket.create_connection_with_custom_network_interface + # Print the version and exit if args.version: version() @@ -1997,4 +1964,4 @@ def main(): if __name__ == '__main__': - main() + main() \ No newline at end of file