Skip to content

Commit 34db8de

Browse files
committed
feat(http1) Add support for writing Trailer Fields
Closes #2719
1 parent 429ad8a commit 34db8de

File tree

8 files changed

+473
-30
lines changed

8 files changed

+473
-30
lines changed

src/proto/h1/conn.rs

+38
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ where
7575
// We assume a modern world where the remote speaks HTTP/1.1.
7676
// If they tell us otherwise, we'll downgrade in `read_head`.
7777
version: Version::HTTP_11,
78+
allow_trailer_fields: false,
7879
},
7980
_marker: PhantomData,
8081
}
@@ -264,6 +265,16 @@ where
264265
self.state.reading = Reading::Body(Decoder::new(msg.decode));
265266
}
266267

268+
if let Some(Ok(te_value)) = msg.head.headers.get("te").map(|v| v.to_str()) {
269+
if te_value.eq_ignore_ascii_case("trailers") {
270+
self.state.allow_trailer_fields = true;
271+
} else {
272+
self.state.allow_trailer_fields = false;
273+
}
274+
} else {
275+
self.state.allow_trailer_fields = false;
276+
}
277+
267278
Poll::Ready(Some(Ok((msg.head, msg.decode, wants))))
268279
}
269280

@@ -640,6 +651,31 @@ where
640651
self.state.writing = state;
641652
}
642653

654+
pub(crate) fn write_trailers(&mut self, trailers: HeaderMap) {
655+
if T::is_server() && self.state.allow_trailer_fields == false {
656+
debug!("trailers not allowed to be sent");
657+
return;
658+
}
659+
debug_assert!(self.can_write_body() && self.can_buffer_body());
660+
661+
match self.state.writing {
662+
Writing::Body(ref encoder) => {
663+
if let Some(enc_buf) =
664+
encoder.encode_trailers(trailers, self.state.title_case_headers)
665+
{
666+
self.io.buffer(enc_buf);
667+
668+
self.state.writing = if encoder.is_last() || encoder.is_close_delimited() {
669+
Writing::Closed
670+
} else {
671+
Writing::KeepAlive
672+
};
673+
}
674+
}
675+
_ => unreachable!("write_trailers invalid state: {:?}", self.state.writing),
676+
}
677+
}
678+
643679
pub(crate) fn write_body_and_end(&mut self, chunk: B) {
644680
debug_assert!(self.can_write_body() && self.can_buffer_body());
645681
// empty chunks should be discarded at Dispatcher level
@@ -842,6 +878,8 @@ struct State {
842878
upgrade: Option<crate::upgrade::Pending>,
843879
/// Either HTTP/1.0 or 1.1 connection
844880
version: Version,
881+
/// Flag to track if trailer fields are allowed to be sent
882+
allow_trailer_fields: bool,
845883
}
846884

847885
#[derive(Debug)]

src/proto/h1/dispatch.rs

+24-18
Original file line numberDiff line numberDiff line change
@@ -351,27 +351,33 @@ where
351351
*clear_body = true;
352352
crate::Error::new_user_body(e)
353353
})?;
354-
let chunk = if let Ok(data) = frame.into_data() {
355-
data
356-
} else {
357-
trace!("discarding non-data frame");
358-
continue;
359-
};
360-
let eos = body.is_end_stream();
361-
if eos {
362-
*clear_body = true;
363-
if chunk.remaining() == 0 {
364-
trace!("discarding empty chunk");
365-
self.conn.end_body()?;
354+
355+
if frame.is_data() {
356+
let chunk = frame.into_data().unwrap_or_else(|_| unreachable!());
357+
let eos = body.is_end_stream();
358+
if eos {
359+
*clear_body = true;
360+
if chunk.remaining() == 0 {
361+
trace!("discarding empty chunk");
362+
self.conn.end_body()?;
363+
} else {
364+
self.conn.write_body_and_end(chunk);
365+
}
366366
} else {
367-
self.conn.write_body_and_end(chunk);
367+
if chunk.remaining() == 0 {
368+
trace!("discarding empty chunk");
369+
continue;
370+
}
371+
self.conn.write_body(chunk);
368372
}
373+
} else if frame.is_trailers() {
374+
*clear_body = true;
375+
self.conn.write_trailers(
376+
frame.into_trailers().unwrap_or_else(|_| unreachable!()),
377+
);
369378
} else {
370-
if chunk.remaining() == 0 {
371-
trace!("discarding empty chunk");
372-
continue;
373-
}
374-
self.conn.write_body(chunk);
379+
trace!("discarding unknown frame");
380+
continue;
375381
}
376382
} else {
377383
*clear_body = true;

src/proto/h1/encode.rs

+118-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
use std::collections::HashMap;
12
use std::fmt;
23
use std::io::IoSlice;
34

45
use bytes::buf::{Chain, Take};
5-
use bytes::Buf;
6+
use bytes::{Buf, Bytes};
7+
use http::{
8+
header::{
9+
AUTHORIZATION, CACHE_CONTROL, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_RANGE,
10+
CONTENT_TYPE, HOST, MAX_FORWARDS, SET_COOKIE, TRAILER, TRANSFER_ENCODING,
11+
},
12+
HeaderMap, HeaderName, HeaderValue,
13+
};
614

715
use super::io::WriteBuf;
16+
use super::role::{write_headers, write_headers_title_case};
817

918
type StaticBuf = &'static [u8];
1019

@@ -26,7 +35,7 @@ pub(crate) struct NotEof(u64);
2635
#[derive(Debug, PartialEq, Clone)]
2736
enum Kind {
2837
/// An Encoder for when Transfer-Encoding includes `chunked`.
29-
Chunked,
38+
Chunked(Option<Vec<HeaderValue>>),
3039
/// An Encoder for when Content-Length is set.
3140
///
3241
/// Enforces that the body is not longer than the Content-Length header.
@@ -45,6 +54,7 @@ enum BufKind<B> {
4554
Limited(Take<B>),
4655
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
4756
ChunkedEnd(StaticBuf),
57+
Trailers(Chain<Chain<StaticBuf, Bytes>, StaticBuf>),
4858
}
4959

5060
impl Encoder {
@@ -55,7 +65,7 @@ impl Encoder {
5565
}
5666
}
5767
pub(crate) fn chunked() -> Encoder {
58-
Encoder::new(Kind::Chunked)
68+
Encoder::new(Kind::Chunked(None))
5969
}
6070

6171
pub(crate) fn length(len: u64) -> Encoder {
@@ -67,6 +77,16 @@ impl Encoder {
6777
Encoder::new(Kind::CloseDelimited)
6878
}
6979

80+
pub(crate) fn into_chunked_with_trailing_fields(self, trailers: Vec<HeaderValue>) -> Encoder {
81+
match self.kind {
82+
Kind::Chunked(_) => Encoder {
83+
kind: Kind::Chunked(Some(trailers)),
84+
is_last: self.is_last,
85+
},
86+
_ => self,
87+
}
88+
}
89+
7090
pub(crate) fn is_eof(&self) -> bool {
7191
matches!(self.kind, Kind::Length(0))
7292
}
@@ -89,10 +109,17 @@ impl Encoder {
89109
}
90110
}
91111

112+
pub(crate) fn is_chunked(&self) -> bool {
113+
match self.kind {
114+
Kind::Chunked(_) => true,
115+
_ => false,
116+
}
117+
}
118+
92119
pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
93120
match self.kind {
94121
Kind::Length(0) => Ok(None),
95-
Kind::Chunked => Ok(Some(EncodedBuf {
122+
Kind::Chunked(_) => Ok(Some(EncodedBuf {
96123
kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
97124
})),
98125
#[cfg(feature = "server")]
@@ -109,7 +136,7 @@ impl Encoder {
109136
debug_assert!(len > 0, "encode() called with empty buf");
110137

111138
let kind = match self.kind {
112-
Kind::Chunked => {
139+
Kind::Chunked(_) => {
113140
trace!("encoding chunked {}B", len);
114141
let buf = ChunkSize::new(len)
115142
.chain(msg)
@@ -136,6 +163,54 @@ impl Encoder {
136163
EncodedBuf { kind }
137164
}
138165

166+
pub(crate) fn encode_trailers<B>(
167+
&self,
168+
mut trailers: HeaderMap,
169+
title_case_headers: bool,
170+
) -> Option<EncodedBuf<B>> {
171+
match &self.kind {
172+
Kind::Chunked(allowed_trailer_fields) => {
173+
let allowed_trailer_fields_map = match allowed_trailer_fields {
174+
Some(ref allowed_trailer_fields) => {
175+
allowed_trailer_field_map(&allowed_trailer_fields)
176+
}
177+
None => return None,
178+
};
179+
180+
let mut cur_name = None;
181+
let mut allowed_trailers = HeaderMap::new();
182+
183+
for (opt_name, value) in trailers.drain() {
184+
if let Some(n) = opt_name {
185+
cur_name = Some(n);
186+
}
187+
let name = cur_name.as_ref().expect("current header name");
188+
189+
if allowed_trailer_fields_map.contains_key(name.as_str())
190+
&& !invalid_trailer_field(name)
191+
{
192+
allowed_trailers.insert(name, value);
193+
}
194+
}
195+
196+
let mut buf = Vec::new();
197+
if title_case_headers {
198+
write_headers_title_case(&allowed_trailers, &mut buf);
199+
} else {
200+
write_headers(&allowed_trailers, &mut buf);
201+
}
202+
203+
Some(EncodedBuf {
204+
kind: BufKind::Trailers(b"0\r\n".chain(Bytes::from(buf)).chain(b"\r\n")),
205+
})
206+
}
207+
_ => {
208+
debug!("attempted to encode trailers for non-chunked response");
209+
None
210+
}
211+
}
212+
}
213+
139214
pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
140215
where
141216
B: Buf,
@@ -144,7 +219,7 @@ impl Encoder {
144219
debug_assert!(len > 0, "encode() called with empty buf");
145220

146221
match self.kind {
147-
Kind::Chunked => {
222+
Kind::Chunked(_) => {
148223
trace!("encoding chunked {}B", len);
149224
let buf = ChunkSize::new(len)
150225
.chain(msg)
@@ -181,6 +256,39 @@ impl Encoder {
181256
}
182257
}
183258

259+
fn invalid_trailer_field(name: &HeaderName) -> bool {
260+
match name {
261+
&AUTHORIZATION => true,
262+
&CACHE_CONTROL => true,
263+
&CONTENT_ENCODING => true,
264+
&CONTENT_LENGTH => true,
265+
&CONTENT_RANGE => true,
266+
&CONTENT_TYPE => true,
267+
&HOST => true,
268+
&MAX_FORWARDS => true,
269+
&SET_COOKIE => true,
270+
&TRAILER => true,
271+
&TRANSFER_ENCODING => true,
272+
_ => false,
273+
}
274+
}
275+
276+
fn allowed_trailer_field_map(allowed_trailer_fields: &Vec<HeaderValue>) -> HashMap<String, ()> {
277+
let mut trailer_map = HashMap::new();
278+
279+
for header_value in allowed_trailer_fields {
280+
if let Ok(header_str) = header_value.to_str() {
281+
let items: Vec<&str> = header_str.split(',').map(|item| item.trim()).collect();
282+
283+
for item in items {
284+
trailer_map.entry(item.to_string()).or_insert(());
285+
}
286+
}
287+
}
288+
289+
trailer_map
290+
}
291+
184292
impl<B> Buf for EncodedBuf<B>
185293
where
186294
B: Buf,
@@ -192,6 +300,7 @@ where
192300
BufKind::Limited(ref b) => b.remaining(),
193301
BufKind::Chunked(ref b) => b.remaining(),
194302
BufKind::ChunkedEnd(ref b) => b.remaining(),
303+
BufKind::Trailers(ref b) => b.remaining(),
195304
}
196305
}
197306

@@ -202,6 +311,7 @@ where
202311
BufKind::Limited(ref b) => b.chunk(),
203312
BufKind::Chunked(ref b) => b.chunk(),
204313
BufKind::ChunkedEnd(ref b) => b.chunk(),
314+
BufKind::Trailers(ref b) => b.chunk(),
205315
}
206316
}
207317

@@ -212,6 +322,7 @@ where
212322
BufKind::Limited(ref mut b) => b.advance(cnt),
213323
BufKind::Chunked(ref mut b) => b.advance(cnt),
214324
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
325+
BufKind::Trailers(ref mut b) => b.advance(cnt),
215326
}
216327
}
217328

@@ -222,6 +333,7 @@ where
222333
BufKind::Limited(ref b) => b.chunks_vectored(dst),
223334
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
224335
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
336+
BufKind::Trailers(ref b) => b.chunks_vectored(dst),
225337
}
226338
}
227339
}

0 commit comments

Comments
 (0)