Skip to content

Commit ee1edff

Browse files
committed
h2 proto
1 parent 2447130 commit ee1edff

File tree

2 files changed

+192
-35
lines changed

2 files changed

+192
-35
lines changed

pproxy/proto.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -283,43 +283,51 @@ async def guess(self, reader, **kw):
283283
header = await reader.read_w(4)
284284
reader.rollback(header)
285285
return header in (b'GET ', b'HEAD', b'POST', b'PUT ', b'DELE', b'CONN', b'OPTI', b'TRAC', b'PATC')
286-
async def accept(self, reader, user, writer, users, authtable, httpget=None, **kw):
286+
async def accept(self, reader, user, writer, **kw):
287287
lines = await reader.read_until(b'\r\n\r\n')
288288
headers = lines[:-4].decode().split('\r\n')
289289
method, path, ver = HTTP_LINE.match(headers.pop(0)).groups()
290290
lines = '\r\n'.join(i for i in headers if not i.startswith('Proxy-'))
291291
headers = dict(i.split(': ', 1) for i in headers if ': ' in i)
292+
async def reply(code, message, body=None, wait=False):
293+
writer.write(message)
294+
if body:
295+
writer.write(body)
296+
if wait:
297+
await writer.drain()
298+
return await self.http_accept(user, method, path, None, ver, lines, headers['Host'], headers.get('Proxy-Authorization'), reply, **kw)
299+
async def http_accept(self, user, method, path, authority, ver, lines, host, pauth, reply, authtable, users, httpget, **kw):
292300
url = urllib.parse.urlparse(path)
293-
if method == 'GET' and not url.hostname and httpget:
294-
for path, text in httpget.items():
301+
if method == 'GET' and not url.hostname:
302+
for path, text in (httpget.items() if httpget else ()):
295303
if path == url.path:
296304
user = next(filter(lambda x: x.decode()==url.query, users), None) if users else True
297305
if user:
298306
if users:
299307
authtable.set_authed(user)
300308
if type(text) is str:
301-
text = (text % dict(host=headers["Host"])).encode()
302-
writer.write(f'{ver} 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nCache-Control: max-age=900\r\nContent-Length: {len(text)}\r\n\r\n'.encode() + text)
303-
await writer.drain()
309+
text = (text % dict(host=host)).encode()
310+
await reply(200, f'{ver} 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nCache-Control: max-age=900\r\nContent-Length: {len(text)}\r\n\r\n'.encode(), text, True)
304311
raise Exception('Connection closed')
305312
raise Exception(f'404 {method} {url.path}')
306313
if users:
307-
pauth = headers.get('Proxy-Authorization', None)
308314
user = authtable.authed()
309315
if not user:
310316
user = next(filter(lambda i: ('Basic '+base64.b64encode(i).decode()) == pauth, users), None)
311317
if user is None:
312-
writer.write(f'{ver} 407 Proxy Authentication Required\r\nConnection: close\r\nProxy-Authenticate: Basic realm="simple"\r\n\r\n'.encode())
318+
await reply(407, f'{ver} 407 Proxy Authentication Required\r\nConnection: close\r\nProxy-Authenticate: Basic realm="simple"\r\n\r\n'.encode(), wait=True)
313319
raise Exception('Unauthorized HTTP')
314320
authtable.set_authed(user)
315321
if method == 'CONNECT':
316-
host_name, port = netloc_split(path)
317-
return user, host_name, port, f'{ver} 200 OK\r\nConnection: close\r\n\r\n'.encode()
322+
host_name, port = netloc_split(authority or path)
323+
return user, host_name, port, lambda writer: reply(200, f'{ver} 200 OK\r\nConnection: close\r\n\r\n'.encode())
318324
else:
319-
url = urllib.parse.urlparse(path)
320-
host_name, port = netloc_split(url.netloc or headers.get("Host"), default_port=80)
325+
host_name, port = netloc_split(url.netloc or host, default_port=80)
321326
newpath = url._replace(netloc='', scheme='').geturl()
322-
return user, host_name, port, b'', f'{method} {newpath} {ver}\r\n{lines}\r\n\r\n'.encode()
327+
async def connected(writer):
328+
writer.write(f'{method} {newpath} {ver}\r\n{lines}\r\n\r\n'.encode())
329+
return True
330+
return user, host_name, port, connected
323331
async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw):
324332
writer_remote.write(f'CONNECT {host_name}:{port} HTTP/1.1\r\nHost: {myhost}'.encode() + (b'\r\nProxy-Authorization: Basic '+base64.b64encode(rauth) if rauth else b'') + b'\r\n\r\n')
325333
await reader_remote.read_until(b'\r\n\r\n')
@@ -370,6 +378,29 @@ def write(data, o=writer_remote.write):
370378
return o(data)
371379
writer_remote.write = write
372380

381+
class H2(HTTP):
382+
async def guess(self, reader, **kw):
383+
return True
384+
async def accept(self, reader, user, writer, **kw):
385+
if not writer.headers.done():
386+
await writer.headers
387+
headers = writer.headers.result()
388+
headers = {i.decode().lower():j.decode() for i,j in headers}
389+
lines = '\r\n'.join(i for i in headers if not i.startswith('proxy-') and not i.startswith(':'))
390+
async def reply(code, message, body=None, wait=False):
391+
writer.send_headers(((':status', str(code)),))
392+
if body:
393+
writer.write(body)
394+
if wait:
395+
await writer.drain()
396+
return await self.http_accept(user, headers[':method'], headers[':path'], headers[':authority'], '2.0', lines, '', headers.get('proxy-authorization'), reply, **kw)
397+
async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw):
398+
headers = [(':method', 'CONNECT'), (':scheme', 'https'), (':path', '/'),
399+
(':authority', f'{host_name}:{port}')]
400+
if rauth:
401+
headers.append(('proxy-authorization', 'Basic '+base64.b64encode(rauth)))
402+
writer_remote.send_headers(headers)
403+
373404
class SSH(BaseProtocol):
374405
async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw):
375406
pass
@@ -521,8 +552,8 @@ async def accept(protos, reader, **kw):
521552
raise Exception('Connection closed')
522553
if user:
523554
ret = await proto.accept(reader, user, **kw)
524-
while len(ret) < 5:
525-
ret += (b'',)
555+
while len(ret) < 4:
556+
ret += (None,)
526557
return (proto,) + ret
527558
raise Exception(f'Unsupported protocol')
528559

@@ -533,7 +564,7 @@ def udp_accept(protos, data, **kw):
533564
return (proto,) + ret
534565
raise Exception(f'Unsupported protocol {data[:10]}')
535566

536-
MAPPINGS = dict(direct=Direct, http=HTTP, httponly=HTTPOnly, ssh=SSH, socks5=Socks5, socks4=Socks4, socks=Socks5, ss=SS, ssr=SSR, redir=Redir, pf=Pf, tunnel=Tunnel, echo=Echo, ws=WS, trojan=Trojan, ssl='', secure='', quic='')
567+
MAPPINGS = dict(direct=Direct, http=HTTP, httponly=HTTPOnly, ssh=SSH, socks5=Socks5, socks4=Socks4, socks=Socks5, ss=SS, ssr=SSR, redir=Redir, pf=Pf, tunnel=Tunnel, echo=Echo, ws=WS, trojan=Trojan, h2=H2, ssl='', secure='', quic='')
537568
MAPPINGS['in'] = ''
538569

539570
def get_protos(rawprotos):
@@ -553,7 +584,7 @@ def get_protos(rawprotos):
553584
def sslwrap(reader, writer, sslcontext, server_side=False, server_hostname=None, verbose=None):
554585
if sslcontext is None:
555586
return reader, writer
556-
ssl_reader = type(reader)()
587+
ssl_reader = asyncio.StreamReader()
557588
class Protocol(asyncio.Protocol):
558589
def data_received(self, data):
559590
ssl_reader.feed_data(data)
@@ -576,13 +607,14 @@ def close(self):
576607
def _force_close(self, exc):
577608
if not self.closed:
578609
(verbose or print)(f'{exc} from {writer.get_extra_info("peername")[0]}')
610+
ssl._app_transport._closed = True
579611
self.close()
580612
def abort(self):
581613
self.close()
582614
ssl.connection_made(Transport())
583615
async def channel():
584616
try:
585-
while not reader.at_eof() and not ssl._app_transport.closed:
617+
while not reader.at_eof() and not ssl._app_transport._closed:
586618
data = await reader.read(65536)
587619
if not data:
588620
break
@@ -600,7 +632,8 @@ def write(self, data):
600632
def drain(self):
601633
return writer.drain()
602634
def is_closing(self):
603-
return ssl._app_transport.closed
635+
return ssl._app_transport._closed
604636
def close(self):
605-
ssl._app_transport.close()
637+
if not ssl._app_transport._closed:
638+
ssl._app_transport.close()
606639
return ssl_reader, Writer()

pproxy/server.py

Lines changed: 139 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s
7171
remote_text = f'{remote_ip}:{remote_port}'
7272
local_addr = None if server_ip in ('127.0.0.1', '::1', None) else (server_ip, 0)
7373
reader_cipher, _ = await prepare_ciphers(cipher, reader, writer, server_side=False)
74-
lproto, user, host_name, port, lbuf, rbuf = await proto.accept(protos, reader=reader, writer=writer, authtable=AuthTable(remote_ip, authtime), reader_cipher=reader_cipher, sock=writer.get_extra_info('socket'), **kwargs)
74+
lproto, user, host_name, port, client_connected = await proto.accept(protos, reader=reader, writer=writer, authtable=AuthTable(remote_ip, authtime), reader_cipher=reader_cipher, sock=writer.get_extra_info('socket'), **kwargs)
7575
if host_name == 'echo':
7676
asyncio.ensure_future(lproto.channel(reader, writer, DUMMY, DUMMY))
7777
elif host_name == 'empty':
@@ -87,13 +87,12 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s
8787
raise Exception(f'Connection timeout {roption.bind}')
8888
try:
8989
reader_remote, writer_remote = await roption.prepare_connection(reader_remote, writer_remote, host_name, port)
90-
writer.write(lbuf)
91-
writer_remote.write(rbuf)
90+
use_http = (await client_connected(writer_remote)) if client_connected else None
9291
except Exception:
9392
writer_remote.close()
9493
raise Exception('Unknown remote protocol')
9594
m = modstat(user, remote_ip, host_name)
96-
lchannel = lproto.http_channel if rbuf else lproto.channel
95+
lchannel = lproto.http_channel if use_http else lproto.channel
9796
asyncio.ensure_future(lproto.channel(reader_remote, writer, m(2+roption.direct), m(4+roption.direct)))
9897
asyncio.ensure_future(lchannel(reader, writer_remote, m(roption.direct), roption.connection_change))
9998
except Exception as ex:
@@ -304,6 +303,126 @@ def start_server(self, args, stream_handler=stream_handler):
304303
else:
305304
return asyncio.start_server(handler, host=self.host_name, port=self.port, reuse_port=args.get('ruport'))
306305

306+
class ProxyH2(ProxySimple):
307+
def __init__(self, sslserver, sslclient, **kw):
308+
super().__init__(sslserver=None, sslclient=None, **kw)
309+
self.handshake = None
310+
self.h2sslserver = sslserver
311+
self.h2sslclient = sslclient
312+
async def handler(self, reader, writer, client_side=True, stream_handler=None, **kw):
313+
import h2.connection, h2.config, h2.events
314+
reader, writer = proto.sslwrap(reader, writer, self.h2sslclient if client_side else self.h2sslserver, not client_side, None)
315+
config = h2.config.H2Configuration(client_side=client_side)
316+
conn = h2.connection.H2Connection(config=config)
317+
streams = {}
318+
conn.initiate_connection()
319+
writer.write(conn.data_to_send())
320+
while not reader.at_eof() and not writer.is_closing():
321+
try:
322+
data = await reader.read(65636)
323+
if not data:
324+
break
325+
events = conn.receive_data(data)
326+
except Exception:
327+
pass
328+
writer.write(conn.data_to_send())
329+
for event in events:
330+
if isinstance(event, h2.events.RequestReceived) and not client_side:
331+
if event.stream_id not in streams:
332+
stream_reader, stream_writer = self.get_stream(conn, writer, event.stream_id)
333+
streams[event.stream_id] = (stream_reader, stream_writer)
334+
asyncio.ensure_future(stream_handler(stream_reader, stream_writer))
335+
else:
336+
stream_reader, stream_writer = streams[event.stream_id]
337+
stream_writer.headers.set_result(event.headers)
338+
elif isinstance(event, h2.events.SettingsAcknowledged) and client_side:
339+
self.handshake.set_result((conn, streams, writer))
340+
elif isinstance(event, h2.events.DataReceived):
341+
stream_reader, stream_writer = streams[event.stream_id]
342+
stream_reader.feed_data(event.data)
343+
conn.acknowledge_received_data(len(event.data), event.stream_id)
344+
writer.write(conn.data_to_send())
345+
elif isinstance(event, h2.events.StreamEnded) or isinstance(event, h2.events.StreamReset):
346+
stream_reader, stream_writer = streams[event.stream_id]
347+
stream_reader.feed_eof()
348+
if not stream_writer.closed:
349+
stream_writer.close()
350+
elif isinstance(event, h2.events.ConnectionTerminated):
351+
break
352+
elif isinstance(event, h2.events.WindowUpdated):
353+
if event.stream_id in streams:
354+
stream_reader, stream_writer = streams[event.stream_id]
355+
stream_writer.window_update()
356+
writer.write(conn.data_to_send())
357+
writer.close()
358+
def get_stream(self, conn, writer, stream_id):
359+
reader = asyncio.StreamReader()
360+
write_buffer = bytearray()
361+
write_wait = asyncio.Event()
362+
write_full = asyncio.Event()
363+
class StreamWriter():
364+
def __init__(self):
365+
self.closed = False
366+
self.headers = asyncio.get_event_loop().create_future()
367+
def get_extra_info(self, key):
368+
return writer.get_extra_info(key)
369+
def write(self, data):
370+
write_buffer.extend(data)
371+
write_wait.set()
372+
def drain(self):
373+
writer.write(conn.data_to_send())
374+
return writer.drain()
375+
def is_closing(self):
376+
return self.closed
377+
def close(self):
378+
self.closed = True
379+
write_wait.set()
380+
def window_update(self):
381+
write_full.set()
382+
def send_headers(self, headers):
383+
conn.send_headers(stream_id, headers)
384+
writer.write(conn.data_to_send())
385+
stream_writer = StreamWriter()
386+
async def write_job():
387+
while not stream_writer.closed:
388+
while len(write_buffer) > 0:
389+
while conn.local_flow_control_window(stream_id) <= 0:
390+
write_full.clear()
391+
await write_full.wait()
392+
if stream_writer.closed:
393+
break
394+
chunk_size = min(conn.local_flow_control_window(stream_id), len(write_buffer), conn.max_outbound_frame_size)
395+
conn.send_data(stream_id, write_buffer[:chunk_size])
396+
writer.write(conn.data_to_send())
397+
del write_buffer[:chunk_size]
398+
if not stream_writer.closed:
399+
write_wait.clear()
400+
await write_wait.wait()
401+
conn.send_data(stream_id, b'', end_stream=True)
402+
writer.write(conn.data_to_send())
403+
asyncio.ensure_future(write_job())
404+
return reader, stream_writer
405+
async def wait_h2_connection(self, local_addr, family):
406+
if self.handshake is not None:
407+
if not self.handshake.done():
408+
await self.handshake
409+
else:
410+
self.handshake = asyncio.get_event_loop().create_future()
411+
reader, writer = await super().wait_open_connection(None, None, local_addr, family)
412+
asyncio.ensure_future(self.handler(reader, writer))
413+
await self.handshake
414+
return self.handshake.result()
415+
async def wait_open_connection(self, host, port, local_addr, family):
416+
conn, streams, writer = await self.wait_h2_connection(local_addr, family)
417+
stream_id = conn.get_next_available_stream_id()
418+
conn._begin_new_stream(stream_id, stream_id%2)
419+
stream_reader, stream_writer = self.get_stream(conn, writer, stream_id)
420+
streams[stream_id] = (stream_reader, stream_writer)
421+
return stream_reader, stream_writer
422+
def start_server(self, args, stream_handler=stream_handler):
423+
handler = functools.partial(stream_handler, **vars(self), **args)
424+
return super().start_server(args, functools.partial(self.handler, client_side=False, stream_handler=handler))
425+
307426
class ProxyQUIC(ProxySimple):
308427
def __init__(self, quicserver, quicclient, **kw):
309428
super().__init__(**kw)
@@ -544,6 +663,8 @@ def proxies_by_uri(uri_jumps):
544663
jump = proxy_by_uri(uri, jump)
545664
return jump
546665

666+
sslcontexts = []
667+
547668
def proxy_by_uri(uri, jump):
548669
scheme, _, uri = uri.partition('://')
549670
url = urllib.parse.urlparse('s://'+uri)
@@ -558,17 +679,25 @@ def proxy_by_uri(uri, jump):
558679
if 'ssl' in rawprotos:
559680
sslclient.check_hostname = False
560681
sslclient.verify_mode = ssl.CERT_NONE
682+
sslcontexts.append(sslserver)
683+
sslcontexts.append(sslclient)
561684
else:
562685
sslserver = sslclient = None
563686
if 'quic' in rawprotos:
564687
try:
565688
import ssl, aioquic.quic.configuration
566689
except Exception:
567690
raise Exception('Missing library: "pip3 install aioquic"')
568-
import logging
569691
quicserver = aioquic.quic.configuration.QuicConfiguration(is_client=False)
570692
quicclient = aioquic.quic.configuration.QuicConfiguration()
571693
quicclient.verify_mode = ssl.CERT_NONE
694+
sslcontexts.append(quicserver)
695+
sslcontexts.append(quicclient)
696+
if 'h2' in rawprotos:
697+
try:
698+
import h2
699+
except Exception:
700+
raise Exception('Missing library: "pip3 install h2"')
572701
protonames = [i.name for i in protos]
573702
urlpath, _, plugins = url.path.partition(',')
574703
urlpath, _, lbind = urlpath.partition('@')
@@ -611,6 +740,8 @@ def proxy_by_uri(uri, jump):
611740
host_name=host_name, port=port, unix=not loc, lbind=lbind, sslclient=sslclient, sslserver=sslserver)
612741
if 'quic' in rawprotos:
613742
proxy = ProxyQUIC(quicserver, quicclient, **params)
743+
elif 'h2' in protonames:
744+
proxy = ProxyH2(**params)
614745
elif 'ssh' in protonames:
615746
proxy = ProxySSH(**params)
616747
else:
@@ -646,7 +777,7 @@ async def test_url(url, rserver):
646777
print(headers.decode()[:-4])
647778
print(f'--------------------------------')
648779
body = bytearray()
649-
while 1:
780+
while not reader.at_eof():
650781
s = await reader.read(65536)
651782
if not s:
652783
break
@@ -677,15 +808,8 @@ def main():
677808
args = parser.parse_args()
678809
if args.sslfile:
679810
sslfile = args.sslfile.split(',')
680-
for option in args.listen:
681-
if option.sslclient:
682-
option.sslclient.load_cert_chain(*sslfile)
683-
option.sslserver.load_cert_chain(*sslfile)
684-
for option in args.listen+args.ulisten+args.rserver+args.urserver:
685-
if isinstance(option, ProxyQUIC):
686-
option.quicserver.load_cert_chain(*sslfile)
687-
if isinstance(option, ProxyBackward) and isinstance(option.backward, ProxyQUIC):
688-
option.backward.quicserver.load_cert_chain(*sslfile)
811+
for context in sslcontexts:
812+
context.load_cert_chain(*sslfile)
689813
elif any(map(lambda o: o.sslclient or isinstance(o, ProxyQUIC), args.listen+args.ulisten)):
690814
print('You must specify --ssl to listen in ssl mode')
691815
return

0 commit comments

Comments
 (0)