sqlparser/ast/
visitor.rs1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35 /// Visit this node with the provided [`Visitor`].
36 ///
37 /// Implementations should call the appropriate visitor hooks to traverse
38 /// child nodes and return a `ControlFlow` value to allow early exit.
39 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
40}
41
42/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
43/// recursively visiting parsed SQL statements.
44///
45/// # Note
46///
47/// This trait should be automatically derived for sqlparser AST nodes
48/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
49///
50/// ```text
51/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
52/// ```
53pub trait VisitMut {
54 /// Mutably visit this node with the provided [`VisitorMut`].
55 ///
56 /// Implementations should call the appropriate mutable visitor hooks to
57 /// traverse and allow in-place mutation of child nodes. Returning a
58 /// `ControlFlow` value permits early termination of the traversal.
59 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
60}
61
62impl<T: Visit> Visit for Option<T> {
63 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64 if let Some(s) = self {
65 s.visit(visitor)?;
66 }
67 ControlFlow::Continue(())
68 }
69}
70
71impl<T: Visit> Visit for Vec<T> {
72 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73 for v in self {
74 v.visit(visitor)?;
75 }
76 ControlFlow::Continue(())
77 }
78}
79
80impl<T: Visit> Visit for Box<T> {
81 fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
82 T::visit(self, visitor)
83 }
84}
85
86impl<T: VisitMut> VisitMut for Option<T> {
87 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88 if let Some(s) = self {
89 s.visit(visitor)?;
90 }
91 ControlFlow::Continue(())
92 }
93}
94
95impl<T: VisitMut> VisitMut for Vec<T> {
96 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97 for v in self {
98 v.visit(visitor)?;
99 }
100 ControlFlow::Continue(())
101 }
102}
103
104impl<T: VisitMut> VisitMut for Box<T> {
105 fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
106 T::visit(self, visitor)
107 }
108}
109
110macro_rules! visit_noop {
111 ($($t:ty),+) => {
112 $(impl Visit for $t {
113 fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
114 ControlFlow::Continue(())
115 }
116 })+
117 $(impl VisitMut for $t {
118 fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
119 ControlFlow::Continue(())
120 }
121 })+
122 };
123}
124
125visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
126
127#[cfg(feature = "bigdecimal")]
128visit_noop!(bigdecimal::BigDecimal);
129
130/// A visitor that can be used to walk an AST tree.
131///
132/// `pre_visit_` methods are invoked before visiting all children of the
133/// node and `post_visit_` methods are invoked after visiting all
134/// children of the node.
135///
136/// # See also
137///
138/// These methods provide a more concise way of visiting nodes of a certain type:
139/// * [visit_relations]
140/// * [visit_expressions]
141/// * [visit_statements]
142///
143/// # Example
144/// ```
145/// # use sqlparser::parser::Parser;
146/// # use sqlparser::dialect::GenericDialect;
147/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
148/// # use core::ops::ControlFlow;
149/// // A structure that records statements and relations
150/// #[derive(Default)]
151/// struct V {
152/// visited: Vec<String>,
153/// }
154///
155/// // Visit relations and exprs before children are visited (depth first walk)
156/// // Note you can also visit statements and visit exprs after children have been visited
157/// impl Visitor for V {
158/// type Break = ();
159///
160/// fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
161/// self.visited.push(format!("PRE: RELATION: {}", relation));
162/// ControlFlow::Continue(())
163/// }
164///
165/// fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
166/// self.visited.push(format!("PRE: EXPR: {}", expr));
167/// ControlFlow::Continue(())
168/// }
169/// }
170///
171/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
172/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
173/// .unwrap();
174///
175/// // Drive the visitor through the AST
176/// let mut visitor = V::default();
177/// statements.visit(&mut visitor);
178///
179/// // The visitor has visited statements and expressions in pre-traversal order
180/// let expected : Vec<_> = [
181/// "PRE: EXPR: a",
182/// "PRE: RELATION: foo",
183/// "PRE: EXPR: x IN (SELECT y FROM bar)",
184/// "PRE: EXPR: x",
185/// "PRE: EXPR: y",
186/// "PRE: RELATION: bar",
187/// ]
188/// .into_iter().map(|s| s.to_string()).collect();
189///
190/// assert_eq!(visitor.visited, expected);
191/// ```
192pub trait Visitor {
193 /// Type returned when the recursion returns early.
194 ///
195 /// Important note: The `Break` type should be kept as small as possible to prevent
196 /// stack overflow during recursion. If you need to return an error, consider
197 /// boxing it with `Box` to minimize stack usage.
198 type Break;
199
200 /// Invoked for any queries that appear in the AST before visiting children
201 fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
202 ControlFlow::Continue(())
203 }
204
205 /// Invoked for any queries that appear in the AST after visiting children
206 fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
207 ControlFlow::Continue(())
208 }
209
210 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
211 fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
212 ControlFlow::Continue(())
213 }
214
215 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
216 fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
217 ControlFlow::Continue(())
218 }
219
220 /// Invoked for any table factors that appear in the AST before visiting children
221 fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
222 ControlFlow::Continue(())
223 }
224
225 /// Invoked for any table factors that appear in the AST after visiting children
226 fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
227 ControlFlow::Continue(())
228 }
229
230 /// Invoked for any expressions that appear in the AST before visiting children
231 fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
232 ControlFlow::Continue(())
233 }
234
235 /// Invoked for any expressions that appear in the AST
236 fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
237 ControlFlow::Continue(())
238 }
239
240 /// Invoked for any statements that appear in the AST before visiting children
241 fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
242 ControlFlow::Continue(())
243 }
244
245 /// Invoked for any statements that appear in the AST after visiting children
246 fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
247 ControlFlow::Continue(())
248 }
249
250 /// Invoked for any Value that appear in the AST before visiting children
251 fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
252 ControlFlow::Continue(())
253 }
254
255 /// Invoked for any Value that appear in the AST after visiting children
256 fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
257 ControlFlow::Continue(())
258 }
259}
260
261/// A visitor that can be used to mutate an AST tree.
262///
263/// `pre_visit_` methods are invoked before visiting all children of the
264/// node and `post_visit_` methods are invoked after visiting all
265/// children of the node.
266///
267/// # See also
268///
269/// These methods provide a more concise way of visiting nodes of a certain type:
270/// * [visit_relations_mut]
271/// * [visit_expressions_mut]
272/// * [visit_statements_mut]
273///
274/// # Example
275/// ```
276/// # use sqlparser::parser::Parser;
277/// # use sqlparser::dialect::GenericDialect;
278/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
279/// # use core::ops::ControlFlow;
280///
281/// // A visitor that replaces "to_replace" with "replaced" in all expressions
282/// struct Replacer;
283///
284/// // Visit each expression after its children have been visited
285/// impl VisitorMut for Replacer {
286/// type Break = ();
287///
288/// fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
289/// if let Expr::Identifier(Ident{ value, ..}) = expr {
290/// *value = value.replace("to_replace", "replaced")
291/// }
292/// ControlFlow::Continue(())
293/// }
294/// }
295///
296/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
297/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
298///
299/// // Drive the visitor through the AST
300/// statements.visit(&mut Replacer);
301///
302/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
303/// ```
304pub trait VisitorMut {
305 /// Type returned when the recursion returns early.
306 ///
307 /// Important note: The `Break` type should be kept as small as possible to prevent
308 /// stack overflow during recursion. If you need to return an error, consider
309 /// boxing it with `Box` to minimize stack usage.
310 type Break;
311
312 /// Invoked for any queries that appear in the AST before visiting children
313 fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
314 ControlFlow::Continue(())
315 }
316
317 /// Invoked for any queries that appear in the AST after visiting children
318 fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
319 ControlFlow::Continue(())
320 }
321
322 /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
323 fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
324 ControlFlow::Continue(())
325 }
326
327 /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
328 fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
329 ControlFlow::Continue(())
330 }
331
332 /// Invoked for any table factors that appear in the AST before visiting children
333 fn pre_visit_table_factor(
334 &mut self,
335 _table_factor: &mut TableFactor,
336 ) -> ControlFlow<Self::Break> {
337 ControlFlow::Continue(())
338 }
339
340 /// Invoked for any table factors that appear in the AST after visiting children
341 fn post_visit_table_factor(
342 &mut self,
343 _table_factor: &mut TableFactor,
344 ) -> ControlFlow<Self::Break> {
345 ControlFlow::Continue(())
346 }
347
348 /// Invoked for any expressions that appear in the AST before visiting children
349 fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
350 ControlFlow::Continue(())
351 }
352
353 /// Invoked for any expressions that appear in the AST
354 fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
355 ControlFlow::Continue(())
356 }
357
358 /// Invoked for any statements that appear in the AST before visiting children
359 fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
360 ControlFlow::Continue(())
361 }
362
363 /// Invoked for any statements that appear in the AST after visiting children
364 fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
365 ControlFlow::Continue(())
366 }
367
368 /// Invoked for any value that appear in the AST before visiting children
369 fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
370 ControlFlow::Continue(())
371 }
372
373 /// Invoked for any statements that appear in the AST after visiting children
374 fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
375 ControlFlow::Continue(())
376 }
377}
378
379struct RelationVisitor<F>(F);
380
381impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
382 type Break = E;
383
384 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
385 self.0(relation)
386 }
387}
388
389impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
390 type Break = E;
391
392 fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
393 self.0(relation)
394 }
395}
396
397/// Invokes the provided closure on all relations (e.g. table names) present in `v`
398///
399/// # Example
400/// ```
401/// # use sqlparser::parser::Parser;
402/// # use sqlparser::dialect::GenericDialect;
403/// # use sqlparser::ast::{visit_relations};
404/// # use core::ops::ControlFlow;
405/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
406/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
407/// .unwrap();
408///
409/// // visit statements, capturing relations (table names)
410/// let mut visited = vec![];
411/// visit_relations(&statements, |relation| {
412/// visited.push(format!("RELATION: {}", relation));
413/// ControlFlow::<()>::Continue(())
414/// });
415///
416/// let expected : Vec<_> = [
417/// "RELATION: foo",
418/// "RELATION: bar",
419/// ]
420/// .into_iter().map(|s| s.to_string()).collect();
421///
422/// assert_eq!(visited, expected);
423/// ```
424pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
425where
426 V: Visit,
427 F: FnMut(&ObjectName) -> ControlFlow<E>,
428{
429 let mut visitor = RelationVisitor(f);
430 v.visit(&mut visitor)?;
431 ControlFlow::Continue(())
432}
433
434/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
435/// present in `v`.
436///
437/// When the closure mutates its argument, the new mutated relation will not be visited again.
438///
439/// # Example
440/// ```
441/// # use sqlparser::parser::Parser;
442/// # use sqlparser::dialect::GenericDialect;
443/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
444/// # use core::ops::ControlFlow;
445/// let sql = "SELECT a FROM foo";
446/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
447/// .unwrap();
448///
449/// // visit statements, renaming table foo to bar
450/// visit_relations_mut(&mut statements, |table| {
451/// table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
452/// ControlFlow::<()>::Continue(())
453/// });
454///
455/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
456/// ```
457pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
458where
459 V: VisitMut,
460 F: FnMut(&mut ObjectName) -> ControlFlow<E>,
461{
462 let mut visitor = RelationVisitor(f);
463 v.visit(&mut visitor)?;
464 ControlFlow::Continue(())
465}
466
467struct ExprVisitor<F>(F);
468
469impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
470 type Break = E;
471
472 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
473 self.0(expr)
474 }
475}
476
477impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
478 type Break = E;
479
480 fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
481 self.0(expr)
482 }
483}
484
485/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
486///
487/// # Example
488/// ```
489/// # use sqlparser::parser::Parser;
490/// # use sqlparser::dialect::GenericDialect;
491/// # use sqlparser::ast::{visit_expressions};
492/// # use core::ops::ControlFlow;
493/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
494/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
495/// .unwrap();
496///
497/// // visit all expressions
498/// let mut visited = vec![];
499/// visit_expressions(&statements, |expr| {
500/// visited.push(format!("EXPR: {}", expr));
501/// ControlFlow::<()>::Continue(())
502/// });
503///
504/// let expected : Vec<_> = [
505/// "EXPR: a",
506/// "EXPR: x IN (SELECT y FROM bar)",
507/// "EXPR: x",
508/// "EXPR: y",
509/// ]
510/// .into_iter().map(|s| s.to_string()).collect();
511///
512/// assert_eq!(visited, expected);
513/// ```
514pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
515where
516 V: Visit,
517 F: FnMut(&Expr) -> ControlFlow<E>,
518{
519 let mut visitor = ExprVisitor(f);
520 v.visit(&mut visitor)?;
521 ControlFlow::Continue(())
522}
523
524/// Invokes the provided closure iteratively with a mutable reference to all expressions
525/// present in `v`.
526///
527/// This performs a depth-first search, so if the closure mutates the expression
528///
529/// # Example
530///
531/// ## Remove all select limits in sub-queries
532/// ```
533/// # use sqlparser::parser::Parser;
534/// # use sqlparser::dialect::GenericDialect;
535/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
536/// # use core::ops::ControlFlow;
537/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
538/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
539///
540/// // Remove all select limits in sub-queries
541/// visit_expressions_mut(&mut statements, |expr| {
542/// if let Expr::Subquery(q) = expr {
543/// q.limit_clause = None;
544/// }
545/// ControlFlow::<()>::Continue(())
546/// });
547///
548/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
549/// ```
550///
551/// ## Wrap column name in function call
552///
553/// This demonstrates how to effectively replace an expression with another more complicated one
554/// that references the original. This example avoids unnecessary allocations by using the
555/// [`std::mem`] family of functions.
556///
557/// ```
558/// # use sqlparser::parser::Parser;
559/// # use sqlparser::dialect::GenericDialect;
560/// # use sqlparser::ast::*;
561/// # use core::ops::ControlFlow;
562/// let sql = "SELECT x, y FROM t";
563/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
564///
565/// visit_expressions_mut(&mut statements, |expr| {
566/// if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
567/// let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
568/// *expr = Expr::Function(Function {
569/// name: ObjectName::from(vec![Ident::new("f")]),
570/// uses_odbc_syntax: false,
571/// args: FunctionArguments::List(FunctionArgumentList {
572/// duplicate_treatment: None,
573/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
574/// clauses: vec![],
575/// }),
576/// null_treatment: None,
577/// filter: None,
578/// over: None,
579/// parameters: FunctionArguments::None,
580/// within_group: vec![],
581/// });
582/// }
583/// ControlFlow::<()>::Continue(())
584/// });
585///
586/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
587/// ```
588pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
589where
590 V: VisitMut,
591 F: FnMut(&mut Expr) -> ControlFlow<E>,
592{
593 v.visit(&mut ExprVisitor(f))?;
594 ControlFlow::Continue(())
595}
596
597struct StatementVisitor<F>(F);
598
599impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
600 type Break = E;
601
602 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
603 self.0(statement)
604 }
605}
606
607impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
608 type Break = E;
609
610 fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
611 self.0(statement)
612 }
613}
614
615/// Invokes the provided closure iteratively with a mutable reference to all statements
616/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
617///
618/// # Example
619/// ```
620/// # use sqlparser::parser::Parser;
621/// # use sqlparser::dialect::GenericDialect;
622/// # use sqlparser::ast::{visit_statements};
623/// # use core::ops::ControlFlow;
624/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
625/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
626/// .unwrap();
627///
628/// // visit all statements
629/// let mut visited = vec![];
630/// visit_statements(&statements, |stmt| {
631/// visited.push(format!("STATEMENT: {}", stmt));
632/// ControlFlow::<()>::Continue(())
633/// });
634///
635/// let expected : Vec<_> = [
636/// "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
637/// "STATEMENT: CREATE TABLE baz (q INT)"
638/// ]
639/// .into_iter().map(|s| s.to_string()).collect();
640///
641/// assert_eq!(visited, expected);
642/// ```
643pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
644where
645 V: Visit,
646 F: FnMut(&Statement) -> ControlFlow<E>,
647{
648 let mut visitor = StatementVisitor(f);
649 v.visit(&mut visitor)?;
650 ControlFlow::Continue(())
651}
652
653/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
654///
655/// # Example
656/// ```
657/// # use sqlparser::parser::Parser;
658/// # use sqlparser::dialect::GenericDialect;
659/// # use sqlparser::ast::{Statement, visit_statements_mut};
660/// # use core::ops::ControlFlow;
661/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
662/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
663///
664/// // Remove all select limits in outer statements (not in sub-queries)
665/// visit_statements_mut(&mut statements, |stmt| {
666/// if let Statement::Query(q) = stmt {
667/// q.limit_clause = None;
668/// }
669/// ControlFlow::<()>::Continue(())
670/// });
671///
672/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
673/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
674/// ```
675pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
676where
677 V: VisitMut,
678 F: FnMut(&mut Statement) -> ControlFlow<E>,
679{
680 v.visit(&mut StatementVisitor(f))?;
681 ControlFlow::Continue(())
682}
683
684#[cfg(test)]
685mod tests {
686 use super::*;
687 use crate::ast::Statement;
688 use crate::dialect::GenericDialect;
689 use crate::parser::Parser;
690 use crate::tokenizer::Tokenizer;
691
692 #[derive(Default)]
693 struct TestVisitor {
694 visited: Vec<String>,
695 }
696
697 impl Visitor for TestVisitor {
698 type Break = ();
699
700 /// Invoked for any queries that appear in the AST before visiting children
701 fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
702 self.visited.push(format!("PRE: QUERY: {query}"));
703 ControlFlow::Continue(())
704 }
705
706 /// Invoked for any queries that appear in the AST after visiting children
707 fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
708 self.visited.push(format!("POST: QUERY: {query}"));
709 ControlFlow::Continue(())
710 }
711
712 fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
713 self.visited.push(format!("PRE: RELATION: {relation}"));
714 ControlFlow::Continue(())
715 }
716
717 fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
718 self.visited.push(format!("POST: RELATION: {relation}"));
719 ControlFlow::Continue(())
720 }
721
722 fn pre_visit_table_factor(
723 &mut self,
724 table_factor: &TableFactor,
725 ) -> ControlFlow<Self::Break> {
726 self.visited
727 .push(format!("PRE: TABLE FACTOR: {table_factor}"));
728 ControlFlow::Continue(())
729 }
730
731 fn post_visit_table_factor(
732 &mut self,
733 table_factor: &TableFactor,
734 ) -> ControlFlow<Self::Break> {
735 self.visited
736 .push(format!("POST: TABLE FACTOR: {table_factor}"));
737 ControlFlow::Continue(())
738 }
739
740 fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
741 self.visited.push(format!("PRE: EXPR: {expr}"));
742 ControlFlow::Continue(())
743 }
744
745 fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
746 self.visited.push(format!("POST: EXPR: {expr}"));
747 ControlFlow::Continue(())
748 }
749
750 fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
751 self.visited.push(format!("PRE: STATEMENT: {statement}"));
752 ControlFlow::Continue(())
753 }
754
755 fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
756 self.visited.push(format!("POST: STATEMENT: {statement}"));
757 ControlFlow::Continue(())
758 }
759 }
760
761 fn do_visit<V: Visitor<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
762 let dialect = GenericDialect {};
763 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
764 let s = Parser::new(&dialect)
765 .with_tokens(tokens)
766 .parse_statement()
767 .unwrap();
768
769 let flow = s.visit(visitor);
770 assert_eq!(flow, ControlFlow::Continue(()));
771 s
772 }
773
774 #[test]
775 fn test_sql() {
776 let tests = vec![
777 (
778 "SELECT * from table_name as my_table",
779 vec![
780 "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
781 "PRE: QUERY: SELECT * FROM table_name AS my_table",
782 "PRE: TABLE FACTOR: table_name AS my_table",
783 "PRE: RELATION: table_name",
784 "POST: RELATION: table_name",
785 "POST: TABLE FACTOR: table_name AS my_table",
786 "POST: QUERY: SELECT * FROM table_name AS my_table",
787 "POST: STATEMENT: SELECT * FROM table_name AS my_table",
788 ],
789 ),
790 (
791 "SELECT * from t1 join t2 on t1.id = t2.t1_id",
792 vec![
793 "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
794 "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
795 "PRE: TABLE FACTOR: t1",
796 "PRE: RELATION: t1",
797 "POST: RELATION: t1",
798 "POST: TABLE FACTOR: t1",
799 "PRE: TABLE FACTOR: t2",
800 "PRE: RELATION: t2",
801 "POST: RELATION: t2",
802 "POST: TABLE FACTOR: t2",
803 "PRE: EXPR: t1.id = t2.t1_id",
804 "PRE: EXPR: t1.id",
805 "POST: EXPR: t1.id",
806 "PRE: EXPR: t2.t1_id",
807 "POST: EXPR: t2.t1_id",
808 "POST: EXPR: t1.id = t2.t1_id",
809 "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
810 "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
811 ],
812 ),
813 (
814 "SELECT * from t1 where EXISTS(SELECT column from t2)",
815 vec![
816 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
817 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
818 "PRE: TABLE FACTOR: t1",
819 "PRE: RELATION: t1",
820 "POST: RELATION: t1",
821 "POST: TABLE FACTOR: t1",
822 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
823 "PRE: QUERY: SELECT column FROM t2",
824 "PRE: EXPR: column",
825 "POST: EXPR: column",
826 "PRE: TABLE FACTOR: t2",
827 "PRE: RELATION: t2",
828 "POST: RELATION: t2",
829 "POST: TABLE FACTOR: t2",
830 "POST: QUERY: SELECT column FROM t2",
831 "POST: EXPR: EXISTS (SELECT column FROM t2)",
832 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
833 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
834 ],
835 ),
836 (
837 "SELECT * from t1 where EXISTS(SELECT column from t2)",
838 vec![
839 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
840 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
841 "PRE: TABLE FACTOR: t1",
842 "PRE: RELATION: t1",
843 "POST: RELATION: t1",
844 "POST: TABLE FACTOR: t1",
845 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
846 "PRE: QUERY: SELECT column FROM t2",
847 "PRE: EXPR: column",
848 "POST: EXPR: column",
849 "PRE: TABLE FACTOR: t2",
850 "PRE: RELATION: t2",
851 "POST: RELATION: t2",
852 "POST: TABLE FACTOR: t2",
853 "POST: QUERY: SELECT column FROM t2",
854 "POST: EXPR: EXISTS (SELECT column FROM t2)",
855 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
856 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
857 ],
858 ),
859 (
860 "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
861 vec![
862 "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
863 "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
864 "PRE: TABLE FACTOR: t1",
865 "PRE: RELATION: t1",
866 "POST: RELATION: t1",
867 "POST: TABLE FACTOR: t1",
868 "PRE: EXPR: EXISTS (SELECT column FROM t2)",
869 "PRE: QUERY: SELECT column FROM t2",
870 "PRE: EXPR: column",
871 "POST: EXPR: column",
872 "PRE: TABLE FACTOR: t2",
873 "PRE: RELATION: t2",
874 "POST: RELATION: t2",
875 "POST: TABLE FACTOR: t2",
876 "POST: QUERY: SELECT column FROM t2",
877 "POST: EXPR: EXISTS (SELECT column FROM t2)",
878 "PRE: TABLE FACTOR: t3",
879 "PRE: RELATION: t3",
880 "POST: RELATION: t3",
881 "POST: TABLE FACTOR: t3",
882 "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
883 "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
884 ],
885 ),
886 (
887 concat!(
888 "SELECT * FROM monthly_sales ",
889 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
890 "ORDER BY EMPID"
891 ),
892 vec![
893 "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
894 "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
895 "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
896 "PRE: TABLE FACTOR: monthly_sales",
897 "PRE: RELATION: monthly_sales",
898 "POST: RELATION: monthly_sales",
899 "POST: TABLE FACTOR: monthly_sales",
900 "PRE: EXPR: SUM(a.amount)",
901 "PRE: EXPR: a.amount",
902 "POST: EXPR: a.amount",
903 "POST: EXPR: SUM(a.amount)",
904 "PRE: EXPR: a.MONTH",
905 "POST: EXPR: a.MONTH",
906 "PRE: EXPR: 'JAN'",
907 "POST: EXPR: 'JAN'",
908 "PRE: EXPR: 'FEB'",
909 "POST: EXPR: 'FEB'",
910 "PRE: EXPR: 'MAR'",
911 "POST: EXPR: 'MAR'",
912 "PRE: EXPR: 'APR'",
913 "POST: EXPR: 'APR'",
914 "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
915 "PRE: EXPR: EMPID",
916 "POST: EXPR: EMPID",
917 "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
918 "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
919 ]
920 ),
921 (
922 "SHOW COLUMNS FROM t1",
923 vec![
924 "PRE: STATEMENT: SHOW COLUMNS FROM t1",
925 "PRE: RELATION: t1",
926 "POST: RELATION: t1",
927 "POST: STATEMENT: SHOW COLUMNS FROM t1",
928 ],
929 ),
930 ];
931 for (sql, expected) in tests {
932 let mut visitor = TestVisitor::default();
933 let _ = do_visit(sql, &mut visitor);
934 let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
935 assert_eq!(actual, expected)
936 }
937 }
938
939 struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
940
941 impl Visitor for QuickVisitor {
942 type Break = ();
943 }
944
945 #[test]
946 fn overflow() {
947 let cond = (0..1000)
948 .map(|n| format!("X = {n}"))
949 .collect::<Vec<_>>()
950 .join(" OR ");
951 let sql = format!("SELECT x where {cond}");
952
953 let dialect = GenericDialect {};
954 let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
955 let s = Parser::new(&dialect)
956 .with_tokens(tokens)
957 .parse_statement()
958 .unwrap();
959
960 let mut visitor = QuickVisitor {};
961 let flow = s.visit(&mut visitor);
962 assert_eq!(flow, ControlFlow::Continue(()));
963 }
964}
965
966#[cfg(test)]
967mod visit_mut_tests {
968 use crate::ast::{Statement, Value, VisitMut, VisitorMut};
969 use crate::dialect::GenericDialect;
970 use crate::parser::Parser;
971 use crate::tokenizer::Tokenizer;
972 use core::ops::ControlFlow;
973
974 #[derive(Default)]
975 struct MutatorVisitor {
976 index: u64,
977 }
978
979 impl VisitorMut for MutatorVisitor {
980 type Break = ();
981
982 fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
983 self.index += 1;
984 *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
985 ControlFlow::Continue(())
986 }
987
988 fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
989 ControlFlow::Continue(())
990 }
991 }
992
993 fn do_visit_mut<V: VisitorMut<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
994 let dialect = GenericDialect {};
995 let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
996 let mut s = Parser::new(&dialect)
997 .with_tokens(tokens)
998 .parse_statement()
999 .unwrap();
1000
1001 let flow = s.visit(visitor);
1002 assert_eq!(flow, ControlFlow::Continue(()));
1003 s
1004 }
1005
1006 #[test]
1007 fn test_value_redact() {
1008 let tests = vec![
1009 (
1010 concat!(
1011 "SELECT * FROM monthly_sales ",
1012 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
1013 "ORDER BY EMPID"
1014 ),
1015 concat!(
1016 "SELECT * FROM monthly_sales ",
1017 "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
1018 "ORDER BY EMPID"
1019 ),
1020 ),
1021 ];
1022
1023 for (sql, expected) in tests {
1024 let mut visitor = MutatorVisitor::default();
1025 let mutated = do_visit_mut(sql, &mut visitor);
1026 assert_eq!(mutated.to_string(), expected)
1027 }
1028 }
1029}