Skip to content

Commit 7b490d1

Browse files
committed
[DTLS] add TestPSK and test_psk_hint_fail
1 parent c73bae2 commit 7b490d1

File tree

7 files changed

+199
-34
lines changed

7 files changed

+199
-34
lines changed

dtls/src/config.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub struct Config {
4545
// psk sets the pre-shared key used by this DTLS connection
4646
// If psk is non-nil only psk cipher_suites will be used
4747
pub(crate) psk: Option<PSKCallback>,
48-
pub(crate) psk_identity_hint: Vec<u8>,
48+
pub(crate) psk_identity_hint: Option<Vec<u8>>,
4949

5050
// insecure_skip_verify controls whether a client verifies the
5151
// server's certificate chain and host name.
@@ -151,7 +151,7 @@ pub(crate) fn validate_config(config: &Config) -> Result<(), Error> {
151151
return Err(ERR_PSK_AND_CERTIFICATE.clone());
152152
}
153153

154-
if !config.psk_identity_hint.is_empty() && config.psk.is_none() {
154+
if config.psk_identity_hint.is_some() && config.psk.is_none() {
155155
return Err(ERR_IDENTITY_NO_PSK.clone());
156156
}
157157

dtls/src/conn.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ impl Conn {
172172

173173
let cfg = HandshakeConfig {
174174
local_psk_callback: config.psk.take(),
175-
local_psk_identity_hint: config.psk_identity_hint.clone(),
175+
local_psk_identity_hint: config.psk_identity_hint.take(),
176176
local_cipher_suites,
177177
local_signature_schemes,
178178
extended_master_secret: config.extended_master_secret,
@@ -326,6 +326,9 @@ impl Conn {
326326
srv_cli_str(is_client),
327327
err
328328
);
329+
if err == *ERR_ALERT_FATAL_OR_CLOSE {
330+
break;
331+
}
329332
}
330333
}
331334
}
@@ -950,13 +953,13 @@ impl Conn {
950953
let mut reader = BufReader::new(out.as_slice());
951954
let raw_handshake = match Handshake::unmarshal(&mut reader) {
952955
Ok(rh) => {
953-
trace!(
956+
/*trace!(
954957
"Recv [handshake:{}] -> {} (epoch: {}, seq: {})",
955958
srv_cli_str(ctx.is_client),
956959
rh.handshake_header.handshake_type.to_string(),
957960
h.epoch,
958961
rh.handshake_header.message_sequence
959-
);
962+
);*/
960963
rh
961964
}
962965
Err(err) => {

dtls/src/conn/conn_test.rs

Lines changed: 164 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,18 @@ use crate::errors::*;
1010
use tokio::net::UdpSocket;
1111

1212
use std::time::SystemTime;
13+
1314
//use std::io::Write;
1415

16+
lazy_static! {
17+
pub static ref ERR_TEST_PSK_INVALID_IDENTITY: Error =
18+
Error::new("TestPSK: Server got invalid identity".to_owned());
19+
pub static ref ERR_PSK_REJECTED: Error = Error::new("PSK Rejected".to_owned());
20+
pub static ref ERR_NOT_EXPECTED_CHAIN: Error = Error::new("not expected chain".to_owned());
21+
pub static ref ERR_EXPECTED_CHAIN: Error = Error::new("expected chain".to_owned());
22+
pub static ref ERR_WRONG_CERT: Error = Error::new("wrong cert".to_owned());
23+
}
24+
1525
async fn build_pipe() -> Result<(Conn, Conn), Error> {
1626
let (ua, ub) = pipe().await?;
1727

@@ -39,10 +49,13 @@ async fn pipe_conn(ca: UdpSocket, cb: UdpSocket) -> Result<(Conn, Conn), Error>
3949
ca,
4050
Config {
4151
srtp_protection_profiles: vec![SRTPProtectionProfile::SRTP_AES128_CM_HMAC_SHA1_80],
52+
//TODO: change PSK to cert
4253
cipher_suites: vec![CipherSuiteID::TLS_PSK_WITH_AES_128_GCM_SHA256],
54+
psk: Some(psk_callback_client),
55+
psk_identity_hint: Some("WebRTC.rs DTLS Server".as_bytes().to_vec()),
4356
..Default::default()
4457
},
45-
false,
58+
false, //TODO: use ceritificate
4659
)
4760
.await;
4861

@@ -54,10 +67,13 @@ async fn pipe_conn(ca: UdpSocket, cb: UdpSocket) -> Result<(Conn, Conn), Error>
5467
cb,
5568
Config {
5669
srtp_protection_profiles: vec![SRTPProtectionProfile::SRTP_AES128_CM_HMAC_SHA1_80],
70+
//TODO: change PSK to cert
5771
cipher_suites: vec![CipherSuiteID::TLS_PSK_WITH_AES_128_GCM_SHA256],
72+
psk: Some(psk_callback_server),
73+
psk_identity_hint: Some("WebRTC.rs DTLS Client".as_bytes().to_vec()),
5874
..Default::default()
5975
},
60-
false,
76+
false, //TODO: use ceritificate
6177
)
6278
.await?;
6379

@@ -86,16 +102,17 @@ fn psk_callback_server(hint: &[u8]) -> Result<Vec<u8>, Error> {
86102
Ok(vec![0xAB, 0xC1, 0x23])
87103
}
88104

105+
fn psk_callback_hint_fail(_hint: &[u8]) -> Result<Vec<u8>, Error> {
106+
Err(ERR_PSK_REJECTED.clone())
107+
}
108+
89109
async fn create_test_client(
90110
ca: UdpSocket,
91111
mut cfg: Config,
92112
generate_certificate: bool,
93113
) -> Result<Conn, Error> {
94114
if generate_certificate {
95115
//TODO:
96-
} else {
97-
cfg.psk = Some(psk_callback_client);
98-
cfg.psk_identity_hint = "WebRTC.rs DTLS Server".as_bytes().to_vec();
99116
}
100117

101118
cfg.insecure_skip_verify = true;
@@ -104,15 +121,13 @@ async fn create_test_client(
104121

105122
async fn create_test_server(
106123
cb: UdpSocket,
107-
mut cfg: Config,
124+
cfg: Config,
108125
generate_certificate: bool,
109126
) -> Result<Conn, Error> {
110127
if generate_certificate {
111128
//TODO:
112-
} else {
113-
cfg.psk = Some(psk_callback_server);
114-
cfg.psk_identity_hint = "WebRTC.rs DTLS Client".as_bytes().to_vec();
115129
}
130+
116131
Conn::new(cb, cfg, false, None).await
117132
}
118133

@@ -340,11 +355,11 @@ async fn test_handshake_with_alert() -> Result<(), Error> {
340355
341356
let (ca, cb) = pipe().await?;
342357
tokio::spawn(async move {
343-
let result = create_test_client(ca, config_client, false).await;
358+
let result = create_test_client(ca, config_client, false).await; //TODO: use certificate
344359
let _ = client_err_tx.send(result).await;
345360
});
346361
347-
let result_server = create_test_server(cb, config_server, false).await;
362+
let result_server = create_test_server(cb, config_server, false).await; //TODO: use certificate
348363
if let Err(err) = result_server {
349364
assert_eq!(
350365
err, err_server,
@@ -495,3 +510,141 @@ async fn test_export_keying_material() -> Result<(), Error> {
495510

496511
Ok(())
497512
}
513+
514+
#[tokio::test]
515+
async fn test_psk() -> Result<(), Error> {
516+
/*env_logger::Builder::new()
517+
.format(|buf, record| {
518+
writeln!(
519+
buf,
520+
"{}:{} [{}] {} - {}",
521+
record.file().unwrap_or("unknown"),
522+
record.line().unwrap_or(0),
523+
record.level(),
524+
chrono::Local::now().format("%H:%M:%S.%6f"),
525+
record.args()
526+
)
527+
})
528+
.filter(None, LevelFilter::Trace)
529+
.init();*/
530+
531+
let tests = vec![
532+
(
533+
"Server identity specified",
534+
Some("Test Identity".as_bytes().to_vec()),
535+
),
536+
("Server identity nil", None),
537+
];
538+
539+
for (name, server_identity) in tests {
540+
let client_identity = "Client Identity".as_bytes();
541+
let (client_res_tx, mut client_res_rx) = mpsc::channel(1);
542+
543+
let (ca, cb) = pipe().await?;
544+
tokio::spawn(async move {
545+
let conf = Config {
546+
psk: Some(psk_callback_client),
547+
psk_identity_hint: Some(client_identity.to_vec()),
548+
cipher_suites: vec![CipherSuiteID::TLS_PSK_WITH_AES_128_GCM_SHA256], //TODO: change it to TLS_PSK_WITH_AES_128_CCM_8
549+
..Default::default()
550+
};
551+
552+
let result = create_test_client(ca, conf, false).await;
553+
let _ = client_res_tx.send(result).await;
554+
});
555+
556+
let config = Config {
557+
psk: Some(psk_callback_server),
558+
psk_identity_hint: server_identity,
559+
cipher_suites: vec![CipherSuiteID::TLS_PSK_WITH_AES_128_GCM_SHA256], //TODO: change it to TLS_PSK_WITH_AES_128_CCM_8
560+
..Default::default()
561+
};
562+
563+
let mut server = create_test_server(cb, config, false).await?;
564+
565+
if let Some(result) = client_res_rx.recv().await {
566+
if let Ok(mut client) = result {
567+
client.close().await?;
568+
} else {
569+
assert!(
570+
false,
571+
"{}: Expected create_test_client successfully, but got error",
572+
name,
573+
);
574+
}
575+
}
576+
577+
server.close().await?;
578+
}
579+
580+
Ok(())
581+
}
582+
583+
#[tokio::test]
584+
async fn test_psk_hint_fail() -> Result<(), Error> {
585+
/*env_logger::Builder::new()
586+
.format(|buf, record| {
587+
writeln!(
588+
buf,
589+
"{}:{} [{}] {} - {}",
590+
record.file().unwrap_or("unknown"),
591+
record.line().unwrap_or(0),
592+
record.level(),
593+
chrono::Local::now().format("%H:%M:%S.%6f"),
594+
record.args()
595+
)
596+
})
597+
.filter(None, LevelFilter::Trace)
598+
.init();*/
599+
600+
let (client_res_tx, mut client_res_rx) = mpsc::channel(1);
601+
602+
let (ca, cb) = pipe().await?;
603+
tokio::spawn(async move {
604+
let conf = Config {
605+
psk: Some(psk_callback_hint_fail),
606+
psk_identity_hint: Some(vec![]),
607+
cipher_suites: vec![CipherSuiteID::TLS_PSK_WITH_AES_128_GCM_SHA256], //TODO: change it to TLS_PSK_WITH_AES_128_CCM_8
608+
..Default::default()
609+
};
610+
611+
let result = create_test_client(ca, conf, false).await;
612+
let _ = client_res_tx.send(result).await;
613+
});
614+
615+
let config = Config {
616+
psk: Some(psk_callback_hint_fail),
617+
psk_identity_hint: Some(vec![]),
618+
cipher_suites: vec![CipherSuiteID::TLS_PSK_WITH_AES_128_GCM_SHA256], //TODO: change it to TLS_PSK_WITH_AES_128_CCM_8
619+
..Default::default()
620+
};
621+
622+
if let Err(server_err) = create_test_server(cb, config, false).await {
623+
assert_eq!(
624+
server_err,
625+
ERR_ALERT_FATAL_OR_CLOSE.clone(),
626+
"TestPSK: Server error exp({}) failed({})",
627+
ERR_ALERT_FATAL_OR_CLOSE.clone(),
628+
server_err,
629+
);
630+
} else {
631+
assert!(false, "Expected server error, but got OK");
632+
}
633+
634+
let result = client_res_rx.recv().await;
635+
if let Some(client) = result {
636+
if let Err(client_err) = client {
637+
assert_eq!(
638+
client_err,
639+
ERR_PSK_REJECTED.clone(),
640+
"TestPSK: Client error exp({}) failed({})",
641+
ERR_PSK_REJECTED.clone(),
642+
client_err,
643+
);
644+
} else {
645+
assert!(false, "Expected client error, but got OK");
646+
}
647+
}
648+
649+
Ok(())
650+
}

dtls/src/flight/flight3.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ impl Flight for Flight3 {
9191
}
9292
}
9393

94-
let result = if cfg.local_psk_callback.is_none() {
94+
let result = if cfg.local_psk_callback.is_some() {
9595
cache
9696
.full_pull_map(
9797
state.handshake_recv_sequence,

dtls/src/flight/flight4.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ impl Flight for Flight4 {
660660
reset_local_sequence_number: false,
661661
});
662662
}
663-
} else if !cfg.local_psk_identity_hint.is_empty() {
663+
} else if let Some(local_psk_identity_hint) = &cfg.local_psk_identity_hint {
664664
// To help the client in selecting which identity to use, the server
665665
// can provide a "PSK identity hint" in the ServerKeyExchange message.
666666
// If no hint is provided, the ServerKeyExchange message is omitted.
@@ -672,7 +672,7 @@ impl Flight for Flight4 {
672672
0,
673673
Content::Handshake(Handshake::new(HandshakeMessage::ServerKeyExchange(
674674
HandshakeMessageServerKeyExchange {
675-
identity_hint: cfg.local_psk_identity_hint.clone(),
675+
identity_hint: local_psk_identity_hint.clone(),
676676
elliptic_curve_type: EllipticCurveType::Unsupported,
677677
named_curve: NamedCurve::Unsupported,
678678
public_key: vec![],

dtls/src/flight/flight5.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ impl Flight for Flight5 {
235235
if let Some(local_keypair) = &state.local_keypair {
236236
client_key_exchange.public_key = local_keypair.public_key.clone();
237237
}
238-
} else {
239-
client_key_exchange.identity_hint = cfg.local_psk_identity_hint.clone();
238+
} else if let Some(local_psk_identity_hint) = &cfg.local_psk_identity_hint {
239+
client_key_exchange.identity_hint = local_psk_identity_hint.clone();
240240
}
241241

242242
pkts.push(Packet {

0 commit comments

Comments
 (0)