diff --git a/python/spdylay.pyx b/python/spdylay.pyx index 1a427d9c..b5c0733e 100644 --- a/python/spdylay.pyx +++ b/python/spdylay.pyx @@ -1103,52 +1103,6 @@ try: import time from xml.sax.saxutils import escape - def send_cb(session, data): - ssctrl = session.user_data - wlen = ssctrl.sock.send(data) - return wlen - - def read_cb(session, stream_id, length, read_ctrl, source): - data = source.read(length) - if not data: - read_ctrl.flags = READ_EOF - return data - - def on_ctrl_recv_cb(session, frame): - ssctrl = session.user_data - if frame.frame_type == SYN_STREAM: - stream = Stream(frame.stream_id) - ssctrl.streams[frame.stream_id] = stream - - stream.process_headers(frame.nv) - elif frame.frame_type == HEADERS: - if frame.stream_id in ssctrl.streams: - stream = ssctrl.streams[frame.stream_id] - stream.process_headers(frame.nv) - - def on_data_chunk_recv_cb(session, flags, stream_id, data): - ssctrl = session.user_data - if stream_id in ssctrl.streams: - stream = ssctrl.streams[stream_id] - if stream.method == 'POST': - if not stream.rfile: - stream.rfile = io.BytesIO() - stream.rfile.write(data) - else: - # We don't allow request body if method is not POST - session.submit_rst_stream(stream_id, PROTOCOL_ERROR) - - def on_stream_close_cb(session, stream_id, status_code): - ssctrl = session.user_data - if stream_id in ssctrl.streams: - del ssctrl.streams[stream_id] - - def on_request_recv_cb(session, stream_id): - ssctrl = session.user_data - if stream_id in ssctrl.streams: - stream = ssctrl.streams[stream_id] - ssctrl.handler.handle_one_request(stream) - class Stream: def __init__(self, stream_id): self.stream_id = stream_id @@ -1180,9 +1134,7 @@ try: self.headers.append((k, v)) class SessionCtrl: - def __init__(self, handler, sock): - self.handler = handler - self.sock = sock + def __init__(self): self.streams = {} class BaseSPDYRequestHandler(socketserver.BaseRequestHandler): @@ -1271,7 +1223,7 @@ try: .format(stream.method)) self.wfile.seek(0) - data_prd = DataProvider(self.wfile, read_cb) + data_prd = DataProvider(self.wfile, self.read_cb) stream.data_prd = data_prd self.send_header(':version', 'HTTP/1.1') @@ -1281,38 +1233,80 @@ try: self.session.submit_response(stream.stream_id, self._response_headers, data_prd) + + def send_cb(self, session, data): + return self.sslsock.send(data) + + def read_cb(self, session, stream_id, length, read_ctrl, source): + data = source.read(length) + if not data: + read_ctrl.flags = READ_EOF + return data + + def on_ctrl_recv_cb(self, session, frame): + if frame.frame_type == SYN_STREAM: + stream = Stream(frame.stream_id) + self.ssctrl.streams[frame.stream_id] = stream + + stream.process_headers(frame.nv) + elif frame.frame_type == HEADERS: + if frame.stream_id in self.ssctrl.streams: + stream = self.ssctrl.streams[frame.stream_id] + stream.process_headers(frame.nv) + + def on_data_chunk_recv_cb(self, session, flags, stream_id, data): + if stream_id in self.ssctrl.streams: + stream = self.ssctrl.streams[stream_id] + if stream.method == 'POST': + if not stream.rfile: + stream.rfile = io.BytesIO() + stream.rfile.write(data) + else: + # We don't allow request body if method is not POST + session.submit_rst_stream(stream_id, PROTOCOL_ERROR) + + def on_stream_close_cb(self, session, stream_id, status_code): + if stream_id in self.ssctrl.streams: + del self.ssctrl.streams[stream_id] + + def on_request_recv_cb(self, session, stream_id): + if stream_id in self.ssctrl.streams: + stream = self.ssctrl.streams[stream_id] + self.handle_one_request(stream) + def handle(self): self.request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) # TODO We need to call handshake manually because 3.3.0b2 # crashes if do_handshake_on_connect=True - sock = self.server.ctx.wrap_socket(self.request, server_side=True, - do_handshake_on_connect=False) - sock.setblocking(False) + self.sslsock = self.server.ctx.wrap_socket(\ + self.request, + server_side=True, + do_handshake_on_connect=False) + + self.sslsock.setblocking(False) while True: try: - sock.do_handshake() + self.sslsock.do_handshake() break except ssl.SSLWantReadError as e: - select.select([sock], [], []) + select.select([self.sslsock], [], []) except ssl.SSLWantWriteError as e: - select.select([], [sock], []) + select.select([], [self.sslsock], []) - version = npn_get_version(sock.selected_npn_protocol()) + version = npn_get_version(self.sslsock.selected_npn_protocol()) if version == 0: return - ssctrl = SessionCtrl(self, sock) + self.ssctrl = SessionCtrl() self.session = Session(\ - SERVER, - version, - send_cb=send_cb, - on_ctrl_recv_cb=on_ctrl_recv_cb, - on_data_chunk_recv_cb=on_data_chunk_recv_cb, - on_stream_close_cb=on_stream_close_cb, - on_request_recv_cb=on_request_recv_cb, - user_data=ssctrl) + SERVER, version, + send_cb=self.send_cb, + on_ctrl_recv_cb=self.on_ctrl_recv_cb, + on_data_chunk_recv_cb=self.on_data_chunk_recv_cb, + on_stream_close_cb=self.on_stream_close_cb, + on_request_recv_cb=self.on_request_recv_cb) self.session.submit_settings(\ FLAG_SETTINGS_NONE, @@ -1322,7 +1316,7 @@ try: while self.session.want_read() or self.session.want_write(): want_read = want_write = False try: - data = sock.recv(4096) + data = self.sslsock.recv(4096) if data: self.session.recv(data) else: @@ -1339,8 +1333,8 @@ try: want_write = True if want_read or want_write: - select.select([sock] if want_read else [], - [sock] if want_write else [], + select.select([self.sslsock] if want_read else [], + [self.sslsock] if want_write else [], []) # The following methods and attributes are copied from