Skip to content

Commit 4cd10cb

Browse files
committed
programs: Extract client/server programs from tests
The code is extracted from the tests in a slightly simpler version and augmented with command-line arguments. This patch further removes the end-to-end tests using these classes: They were skipped anyway because too complex and flaky.
1 parent 434bc44 commit 4cd10cb

File tree

6 files changed

+334
-409
lines changed

6 files changed

+334
-409
lines changed

.github/ISSUE_TEMPLATE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ This repository's issues are reserved for feature requests and bug reports.
2626

2727
### Steps to reproduce
2828

29+
<!-- Using `programs/client.py` and `programs/server.py` if possible -->
30+
2931
1.
3032
1.
3133
1.

ChangeLog

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
* tls: Add support for session caching.
44
* tls: Add context manager to `TLSWrappedSocket`
5+
* programs: Add example DTLS and TLS client and server.
56
* ci: Drop CircleCI.
67
* Update wheels to mbedtls 2.16.12
78
* Add support for Python 3.10.

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include *.toml
66
include *.txt
77
include ChangeLog
88
recursive-include docs *.rst
9+
recursive-include programs *.py
910
recursive-include requirements *.in
1011
recursive-include requirements *.txt
1112
recursive-include scripts *.sh

programs/client.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env python
2+
"""An example DTLS/TLS client.
3+
4+
Run ./programs/client.py --help.
5+
6+
"""
7+
8+
import argparse
9+
import socket
10+
from contextlib import suppress
11+
12+
from mbedtls.exceptions import TLSError
13+
from mbedtls.tls import (
14+
ClientContext,
15+
DTLSConfiguration,
16+
TLSConfiguration,
17+
_enable_debug_output,
18+
_set_debug_level,
19+
)
20+
21+
__all__ = ["Client"]
22+
23+
24+
def _echo_tls(sock, buffer, chunksize):
25+
view = memoryview(buffer)
26+
received = bytearray()
27+
for idx in range(0, len(view), chunksize):
28+
part = view[idx : idx + chunksize]
29+
_sent = sock.send(part)
30+
received += sock.recv(chunksize)
31+
return received
32+
33+
34+
def _echo_dtls(sock, buffer, chunksize):
35+
view = memoryview(buffer)
36+
received = bytearray()
37+
for idx in range(0, len(view), chunksize):
38+
part = view[idx : idx + chunksize]
39+
_sent = sock.send(part)
40+
data, _addr = sock.recvfrom(chunksize)
41+
received += data
42+
return received
43+
44+
45+
class Client:
46+
def __init__(self, cli_conf, proto, srv_address, srv_hostname):
47+
super().__init__()
48+
self.cli_conf = cli_conf
49+
self.proto = proto
50+
self.srv_address = srv_address
51+
self.srv_hostname = srv_hostname
52+
self._sock = None
53+
self._echo = {
54+
socket.SOCK_STREAM: _echo_tls,
55+
socket.SOCK_DGRAM: _echo_dtls,
56+
}[self.proto]
57+
58+
def __enter__(self):
59+
self.start()
60+
return self
61+
62+
def __exit__(self, *exc_info):
63+
self.stop()
64+
65+
def __del__(self):
66+
self.stop()
67+
68+
@property
69+
def context(self):
70+
if self._sock is None:
71+
return None
72+
return self._sock.context
73+
74+
def do_handshake(self):
75+
if not self._sock:
76+
return
77+
78+
self._sock.do_handshake()
79+
80+
def echo(self, buffer, chunksize):
81+
if not self._sock:
82+
return b""
83+
84+
return bytes(self._echo(self._sock, buffer, chunksize))
85+
86+
def start(self):
87+
if self._sock:
88+
self.stop()
89+
90+
self._sock = ClientContext(self.cli_conf).wrap_socket(
91+
socket.socket(socket.AF_INET, self.proto),
92+
server_hostname=self.srv_hostname,
93+
)
94+
self._sock.connect(self.srv_address)
95+
96+
def stop(self):
97+
if not self._sock:
98+
return
99+
100+
with suppress(TLSError, OSError):
101+
self._sock.close()
102+
self._sock = None
103+
104+
def restart(self):
105+
self.stop()
106+
self.start()
107+
108+
109+
def parse_args():
110+
class PSKArg(argparse.Action):
111+
def __call__(self, parser, namespace, values, option_string=None):
112+
def entry(kv):
113+
k, v = kv.split(b"=")
114+
return k.decode("utf-8"), v
115+
116+
setattr(namespace, self.dest, entry(values))
117+
118+
parser = argparse.ArgumentParser(description="client")
119+
group = parser.add_mutually_exclusive_group(required=True)
120+
group.add_argument(
121+
"--tls", dest="proto", action="store_const", const=socket.SOCK_STREAM
122+
)
123+
group.add_argument(
124+
"--dtls", dest="proto", action="store_const", const=socket.SOCK_DGRAM
125+
)
126+
parser.add_argument("--address", default="127.0.0.1")
127+
parser.add_argument("--port", default=4433, type=int)
128+
parser.add_argument("--debug", type=int)
129+
parser.add_argument("--server-name", default="localhost")
130+
parser.add_argument(
131+
"--psk",
132+
type=lambda x: x.encode("latin1"),
133+
action=PSKArg,
134+
metavar="CLI=SECRET",
135+
)
136+
parser.add_argument("message", default="hello")
137+
return parser.parse_args()
138+
139+
140+
def main(args):
141+
conf = {
142+
socket.SOCK_STREAM: TLSConfiguration,
143+
socket.SOCK_DGRAM: DTLSConfiguration,
144+
}[args.proto](pre_shared_key=args.psk, validate_certificates=False)
145+
146+
if args.debug is not None:
147+
_enable_debug_output(conf)
148+
_set_debug_level(args.debug)
149+
150+
with Client(
151+
conf, args.proto, (args.address, args.port), args.server_name
152+
) as cli:
153+
cli.do_handshake()
154+
received = cli.echo(args.message.encode("utf-8"), 1024)
155+
print(received.decode("utf-8"))
156+
157+
158+
if __name__ == "__main__":
159+
main(parse_args())

programs/server.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#!/usr/bin/env python
2+
"""An example DTLS/TLS server.
3+
4+
Run ./programs/server.py --help.
5+
6+
"""
7+
8+
import argparse
9+
import socket
10+
from contextlib import suppress
11+
from functools import partial
12+
13+
from mbedtls.exceptions import TLSError
14+
from mbedtls.tls import (
15+
DTLSConfiguration,
16+
HelloVerifyRequest,
17+
ServerContext,
18+
TLSConfiguration,
19+
_enable_debug_output,
20+
_set_debug_level,
21+
)
22+
23+
__all__ = ["Server"]
24+
25+
26+
def _make_tls_connection(sock):
27+
assert sock
28+
conn, _addr = sock.accept()
29+
return conn
30+
31+
32+
def _make_dtls_connection(sock):
33+
assert sock
34+
conn, addr = sock.accept()
35+
conn.setcookieparam(addr[0].encode("ascii"))
36+
with suppress(HelloVerifyRequest):
37+
conn.do_handshake()
38+
39+
_, (conn, addr) = conn, conn.accept()
40+
_.close()
41+
conn.setcookieparam(addr[0].encode("ascii"))
42+
return conn
43+
44+
45+
class Server:
46+
def __init__(self, srv_conf, proto, address):
47+
super().__init__()
48+
self.srv_conf = srv_conf
49+
self.proto = proto
50+
self.address = address
51+
self._make_connection = {
52+
socket.SOCK_STREAM: _make_tls_connection,
53+
socket.SOCK_DGRAM: _make_dtls_connection,
54+
}[self.proto]
55+
self._sock = None
56+
57+
def __enter__(self):
58+
self.start()
59+
return self
60+
61+
def __exit__(self, *exc_info):
62+
self.stop()
63+
64+
def __del__(self):
65+
self.stop()
66+
67+
@property
68+
def context(self):
69+
if self._sock is None:
70+
return None
71+
return self._sock.context
72+
73+
def start(self):
74+
if self._sock:
75+
self.stop()
76+
77+
self._sock = ServerContext(self.srv_conf).wrap_socket(
78+
socket.socket(socket.AF_INET, self.proto)
79+
)
80+
self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
81+
self._sock.bind(self.address)
82+
if self.proto is socket.SOCK_STREAM:
83+
self._sock.listen(1)
84+
85+
def stop(self):
86+
if not self._sock:
87+
return
88+
89+
self._sock.close()
90+
self._sock = None
91+
92+
def run(self, conn_handler):
93+
if not self._sock:
94+
raise ConnectionRefusedError("server not started")
95+
96+
while True:
97+
self._run(conn_handler)
98+
99+
def _run(self, conn_handler):
100+
with self._make_connection(self._sock) as conn:
101+
try:
102+
conn.do_handshake()
103+
except TLSError:
104+
return
105+
conn_handler(conn)
106+
107+
108+
def echo_handler(conn, *, packet_size):
109+
data = conn.recv(packet_size)
110+
sent = 0
111+
view = memoryview(data)
112+
while sent != len(data):
113+
sent += conn.send(view[sent:])
114+
115+
116+
def parse_args():
117+
class PSKStoreArg(argparse.Action):
118+
def __call__(self, parser, namespace, values, option_string=None):
119+
def entry(kv):
120+
k, v = kv.split(b"=")
121+
return k.decode("utf-8"), v
122+
123+
setattr(
124+
namespace,
125+
self.dest,
126+
dict(entry(kv) for kv in values.split(b",")),
127+
)
128+
129+
parser = argparse.ArgumentParser(description="server")
130+
group = parser.add_mutually_exclusive_group(required=True)
131+
group.add_argument(
132+
"--tls", dest="proto", action="store_const", const=socket.SOCK_STREAM
133+
)
134+
group.add_argument(
135+
"--dtls", dest="proto", action="store_const", const=socket.SOCK_DGRAM
136+
)
137+
parser.add_argument("--address", default="0.0.0.0")
138+
parser.add_argument("--port", default=4433, type=int)
139+
parser.add_argument("--debug", type=int)
140+
parser.add_argument(
141+
"--psk-store",
142+
type=lambda x: x.encode("latin1"),
143+
action=PSKStoreArg,
144+
metavar="CLI1=SECRET1,CLI2=SECRET2...",
145+
)
146+
return parser.parse_args()
147+
148+
149+
def main(args):
150+
conf = {
151+
socket.SOCK_STREAM: TLSConfiguration,
152+
socket.SOCK_DGRAM: DTLSConfiguration,
153+
}[args.proto](
154+
pre_shared_key_store=args.psk_store,
155+
validate_certificates=False,
156+
)
157+
158+
if args.debug is not None:
159+
_enable_debug_output(conf)
160+
_set_debug_level(args.debug)
161+
162+
with Server(conf, args.proto, (args.address, args.port)) as srv:
163+
srv.run(partial(echo_handler, packet_size=4069))
164+
165+
166+
if __name__ == "__main__":
167+
import faulthandler
168+
169+
faulthandler.enable()
170+
with suppress(KeyboardInterrupt):
171+
main(parse_args())

0 commit comments

Comments
 (0)