Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add IPv6 support #39

Closed
wants to merge 9 commits into from
8 changes: 4 additions & 4 deletions src/simple_websocket/aiows.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,12 @@ def __init__(self, url, subprotocols=None, headers=None,
if isinstance(self.subprotocols, str):
self.subprotocols = [self.subprotocols]

self.extra_headeers = []
self.extra_headers = []
if isinstance(headers, dict):
for key, value in headers.items():
self.extra_headeers.append((key, value))
self.extra_headers.append((key, value))
elif isinstance(headers, list):
self.extra_headeers = headers
self.extra_headers = headers

@classmethod
async def connect(cls, url, subprotocols=None, headers=None,
Expand Down Expand Up @@ -443,7 +443,7 @@ async def _connect(self):
async def handshake(self):
out_data = self.ws.send(Request(host=self.host, target=self.path,
subprotocols=self.subprotocols,
extra_headers=self.extra_headeers))
extra_headers=self.extra_headers))
self.wsock.write(out_data)

while True:
Expand Down
29 changes: 21 additions & 8 deletions src/simple_websocket/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ class Client(Base):
"""
def __init__(self, url, subprotocols=None, headers=None,
receive_bytes=4096, ping_interval=None, max_message_size=None,
ssl_context=None, thread_class=None, event_class=None):
ssl_context=None, thread_class=None, event_class=None,
address_family=None):
parsed_url = urlsplit(url)
is_secure = parsed_url.scheme in ['https', 'wss']
self.host = parsed_url.hostname
Expand All @@ -390,14 +391,21 @@ def __init__(self, url, subprotocols=None, headers=None,
if isinstance(self.subprotocols, str):
self.subprotocols = [self.subprotocols]

self.extra_headeers = []
self.extra_headers = []
if isinstance(headers, dict):
for key, value in headers.items():
self.extra_headeers.append((key, value))
self.extra_headers.append((key, value))
elif isinstance(headers, list):
self.extra_headeers = headers
self.extra_headers = headers

if address_family is None:
addr_family = socket.getaddrinfo(
self.host, self.port,
type=socket.SOCK_STREAM)[0][0]
sock = socket.socket(addr_family, socket.SOCK_STREAM)
else:
sock = socket.socket(address_family, socket.SOCK_STREAM)

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
italodeverdade marked this conversation as resolved.
Show resolved Hide resolved
if is_secure: # pragma: no cover
if ssl_context is None:
ssl_context = ssl.create_default_context(
Expand All @@ -413,7 +421,8 @@ def __init__(self, url, subprotocols=None, headers=None,
@classmethod
def connect(cls, url, subprotocols=None, headers=None,
receive_bytes=4096, ping_interval=None, max_message_size=None,
ssl_context=None, thread_class=None, event_class=None):
ssl_context=None, thread_class=None, event_class=None,
address_family=None):
"""Returns a WebSocket client connection.

:param url: The connection URL. Both ``ws://`` and ``wss://`` URLs are
Expand Down Expand Up @@ -450,16 +459,20 @@ def connect(cls, url, subprotocols=None, headers=None,
:param event_class: The ``Event`` class to use when creating event
objects. The default is the `threading.Event``
class from the Python standard library.
:param address_family: The address family to use when creating the
socket. The default is ``None`` and will fallback
to first getaddrinfo result.
"""
return cls(url, subprotocols=subprotocols, headers=headers,
receive_bytes=receive_bytes, ping_interval=ping_interval,
max_message_size=max_message_size, ssl_context=ssl_context,
thread_class=thread_class, event_class=event_class)
thread_class=thread_class, event_class=event_class,
address_family=address_family)

def handshake(self):
out_data = self.ws.send(Request(host=self.host, target=self.path,
subprotocols=self.subprotocols,
extra_headers=self.extra_headeers))
extra_headers=self.extra_headers))
self.sock.send(out_data)

while True:
Expand Down