Skip to content

Commit 1f1b1eb

Browse files
authored
Merge pull request tree-sitter#1797 from tree-sitter/sibling-patterns-inside-errors
Fix performance pitfall when matching "non-rooted" patterns in the presence of errors
2 parents 01df16c + 79eaa68 commit 1f1b1eb

File tree

6 files changed

+209
-18
lines changed

6 files changed

+209
-18
lines changed

cli/src/tests/query_test.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,66 @@ fn test_query_matches_with_too_many_permutations_to_track() {
16871687
});
16881688
}
16891689

1690+
#[test]
1691+
fn test_query_sibling_patterns_dont_match_children_of_an_error() {
1692+
allocations::record(|| {
1693+
let language = get_language("rust");
1694+
let query = Query::new(
1695+
language,
1696+
r#"
1697+
("{" @open "}" @close)
1698+
1699+
[
1700+
(line_comment)
1701+
(block_comment)
1702+
] @comment
1703+
1704+
("<" @first "<" @second)
1705+
"#,
1706+
)
1707+
.unwrap();
1708+
1709+
// Most of the document will fail to parse, resulting in a
1710+
// large number of tokens that are *direct* children of an
1711+
// ERROR node.
1712+
//
1713+
// These children should still match, unless they are part
1714+
// of a "non-rooted" pattern, in which there are multiple
1715+
// top-level sibling nodes. Those patterns should not match
1716+
// directly inside of an error node, because the contents of
1717+
// an error node are not syntactically well-structured, so we
1718+
// would get many spurious matches.
1719+
let source = "
1720+
fn a() {}
1721+
1722+
<<<<<<<<<< add pub b fn () {}
1723+
// comment 1
1724+
pub fn b() {
1725+
/* comment 2 */
1726+
==========
1727+
pub fn c() {
1728+
// comment 3
1729+
>>>>>>>>>> add pub c fn () {}
1730+
}
1731+
";
1732+
1733+
let mut parser = Parser::new();
1734+
parser.set_language(language).unwrap();
1735+
let tree = parser.parse(&source, None).unwrap();
1736+
let mut cursor = QueryCursor::new();
1737+
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
1738+
assert_eq!(
1739+
collect_matches(matches, &query, source),
1740+
&[
1741+
(0, vec![("open", "{"), ("close", "}")]),
1742+
(1, vec![("comment", "// comment 1")]),
1743+
(1, vec![("comment", "/* comment 2 */")]),
1744+
(1, vec![("comment", "// comment 3")]),
1745+
],
1746+
);
1747+
});
1748+
}
1749+
16901750
#[test]
16911751
fn test_query_matches_with_alternatives_and_too_many_permutations_to_track() {
16921752
allocations::record(|| {
@@ -3919,6 +3979,97 @@ fn test_query_is_pattern_guaranteed_at_step() {
39193979
});
39203980
}
39213981

3982+
#[test]
3983+
fn test_query_is_pattern_rooted() {
3984+
struct Row {
3985+
description: &'static str,
3986+
pattern: &'static str,
3987+
is_rooted: bool,
3988+
}
3989+
3990+
let rows = [
3991+
Row {
3992+
description: "simple token",
3993+
pattern: r#"(identifier)"#,
3994+
is_rooted: true,
3995+
},
3996+
Row {
3997+
description: "simple non-terminal",
3998+
pattern: r#"(function_definition name: (identifier))"#,
3999+
is_rooted: true,
4000+
},
4001+
Row {
4002+
description: "alternative of many tokens",
4003+
pattern: r#"["if" "def" (identifier) (comment)]"#,
4004+
is_rooted: true,
4005+
},
4006+
Row {
4007+
description: "alternative of many non-terminals",
4008+
pattern: r#"[
4009+
(function_definition name: (identifier))
4010+
(class_definition name: (identifier))
4011+
(block)
4012+
]"#,
4013+
is_rooted: true,
4014+
},
4015+
Row {
4016+
description: "two siblings",
4017+
pattern: r#"("{" "}")"#,
4018+
is_rooted: false,
4019+
},
4020+
Row {
4021+
description: "top-level repetition",
4022+
pattern: r#"(comment)*"#,
4023+
is_rooted: false,
4024+
},
4025+
Row {
4026+
description: "alternative where one option has two siblings",
4027+
pattern: r#"[
4028+
(block)
4029+
(class_definition)
4030+
("(" ")")
4031+
(function_definition)
4032+
]"#,
4033+
is_rooted: false,
4034+
},
4035+
Row {
4036+
description: "alternative where one option has a top-level repetition",
4037+
pattern: r#"[
4038+
(block)
4039+
(class_definition)
4040+
(comment)*
4041+
(function_definition)
4042+
]"#,
4043+
is_rooted: false,
4044+
},
4045+
];
4046+
4047+
allocations::record(|| {
4048+
eprintln!("");
4049+
4050+
let language = get_language("python");
4051+
for row in &rows {
4052+
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
4053+
if !row.description.contains(filter.as_str()) {
4054+
continue;
4055+
}
4056+
}
4057+
eprintln!(" query example: {:?}", row.description);
4058+
let query = Query::new(language, row.pattern).unwrap();
4059+
assert_eq!(
4060+
query.is_pattern_rooted(0),
4061+
row.is_rooted,
4062+
"Description: {}, Pattern: {:?}",
4063+
row.description,
4064+
row.pattern
4065+
.split_ascii_whitespace()
4066+
.collect::<Vec<_>>()
4067+
.join(" "),
4068+
)
4069+
}
4070+
});
4071+
}
4072+
39224073
#[test]
39234074
fn test_capture_quantifiers() {
39244075
struct Row {

lib/binding_rust/bindings.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,9 @@ extern "C" {
658658
length: *mut u32,
659659
) -> *const TSQueryPredicateStep;
660660
}
661+
extern "C" {
662+
pub fn ts_query_is_pattern_rooted(self_: *const TSQuery, pattern_index: u32) -> bool;
663+
}
661664
extern "C" {
662665
pub fn ts_query_is_pattern_guaranteed_at_step(self_: *const TSQuery, byte_offset: u32) -> bool;
663666
}

lib/binding_rust/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,12 @@ impl Query {
16991699
unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) }
17001700
}
17011701

1702+
/// Check if a given pattern within a query has a single root node.
1703+
#[doc(alias = "ts_query_is_pattern_guaranteed_at_step")]
1704+
pub fn is_pattern_rooted(&self, index: usize) -> bool {
1705+
unsafe { ffi::ts_query_is_pattern_rooted(self.ptr.as_ptr(), index as u32) }
1706+
}
1707+
17021708
/// Check if a given step in a query is 'definite'.
17031709
///
17041710
/// A query step is 'definite' if its parent pattern will be guaranteed to match

lib/include/tree_sitter/api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,11 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern(
733733
uint32_t *length
734734
);
735735

736+
bool ts_query_is_pattern_rooted(
737+
const TSQuery *self,
738+
uint32_t pattern_index
739+
);
740+
736741
bool ts_query_is_pattern_guaranteed_at_step(
737742
const TSQuery *self,
738743
uint32_t byte_offset

lib/src/query.c

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ static TSQueryError ts_query__parse_pattern(
21012101
return e;
21022102
}
21032103

2104-
if(start_index == starting_step_index) {
2104+
if (start_index == starting_step_index) {
21052105
capture_quantifiers_replace(capture_quantifiers, &branch_capture_quantifiers);
21062106
} else {
21072107
capture_quantifiers_join_all(capture_quantifiers, &branch_capture_quantifiers);
@@ -2167,10 +2167,10 @@ static TSQueryError ts_query__parse_pattern(
21672167
}
21682168

21692169
capture_quantifiers_add_all(capture_quantifiers, &child_capture_quantifiers);
2170-
2171-
child_is_immediate = false;
21722170
capture_quantifiers_clear(&child_capture_quantifiers);
2171+
child_is_immediate = false;
21732172
}
2173+
21742174
capture_quantifiers_delete(&child_capture_quantifiers);
21752175
}
21762176

@@ -2630,11 +2630,13 @@ TSQuery *ts_query_new(
26302630

26312631
// Determine whether the pattern has a single root node. This affects
26322632
// decisions about whether or not to start matching the pattern when
2633-
// a query cursor has a range restriction.
2633+
// a query cursor has a range restriction or when immediately within an
2634+
// error node.
26342635
uint32_t start_depth = step->depth;
26352636
bool is_rooted = start_depth == 0;
26362637
for (uint32_t step_index = start_step_index + 1; step_index < self->steps.size; step_index++) {
26372638
QueryStep *step = &self->steps.contents[step_index];
2639+
if (step->is_dead_end) break;
26382640
if (step->depth == start_depth) {
26392641
is_rooted = false;
26402642
break;
@@ -2751,6 +2753,19 @@ uint32_t ts_query_start_byte_for_pattern(
27512753
return self->patterns.contents[pattern_index].start_byte;
27522754
}
27532755

2756+
bool ts_query_is_pattern_rooted(
2757+
const TSQuery *self,
2758+
uint32_t pattern_index
2759+
) {
2760+
for (unsigned i = 0; i < self->pattern_map.size; i++) {
2761+
PatternEntry *entry = &self->pattern_map.contents[i];
2762+
if (entry->pattern_index == pattern_index) {
2763+
if (!entry->is_rooted) return false;
2764+
}
2765+
}
2766+
return true;
2767+
}
2768+
27542769
bool ts_query_is_pattern_guaranteed_at_step(
27552770
const TSQuery *self,
27562771
uint32_t byte_offset
@@ -3324,20 +3339,28 @@ static inline bool ts_query_cursor__advance(
33243339
point_gt(ts_node_end_point(parent_node), self->start_point) &&
33253340
point_lt(ts_node_start_point(parent_node), self->end_point)
33263341
);
3342+
bool node_is_error = symbol == ts_builtin_sym_error;
3343+
bool parent_is_error =
3344+
!ts_node_is_null(parent_node) &&
3345+
ts_node_symbol(parent_node) == ts_builtin_sym_error;
33273346

33283347
// Add new states for any patterns whose root node is a wildcard.
3329-
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
3330-
PatternEntry *pattern = &self->query->pattern_map.contents[i];
3348+
if (!node_is_error) {
3349+
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
3350+
PatternEntry *pattern = &self->query->pattern_map.contents[i];
33313351

3332-
// If this node matches the first step of the pattern, then add a new
3333-
// state at the start of this pattern.
3334-
QueryStep *step = &self->query->steps.contents[pattern->step_index];
3335-
if (
3336-
(pattern->is_rooted ? node_intersects_range : parent_intersects_range) &&
3337-
(!step->field || field_id == step->field) &&
3338-
(!step->supertype_symbol || supertype_count > 0)
3339-
) {
3340-
ts_query_cursor__add_state(self, pattern);
3352+
// If this node matches the first step of the pattern, then add a new
3353+
// state at the start of this pattern.
3354+
QueryStep *step = &self->query->steps.contents[pattern->step_index];
3355+
if (
3356+
(pattern->is_rooted ?
3357+
node_intersects_range :
3358+
(parent_intersects_range && !parent_is_error)) &&
3359+
(!step->field || field_id == step->field) &&
3360+
(!step->supertype_symbol || supertype_count > 0)
3361+
) {
3362+
ts_query_cursor__add_state(self, pattern);
3363+
}
33413364
}
33423365
}
33433366

@@ -3351,7 +3374,9 @@ static inline bool ts_query_cursor__advance(
33513374
// If this node matches the first step of the pattern, then add a new
33523375
// state at the start of this pattern.
33533376
if (
3354-
(pattern->is_rooted ? node_intersects_range : parent_intersects_range) &&
3377+
(pattern->is_rooted ?
3378+
node_intersects_range :
3379+
(parent_intersects_range && !parent_is_error)) &&
33553380
(!step->field || field_id == step->field)
33563381
) {
33573382
ts_query_cursor__add_state(self, pattern);
@@ -3381,7 +3406,7 @@ static inline bool ts_query_cursor__advance(
33813406
// pattern.
33823407
bool node_does_match = false;
33833408
if (step->symbol == WILDCARD_SYMBOL) {
3384-
node_does_match = is_named || !step->is_named;
3409+
node_does_match = !node_is_error && (is_named || !step->is_named);
33853410
} else {
33863411
node_does_match = symbol == step->symbol;
33873412
}

lib/src/stack.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,8 @@ bool ts_stack_print_dot_graph(Stack *self, const TSLanguage *language, FILE *f)
777777
);
778778

779779
if (head->summary) {
780-
fprintf(f, "\nsummary_size: %u", head->summary->size);
780+
fprintf(f, "\nsummary:");
781+
for (uint32_t j = 0; j < head->summary->size; j++) fprintf(f, " %u", head->summary->contents[j].state);
781782
}
782783

783784
if (head->last_external_token.ptr) {

0 commit comments

Comments
 (0)