Skip to content

Commit 6b1ebd3

Browse files
pr0meamaanq
andauthored
feat!: implement StreamingIterator instead of Iterator for QueryMatches and QueryCaptures
This fixes UB when either `QueryMatches` or `QueryCaptures` had collect called on it. Co-authored-by: Amaan Qureshi <[email protected]>
1 parent 12007d3 commit 6b1ebd3

File tree

14 files changed

+271
-105
lines changed

14 files changed

+271
-105
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ serde_derive = "1.0.210"
8484
serde_json = { version = "1.0.128", features = ["preserve_order"] }
8585
similar = "2.6.0"
8686
smallbitvec = "2.5.3"
87+
streaming-iterator = "0.1.9"
8788
tempfile = "3.12.0"
8889
thiserror = "1.0.64"
8990
tiny_http = "0.12.0"

cli/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ serde_derive.workspace = true
5252
serde_json.workspace = true
5353
similar.workspace = true
5454
smallbitvec.workspace = true
55+
streaming-iterator.workspace = true
5556
tiny_http.workspace = true
5657
walkdir.workspace = true
5758
wasmparser.workspace = true

cli/src/query.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::{
88

99
use anstyle::AnsiColor;
1010
use anyhow::{Context, Result};
11+
use streaming_iterator::StreamingIterator;
1112
use tree_sitter::{Language, Parser, Point, Query, QueryCursor};
1213

1314
use crate::{
@@ -58,10 +59,10 @@ pub fn query_files_at_paths(
5859

5960
let start = Instant::now();
6061
if ordered_captures {
61-
for (mat, capture_index) in
62-
query_cursor.captures(&query, tree.root_node(), source_code.as_slice())
63-
{
64-
let capture = mat.captures[capture_index];
62+
let mut captures =
63+
query_cursor.captures(&query, tree.root_node(), source_code.as_slice());
64+
while let Some((mat, capture_index)) = captures.next() {
65+
let capture = mat.captures[*capture_index];
6566
let capture_name = &query.capture_names()[capture.index as usize];
6667
if !quiet && !should_test {
6768
writeln!(
@@ -81,7 +82,9 @@ pub fn query_files_at_paths(
8182
});
8283
}
8384
} else {
84-
for m in query_cursor.matches(&query, tree.root_node(), source_code.as_slice()) {
85+
let mut matches =
86+
query_cursor.matches(&query, tree.root_node(), source_code.as_slice());
87+
while let Some(m) = matches.next() {
8588
if !quiet && !should_test {
8689
writeln!(&mut stdout, " pattern: {}", m.pattern_index)?;
8790
}

cli/src/tests/helpers/query_helpers.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::{cmp::Ordering, fmt::Write, ops::Range};
22

33
use rand::prelude::Rng;
4+
use streaming_iterator::{IntoStreamingIterator, StreamingIterator};
45
use tree_sitter::{
56
Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryMatch, Tree, TreeCursor,
67
};
@@ -324,39 +325,39 @@ pub fn assert_query_matches(
324325
}
325326

326327
pub fn collect_matches<'a>(
327-
matches: impl Iterator<Item = QueryMatch<'a, 'a>>,
328+
mut matches: impl StreamingIterator<Item = QueryMatch<'a, 'a>>,
328329
query: &'a Query,
329330
source: &'a str,
330331
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
331-
matches
332-
.map(|m| {
333-
(
334-
m.pattern_index,
335-
format_captures(m.captures.iter().copied(), query, source),
336-
)
337-
})
338-
.collect()
332+
let mut result = Vec::new();
333+
while let Some(m) = matches.next() {
334+
result.push((
335+
m.pattern_index,
336+
format_captures(m.captures.iter().into_streaming_iter_ref(), query, source),
337+
));
338+
}
339+
result
339340
}
340341

341342
pub fn collect_captures<'a>(
342-
captures: impl Iterator<Item = (QueryMatch<'a, 'a>, usize)>,
343+
captures: impl StreamingIterator<Item = (QueryMatch<'a, 'a>, usize)>,
343344
query: &'a Query,
344345
source: &'a str,
345346
) -> Vec<(&'a str, &'a str)> {
346-
format_captures(captures.map(|(m, i)| m.captures[i]), query, source)
347+
format_captures(captures.map(|(m, i)| m.captures[*i]), query, source)
347348
}
348349

349350
fn format_captures<'a>(
350-
captures: impl Iterator<Item = QueryCapture<'a>>,
351+
mut captures: impl StreamingIterator<Item = QueryCapture<'a>>,
351352
query: &'a Query,
352353
source: &'a str,
353354
) -> Vec<(&'a str, &'a str)> {
354-
captures
355-
.map(|capture| {
356-
(
357-
query.capture_names()[capture.index as usize],
358-
capture.node.utf8_text(source.as_bytes()).unwrap(),
359-
)
360-
})
361-
.collect()
355+
let mut result = Vec::new();
356+
while let Some(capture) = captures.next() {
357+
result.push((
358+
query.capture_names()[capture.index as usize],
359+
capture.node.utf8_text(source.as_bytes()).unwrap(),
360+
));
361+
}
362+
result
362363
}

cli/src/tests/query_test.rs

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::{env, fmt::Write};
33
use indoc::indoc;
44
use lazy_static::lazy_static;
55
use rand::{prelude::StdRng, SeedableRng};
6+
use streaming_iterator::StreamingIterator;
67
use tree_sitter::{
78
CaptureQuantifier, Language, Node, Parser, Point, Query, QueryCursor, QueryError,
89
QueryErrorKind, QueryPredicate, QueryPredicateArg, QueryProperty,
@@ -2267,29 +2268,50 @@ fn test_query_matches_with_wildcard_at_root_intersecting_byte_range() {
22672268

22682269
// After the first line of the class definition
22692270
let offset = source.find("A:").unwrap() + 2;
2270-
let matches = cursor
2271-
.set_byte_range(offset..offset)
2272-
.matches(&query, tree.root_node(), source.as_bytes())
2273-
.map(|mat| mat.captures[0].node.kind())
2274-
.collect::<Vec<_>>();
2271+
let mut matches = Vec::new();
2272+
let mut match_iter = cursor.set_byte_range(offset..offset).matches(
2273+
&query,
2274+
tree.root_node(),
2275+
source.as_bytes(),
2276+
);
2277+
2278+
while let Some(mat) = match_iter.next() {
2279+
if let Some(capture) = mat.captures.first() {
2280+
matches.push(capture.node.kind());
2281+
}
2282+
}
22752283
assert_eq!(matches, &["class_definition"]);
22762284

22772285
// After the first line of the function definition
22782286
let offset = source.find("b():").unwrap() + 4;
2279-
let matches = cursor
2280-
.set_byte_range(offset..offset)
2281-
.matches(&query, tree.root_node(), source.as_bytes())
2282-
.map(|mat| mat.captures[0].node.kind())
2283-
.collect::<Vec<_>>();
2287+
let mut matches = Vec::new();
2288+
let mut match_iter = cursor.set_byte_range(offset..offset).matches(
2289+
&query,
2290+
tree.root_node(),
2291+
source.as_bytes(),
2292+
);
2293+
2294+
while let Some(mat) = match_iter.next() {
2295+
if let Some(capture) = mat.captures.first() {
2296+
matches.push(capture.node.kind());
2297+
}
2298+
}
22842299
assert_eq!(matches, &["class_definition", "function_definition"]);
22852300

22862301
// After the first line of the if statement
22872302
let offset = source.find("c:").unwrap() + 2;
2288-
let matches = cursor
2289-
.set_byte_range(offset..offset)
2290-
.matches(&query, tree.root_node(), source.as_bytes())
2291-
.map(|mat| mat.captures[0].node.kind())
2292-
.collect::<Vec<_>>();
2303+
let mut matches = Vec::new();
2304+
let mut match_iter = cursor.set_byte_range(offset..offset).matches(
2305+
&query,
2306+
tree.root_node(),
2307+
source.as_bytes(),
2308+
);
2309+
2310+
while let Some(mat) = match_iter.next() {
2311+
if let Some(capture) = mat.captures.first() {
2312+
matches.push(capture.node.kind());
2313+
}
2314+
}
22932315
assert_eq!(
22942316
matches,
22952317
&["class_definition", "function_definition", "if_statement"]
@@ -2342,8 +2364,9 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() {
23422364

23432365
// Retrieve some captures
23442366
let mut results = Vec::new();
2345-
for (mat, capture_ix) in captures.by_ref().take(5) {
2346-
let capture = mat.captures[capture_ix];
2367+
let mut first_five = captures.by_ref().take(5);
2368+
while let Some((mat, capture_ix)) = first_five.next() {
2369+
let capture = mat.captures[*capture_ix];
23472370
results.push((
23482371
query.capture_names()[capture.index as usize],
23492372
&source[capture.node.byte_range()],
@@ -2365,8 +2388,8 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() {
23652388
// intersect the range.
23662389
results.clear();
23672390
captures.set_byte_range(source.find("Ok").unwrap()..source.len());
2368-
for (mat, capture_ix) in captures {
2369-
let capture = mat.captures[capture_ix];
2391+
while let Some((mat, capture_ix)) = captures.next() {
2392+
let capture = mat.captures[*capture_ix];
23702393
results.push((
23712394
query.capture_names()[capture.index as usize],
23722395
&source[capture.node.byte_range()],
@@ -2602,21 +2625,23 @@ fn test_query_matches_with_captured_wildcard_at_root() {
26022625
parser.set_language(&language).unwrap();
26032626
let tree = parser.parse(source, None).unwrap();
26042627

2605-
let match_capture_names_and_rows = cursor
2606-
.matches(&query, tree.root_node(), source.as_bytes())
2607-
.map(|m| {
2608-
m.captures
2609-
.iter()
2610-
.map(|c| {
2611-
(
2612-
query.capture_names()[c.index as usize],
2613-
c.node.kind(),
2614-
c.node.start_position().row,
2615-
)
2616-
})
2617-
.collect::<Vec<_>>()
2618-
})
2619-
.collect::<Vec<_>>();
2628+
let mut match_capture_names_and_rows = Vec::new();
2629+
let mut match_iter = cursor.matches(&query, tree.root_node(), source.as_bytes());
2630+
2631+
while let Some(m) = match_iter.next() {
2632+
let captures = m
2633+
.captures
2634+
.iter()
2635+
.map(|c| {
2636+
(
2637+
query.capture_names()[c.index as usize],
2638+
c.node.kind(),
2639+
c.node.start_position().row,
2640+
)
2641+
})
2642+
.collect::<Vec<_>>();
2643+
match_capture_names_and_rows.push(captures);
2644+
}
26202645

26212646
assert_eq!(
26222647
match_capture_names_and_rows,
@@ -3460,9 +3485,13 @@ fn test_query_captures_with_matches_removed() {
34603485
let mut cursor = QueryCursor::new();
34613486

34623487
let mut captured_strings = Vec::new();
3463-
for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) {
3464-
let capture = m.captures[i];
3488+
3489+
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
3490+
while let Some((m, i)) = captures.next() {
3491+
println!("captured: {:?}, {}", m, i);
3492+
let capture = m.captures[*i];
34653493
let text = capture.node.utf8_text(source.as_bytes()).unwrap();
3494+
println!("captured: {:?}", text);
34663495
if text == "a" {
34673496
m.remove();
34683497
continue;
@@ -3504,8 +3533,9 @@ fn test_query_captures_with_matches_removed_before_they_finish() {
35043533
let mut cursor = QueryCursor::new();
35053534

35063535
let mut captured_strings = Vec::new();
3507-
for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) {
3508-
let capture = m.captures[i];
3536+
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
3537+
while let Some((m, i)) = captures.next() {
3538+
let capture = m.captures[*i];
35093539
let text = capture.node.utf8_text(source.as_bytes()).unwrap();
35103540
if text == "as" {
35113541
m.remove();
@@ -3912,21 +3942,24 @@ fn test_query_random() {
39123942
panic!("failed to build query for pattern {pattern} - {e}. seed: {seed}");
39133943
}
39143944
};
3915-
let mut actual_matches = cursor
3916-
.matches(
3917-
&query,
3918-
test_tree.root_node(),
3919-
include_bytes!("parser_test.rs").as_ref(),
3920-
)
3921-
.map(|mat| Match {
3945+
let mut actual_matches = Vec::new();
3946+
let mut match_iter = cursor.matches(
3947+
&query,
3948+
test_tree.root_node(),
3949+
include_bytes!("parser_test.rs").as_ref(),
3950+
);
3951+
3952+
while let Some(mat) = match_iter.next() {
3953+
let transformed_match = Match {
39223954
last_node: None,
39233955
captures: mat
39243956
.captures
39253957
.iter()
39263958
.map(|c| (query.capture_names()[c.index as usize], c.node))
39273959
.collect::<Vec<_>>(),
3928-
})
3929-
.collect::<Vec<_>>();
3960+
};
3961+
actual_matches.push(transformed_match);
3962+
}
39303963

39313964
// actual_matches.sort_unstable();
39323965
actual_matches.dedup();
@@ -4908,12 +4941,12 @@ fn test_consecutive_zero_or_modifiers() {
49084941
assert!(matches.next().is_some());
49094942

49104943
let mut cursor = QueryCursor::new();
4911-
let matches = cursor.matches(&query, three_tree.root_node(), three_source.as_bytes());
4944+
let mut matches = cursor.matches(&query, three_tree.root_node(), three_source.as_bytes());
49124945

49134946
let mut len_3 = false;
49144947
let mut len_1 = false;
49154948

4916-
for m in matches {
4949+
while let Some(m) = matches.next() {
49174950
if m.captures.len() == 3 {
49184951
len_3 = true;
49194952
}

cli/src/tests/text_provider_test.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{iter, sync::Arc};
22

3+
use streaming_iterator::StreamingIterator;
34
use tree_sitter::{Language, Node, Parser, Point, Query, QueryCursor, TextProvider, Tree};
45

56
use crate::tests::helpers::fixtures::get_language;
@@ -30,8 +31,8 @@ fn tree_query<I: AsRef<[u8]>>(tree: &Tree, text: impl TextProvider<I>, language:
3031
let mut cursor = QueryCursor::new();
3132
let mut captures = cursor.captures(&query, tree.root_node(), text);
3233
let (match_, idx) = captures.next().unwrap();
33-
let capture = match_.captures[idx];
34-
assert_eq!(capture.index as usize, idx);
34+
let capture = match_.captures[*idx];
35+
assert_eq!(capture.index as usize, *idx);
3536
assert_eq!("comment", capture.node.kind());
3637
}
3738

0 commit comments

Comments
 (0)