diff --git a/Cargo.toml b/Cargo.toml index 189ae19..256897f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,10 @@ anyhow = "1.0.81" clap = { version = "4.5.4", features = ["derive"] } csv = "1.3.0" serde = { version = "1.0.197", features = ["derive"] } -tree-sitter = "0.23" -tree-sitter-language = "0.1.0" -tree-sitter-tlaplus = "1.4.0" +streaming-iterator = "0.1.9" +tree-sitter = "0.24.3" +tree-sitter-language = "0.1.2" +tree-sitter-tlaplus = "1.5.0" [dev-dependencies] glob = "0.3.1" diff --git a/src/lib.rs b/src/lib.rs index a485182..8ddbf77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use crate::strmeasure::*; use serde::{Deserialize, Deserializer}; use std::ops::Range; +use streaming_iterator::StreamingIterator; use tree_sitter::{Node, Parser, Query, QueryCursor, Tree, TreeCursor}; pub enum Mode { @@ -334,7 +335,9 @@ impl JList { fn mark_jlists(tree: &Tree, query_cursor: &mut QueryCursor, tla_lines: &mut [TlaLine]) { let mut tree_cursor: TreeCursor = tree.walk(); - for capture in query_cursor.matches(&JList::query(), tree.root_node(), "".as_bytes()) { + let query = JList::query(); + let mut captures = query_cursor.matches(&query, tree.root_node(), "".as_bytes()); + while let Some(capture) = captures.next() { let node = capture.captures[0].node; let start_line = node.start_position().row; let line = &mut tla_lines[start_line]; @@ -360,11 +363,9 @@ fn mark_jlists(tree: &Tree, query_cursor: &mut QueryCursor, tla_lines: &mut [Tla line.jlists.push(jlist); } - for capture in query_cursor.matches( - &JList::terminating_infix_op_query(), - tree.root_node(), - "".as_bytes(), - ) { + let query = JList::terminating_infix_op_query(); + let mut captures = query_cursor.matches(&query, tree.root_node(), "".as_bytes()); + while let Some(capture) = captures.next() { let infix_op_node = capture.captures[0].node; let jlist_node = infix_op_node.child_by_field_name("lhs").unwrap(); let jlist_start_line_index = jlist_node.start_position().row; @@ -406,7 +407,8 @@ fn mark_symbols(tree: &Tree, cursor: &mut QueryCursor, tla_lines: &mut [TlaLine] .join(""); let query = Query::new(&tree_sitter_tlaplus::LANGUAGE.into(), queries).unwrap(); - for capture in cursor.matches(&query, tree.root_node(), "".as_bytes()) { + let mut captures = cursor.matches(&query, tree.root_node(), "".as_bytes()); + while let Some(capture) = captures.next() { let capture = capture.captures[0]; let mapping = &mappings[capture.index as usize]; let start_position = capture.node.start_position(); @@ -544,12 +546,9 @@ mod tests { .collect::>() .join(""); let query = Query::new(&tree_sitter_tlaplus::LANGUAGE.into(), &queries).unwrap(); - assert_eq!( - 0, - cursor - .matches(&query, tree.root_node(), "".as_bytes()) - .count() - ); + assert!(cursor + .matches(&query, tree.root_node(), "".as_bytes()) + .is_done()); } fn unwrap_conversion(input: Result) -> String {