Add support for binding to a specific interface

This commit is contained in:
Vladimir Vukicevic 2021-03-08 11:17:26 -08:00
parent c58ad3367b
commit 7cbd3cf338
1 changed files with 59 additions and 13 deletions

View File

@ -36,6 +36,13 @@ except ImportError:
gzip = None gzip = None
GZIP_BASE = object GZIP_BASE = object
try:
import IN
SO_BINDTODEVICE = IN.SO_BINDTODEVICE
except ImportError:
SO_BINDTODEVICE = None
__version__ = '2.1.2' __version__ = '2.1.2'
@ -365,7 +372,7 @@ class SpeedtestMissingBestServer(SpeedtestException):
def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
source_address=None): source_address=None, interface=None):
"""Connect to *address* and return the socket object. """Connect to *address* and return the socket object.
Convenience function. Connect to *address* (a 2-tuple ``(host, Convenience function. Connect to *address* (a 2-tuple ``(host,
@ -375,7 +382,9 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
global default timeout setting returned by :func:`getdefaulttimeout` global default timeout setting returned by :func:`getdefaulttimeout`
is used. If *source_address* is set it must be a tuple of (host, port) is used. If *source_address* is set it must be a tuple of (host, port)
for the socket to bind as a source address before making the connection. for the socket to bind as a source address before making the connection.
An host of '' or port 0 tells the OS to use the default. An host of '' or port 0 tells the OS to use the default. If *interface*
is set it must be the name of an interface to bind to. This may require
root privileges.
Largely vendored from Python 2.7, modified to work with Python 2.4 Largely vendored from Python 2.7, modified to work with Python 2.4
""" """
@ -389,6 +398,8 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
sock = socket.socket(af, socktype, proto) sock = socket.socket(af, socktype, proto)
if timeout is not _GLOBAL_DEFAULT_TIMEOUT: if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(float(timeout)) sock.settimeout(float(timeout))
if interface:
sock.setsockopt(socket.SOL_SOCKET, SO_BINDTODEVICE, str(interface + '\0').encode('utf-8'))
if source_address: if source_address:
sock.bind(source_address) sock.bind(source_address)
sock.connect(sa) sock.connect(sa)
@ -411,6 +422,7 @@ class SpeedtestHTTPConnection(HTTPConnection):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
source_address = kwargs.pop('source_address', None) source_address = kwargs.pop('source_address', None)
interface = kwargs.pop('interface', None);
timeout = kwargs.pop('timeout', 10) timeout = kwargs.pop('timeout', 10)
self._tunnel_host = None self._tunnel_host = None
@ -418,21 +430,30 @@ class SpeedtestHTTPConnection(HTTPConnection):
HTTPConnection.__init__(self, *args, **kwargs) HTTPConnection.__init__(self, *args, **kwargs)
self.source_address = source_address self.source_address = source_address
self.interface = interface
self.timeout = timeout self.timeout = timeout
def connect(self): def connect(self):
"""Connect to the host and port specified in __init__.""" """Connect to the host and port specified in __init__."""
fallback = False
try: try:
# force fallback
if self.interface:
raise AttributeError()
self.sock = socket.create_connection( self.sock = socket.create_connection(
(self.host, self.port), (self.host, self.port),
self.timeout, self.timeout,
self.source_address self.source_address
) )
except (AttributeError, TypeError): except (AttributeError, TypeError):
fallback = True
if fallback:
self.sock = create_connection( self.sock = create_connection(
(self.host, self.port), (self.host, self.port),
self.timeout, self.timeout,
self.source_address self.source_address,
interface=self.interface
) )
if self._tunnel_host: if self._tunnel_host:
@ -448,6 +469,7 @@ if HTTPSConnection:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
source_address = kwargs.pop('source_address', None) source_address = kwargs.pop('source_address', None)
interface = kwargs.pop('interface', None);
timeout = kwargs.pop('timeout', 10) timeout = kwargs.pop('timeout', 10)
self._tunnel_host = None self._tunnel_host = None
@ -456,20 +478,30 @@ if HTTPSConnection:
self.timeout = timeout self.timeout = timeout
self.source_address = source_address self.source_address = source_address
self.interface = interface
def connect(self): def connect(self):
"Connect to a host on a given (SSL) port." "Connect to a host on a given (SSL) port."
fallback = False
try: try:
# force fallback
if self.interface:
raise AttributeError()
self.sock = socket.create_connection( self.sock = socket.create_connection(
(self.host, self.port), (self.host, self.port),
self.timeout, self.timeout,
self.source_address self.source_address
) )
except (AttributeError, TypeError): except (AttributeError, TypeError):
fallback = True
if fallback:
self.sock = create_connection( self.sock = create_connection(
(self.host, self.port), (self.host, self.port),
self.timeout, self.timeout,
self.source_address self.source_address,
interface=self.interface
) )
if self._tunnel_host: if self._tunnel_host:
@ -506,7 +538,7 @@ if HTTPSConnection:
) )
def _build_connection(connection, source_address, timeout, context=None): def _build_connection(connection, source_address, timeout, context=None, interface=None):
"""Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or """Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or
``HTTPSConnection`` with the args we need ``HTTPSConnection`` with the args we need
@ -516,6 +548,7 @@ def _build_connection(connection, source_address, timeout, context=None):
def inner(host, **kwargs): def inner(host, **kwargs):
kwargs.update({ kwargs.update({
'source_address': source_address, 'source_address': source_address,
'interface': interface,
'timeout': timeout 'timeout': timeout
}) })
if context: if context:
@ -528,9 +561,10 @@ class SpeedtestHTTPHandler(AbstractHTTPHandler):
"""Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the
args we need for ``source_address`` and ``timeout`` args we need for ``source_address`` and ``timeout``
""" """
def __init__(self, debuglevel=0, source_address=None, timeout=10): def __init__(self, debuglevel=0, source_address=None, timeout=10, interface=None):
AbstractHTTPHandler.__init__(self, debuglevel) AbstractHTTPHandler.__init__(self, debuglevel)
self.source_address = source_address self.source_address = source_address
self.interface = interface
self.timeout = timeout self.timeout = timeout
def http_open(self, req): def http_open(self, req):
@ -538,7 +572,8 @@ class SpeedtestHTTPHandler(AbstractHTTPHandler):
_build_connection( _build_connection(
SpeedtestHTTPConnection, SpeedtestHTTPConnection,
self.source_address, self.source_address,
self.timeout self.timeout,
interface=self.interface
), ),
req req
) )
@ -551,10 +586,11 @@ class SpeedtestHTTPSHandler(AbstractHTTPHandler):
args we need for ``source_address`` and ``timeout`` args we need for ``source_address`` and ``timeout``
""" """
def __init__(self, debuglevel=0, context=None, source_address=None, def __init__(self, debuglevel=0, context=None, source_address=None,
timeout=10): timeout=10, interface=None):
AbstractHTTPHandler.__init__(self, debuglevel) AbstractHTTPHandler.__init__(self, debuglevel)
self._context = context self._context = context
self.source_address = source_address self.source_address = source_address
self.interface = interface
self.timeout = timeout self.timeout = timeout
def https_open(self, req): def https_open(self, req):
@ -564,6 +600,7 @@ class SpeedtestHTTPSHandler(AbstractHTTPHandler):
self.source_address, self.source_address,
self.timeout, self.timeout,
context=self._context, context=self._context,
interface=self.interface
), ),
req req
) )
@ -571,7 +608,7 @@ class SpeedtestHTTPSHandler(AbstractHTTPHandler):
https_request = AbstractHTTPHandler.do_request_ https_request = AbstractHTTPHandler.do_request_
def build_opener(source_address=None, timeout=10): def build_opener(source_address=None, timeout=10, interface=None):
"""Function similar to ``urllib2.build_opener`` that will build """Function similar to ``urllib2.build_opener`` that will build
an ``OpenerDirector`` with the explicit handlers we want, an ``OpenerDirector`` with the explicit handlers we want,
``source_address`` for binding, ``timeout`` and our custom ``source_address`` for binding, ``timeout`` and our custom
@ -587,12 +624,17 @@ def build_opener(source_address=None, timeout=10):
else: else:
source_address_tuple = None source_address_tuple = None
if interface:
printer('Binding to interface: %s' % (interface,), debug = True)
handlers = [ handlers = [
ProxyHandler(), ProxyHandler(),
SpeedtestHTTPHandler(source_address=source_address_tuple, SpeedtestHTTPHandler(source_address=source_address_tuple,
timeout=timeout), timeout=timeout,
interface=interface),
SpeedtestHTTPSHandler(source_address=source_address_tuple, SpeedtestHTTPSHandler(source_address=source_address_tuple,
timeout=timeout), timeout=timeout,
interface=interface),
HTTPDefaultErrorHandler(), HTTPDefaultErrorHandler(),
HTTPRedirectHandler(), HTTPRedirectHandler(),
HTTPErrorProcessor() HTTPErrorProcessor()
@ -1074,12 +1116,13 @@ class Speedtest(object):
"""Class for performing standard speedtest.net testing operations""" """Class for performing standard speedtest.net testing operations"""
def __init__(self, config=None, source_address=None, timeout=10, def __init__(self, config=None, source_address=None, timeout=10,
secure=False, shutdown_event=None): secure=False, shutdown_event=None, interface=None):
self.config = {} self.config = {}
self._source_address = source_address self._source_address = source_address
self._timeout = timeout self._timeout = timeout
self._opener = build_opener(source_address, timeout) self._interface = interface
self._opener = build_opener(source_address, timeout, interface)
self._secure = secure self._secure = secure
@ -1755,6 +1798,7 @@ def parse_args():
'supplied multiple times') 'supplied multiple times')
parser.add_argument('--mini', help='URL of the Speedtest Mini server') parser.add_argument('--mini', help='URL of the Speedtest Mini server')
parser.add_argument('--source', help='Source IP address to bind to') parser.add_argument('--source', help='Source IP address to bind to')
parser.add_argument('--interface', help='Interface to bind to. May require root.')
parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT, parser.add_argument('--timeout', default=10, type=PARSER_TYPE_FLOAT,
help='HTTP timeout in seconds. Default 10') help='HTTP timeout in seconds. Default 10')
parser.add_argument('--secure', action='store_true', parser.add_argument('--secure', action='store_true',
@ -1790,6 +1834,7 @@ def validate_optional_args(args):
optional_args = { optional_args = {
'json': ('json/simplejson python module', json), 'json': ('json/simplejson python module', json),
'secure': ('SSL support', HTTPSConnection), 'secure': ('SSL support', HTTPSConnection),
'interface': ('Interface binding', SO_BINDTODEVICE),
} }
for arg, info in optional_args.items(): for arg, info in optional_args.items():
@ -1871,6 +1916,7 @@ def shell():
try: try:
speedtest = Speedtest( speedtest = Speedtest(
source_address=args.source, source_address=args.source,
interface=args.interface,
timeout=args.timeout, timeout=args.timeout,
secure=args.secure secure=args.secure
) )