visitor.rs - source

sqlparser/ast/

visitor.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
19
20use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor, Value};
21use core::ops::ControlFlow;
22
23/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
24/// recursively visiting parsed SQL statements.
25///
26/// # Note
27///
28/// This trait should be automatically derived for sqlparser AST nodes
29/// using the [Visit](sqlparser_derive::Visit) proc macro.
30///
31/// ```text
32/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
33/// ```
34pub trait Visit {
35    /// Visit this node with the provided [`Visitor`].
36    ///
37    /// Implementations should call the appropriate visitor hooks to traverse
38    /// child nodes and return a `ControlFlow` value to allow early exit.
39    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break>;
40}
41
42/// A type that can be visited by a [`VisitorMut`]. See [`VisitorMut`] for
43/// recursively visiting parsed SQL statements.
44///
45/// # Note
46///
47/// This trait should be automatically derived for sqlparser AST nodes
48/// using the [VisitMut](sqlparser_derive::VisitMut) proc macro.
49///
50/// ```text
51/// #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
52/// ```
53pub trait VisitMut {
54    /// Mutably visit this node with the provided [`VisitorMut`].
55    ///
56    /// Implementations should call the appropriate mutable visitor hooks to
57    /// traverse and allow in-place mutation of child nodes. Returning a
58    /// `ControlFlow` value permits early termination of the traversal.
59    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break>;
60}
61
62impl<T: Visit> Visit for Option<T> {
63    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
64        if let Some(s) = self {
65            s.visit(visitor)?;
66        }
67        ControlFlow::Continue(())
68    }
69}
70
71impl<T: Visit> Visit for Vec<T> {
72    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
73        for v in self {
74            v.visit(visitor)?;
75        }
76        ControlFlow::Continue(())
77    }
78}
79
80impl<T: Visit> Visit for Box<T> {
81    fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
82        T::visit(self, visitor)
83    }
84}
85
86impl<T: VisitMut> VisitMut for Option<T> {
87    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
88        if let Some(s) = self {
89            s.visit(visitor)?;
90        }
91        ControlFlow::Continue(())
92    }
93}
94
95impl<T: VisitMut> VisitMut for Vec<T> {
96    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
97        for v in self {
98            v.visit(visitor)?;
99        }
100        ControlFlow::Continue(())
101    }
102}
103
104impl<T: VisitMut> VisitMut for Box<T> {
105    fn visit<V: VisitorMut>(&mut self, visitor: &mut V) -> ControlFlow<V::Break> {
106        T::visit(self, visitor)
107    }
108}
109
110macro_rules! visit_noop {
111    ($($t:ty),+) => {
112        $(impl Visit for $t {
113            fn visit<V: Visitor>(&self, _visitor: &mut V) -> ControlFlow<V::Break> {
114               ControlFlow::Continue(())
115            }
116        })+
117        $(impl VisitMut for $t {
118            fn visit<V: VisitorMut>(&mut self, _visitor: &mut V) -> ControlFlow<V::Break> {
119               ControlFlow::Continue(())
120            }
121        })+
122    };
123}
124
125visit_noop!(u8, u16, u32, u64, i8, i16, i32, i64, char, bool, String);
126
127#[cfg(feature = "bigdecimal")]
128visit_noop!(bigdecimal::BigDecimal);
129
130/// A visitor that can be used to walk an AST tree.
131///
132/// `pre_visit_` methods are invoked before visiting all children of the
133/// node and `post_visit_` methods are invoked after visiting all
134/// children of the node.
135///
136/// # See also
137///
138/// These methods provide a more concise way of visiting nodes of a certain type:
139/// * [visit_relations]
140/// * [visit_expressions]
141/// * [visit_statements]
142///
143/// # Example
144/// ```
145/// # use sqlparser::parser::Parser;
146/// # use sqlparser::dialect::GenericDialect;
147/// # use sqlparser::ast::{Visit, Visitor, ObjectName, Expr};
148/// # use core::ops::ControlFlow;
149/// // A structure that records statements and relations
150/// #[derive(Default)]
151/// struct V {
152///    visited: Vec<String>,
153/// }
154///
155/// // Visit relations and exprs before children are visited (depth first walk)
156/// // Note you can also visit statements and visit exprs after children have been visited
157/// impl Visitor for V {
158///   type Break = ();
159///
160///   fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
161///     self.visited.push(format!("PRE: RELATION: {}", relation));
162///     ControlFlow::Continue(())
163///   }
164///
165///   fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
166///     self.visited.push(format!("PRE: EXPR: {}", expr));
167///     ControlFlow::Continue(())
168///   }
169/// }
170///
171/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
172/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
173///    .unwrap();
174///
175/// // Drive the visitor through the AST
176/// let mut visitor = V::default();
177/// statements.visit(&mut visitor);
178///
179/// // The visitor has visited statements and expressions in pre-traversal order
180/// let expected : Vec<_> = [
181///   "PRE: EXPR: a",
182///   "PRE: RELATION: foo",
183///   "PRE: EXPR: x IN (SELECT y FROM bar)",
184///   "PRE: EXPR: x",
185///   "PRE: EXPR: y",
186///   "PRE: RELATION: bar",
187/// ]
188///   .into_iter().map(|s| s.to_string()).collect();
189///
190/// assert_eq!(visitor.visited, expected);
191/// ```
192pub trait Visitor {
193    /// Type returned when the recursion returns early.
194    ///
195    /// Important note: The `Break` type should be kept as small as possible to prevent
196    /// stack overflow during recursion. If you need to return an error, consider
197    /// boxing it with `Box` to minimize stack usage.
198    type Break;
199
200    /// Invoked for any queries that appear in the AST before visiting children
201    fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
202        ControlFlow::Continue(())
203    }
204
205    /// Invoked for any queries that appear in the AST after visiting children
206    fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
207        ControlFlow::Continue(())
208    }
209
210    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
211    fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
212        ControlFlow::Continue(())
213    }
214
215    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
216    fn post_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
217        ControlFlow::Continue(())
218    }
219
220    /// Invoked for any table factors that appear in the AST before visiting children
221    fn pre_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
222        ControlFlow::Continue(())
223    }
224
225    /// Invoked for any table factors that appear in the AST after visiting children
226    fn post_visit_table_factor(&mut self, _table_factor: &TableFactor) -> ControlFlow<Self::Break> {
227        ControlFlow::Continue(())
228    }
229
230    /// Invoked for any expressions that appear in the AST before visiting children
231    fn pre_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
232        ControlFlow::Continue(())
233    }
234
235    /// Invoked for any expressions that appear in the AST
236    fn post_visit_expr(&mut self, _expr: &Expr) -> ControlFlow<Self::Break> {
237        ControlFlow::Continue(())
238    }
239
240    /// Invoked for any statements that appear in the AST before visiting children
241    fn pre_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
242        ControlFlow::Continue(())
243    }
244
245    /// Invoked for any statements that appear in the AST after visiting children
246    fn post_visit_statement(&mut self, _statement: &Statement) -> ControlFlow<Self::Break> {
247        ControlFlow::Continue(())
248    }
249
250    /// Invoked for any Value that appear in the AST before visiting children
251    fn pre_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
252        ControlFlow::Continue(())
253    }
254
255    /// Invoked for any Value that appear in the AST after visiting children
256    fn post_visit_value(&mut self, _value: &Value) -> ControlFlow<Self::Break> {
257        ControlFlow::Continue(())
258    }
259}
260
261/// A visitor that can be used to mutate an AST tree.
262///
263/// `pre_visit_` methods are invoked before visiting all children of the
264/// node and `post_visit_` methods are invoked after visiting all
265/// children of the node.
266///
267/// # See also
268///
269/// These methods provide a more concise way of visiting nodes of a certain type:
270/// * [visit_relations_mut]
271/// * [visit_expressions_mut]
272/// * [visit_statements_mut]
273///
274/// # Example
275/// ```
276/// # use sqlparser::parser::Parser;
277/// # use sqlparser::dialect::GenericDialect;
278/// # use sqlparser::ast::{VisitMut, VisitorMut, ObjectName, Expr, Ident};
279/// # use core::ops::ControlFlow;
280///
281/// // A visitor that replaces "to_replace" with "replaced" in all expressions
282/// struct Replacer;
283///
284/// // Visit each expression after its children have been visited
285/// impl VisitorMut for Replacer {
286///   type Break = ();
287///
288///   fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
289///     if let Expr::Identifier(Ident{ value, ..}) = expr {
290///         *value = value.replace("to_replace", "replaced")
291///     }
292///     ControlFlow::Continue(())
293///   }
294/// }
295///
296/// let sql = "SELECT to_replace FROM foo where to_replace IN (SELECT to_replace FROM bar)";
297/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
298///
299/// // Drive the visitor through the AST
300/// statements.visit(&mut Replacer);
301///
302/// assert_eq!(statements[0].to_string(), "SELECT replaced FROM foo WHERE replaced IN (SELECT replaced FROM bar)");
303/// ```
304pub trait VisitorMut {
305    /// Type returned when the recursion returns early.
306    ///
307    /// Important note: The `Break` type should be kept as small as possible to prevent
308    /// stack overflow during recursion. If you need to return an error, consider
309    /// boxing it with `Box` to minimize stack usage.
310    type Break;
311
312    /// Invoked for any queries that appear in the AST before visiting children
313    fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
314        ControlFlow::Continue(())
315    }
316
317    /// Invoked for any queries that appear in the AST after visiting children
318    fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
319        ControlFlow::Continue(())
320    }
321
322    /// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
323    fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
324        ControlFlow::Continue(())
325    }
326
327    /// Invoked for any relations (e.g. tables) that appear in the AST after visiting children
328    fn post_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
329        ControlFlow::Continue(())
330    }
331
332    /// Invoked for any table factors that appear in the AST before visiting children
333    fn pre_visit_table_factor(
334        &mut self,
335        _table_factor: &mut TableFactor,
336    ) -> ControlFlow<Self::Break> {
337        ControlFlow::Continue(())
338    }
339
340    /// Invoked for any table factors that appear in the AST after visiting children
341    fn post_visit_table_factor(
342        &mut self,
343        _table_factor: &mut TableFactor,
344    ) -> ControlFlow<Self::Break> {
345        ControlFlow::Continue(())
346    }
347
348    /// Invoked for any expressions that appear in the AST before visiting children
349    fn pre_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
350        ControlFlow::Continue(())
351    }
352
353    /// Invoked for any expressions that appear in the AST
354    fn post_visit_expr(&mut self, _expr: &mut Expr) -> ControlFlow<Self::Break> {
355        ControlFlow::Continue(())
356    }
357
358    /// Invoked for any statements that appear in the AST before visiting children
359    fn pre_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
360        ControlFlow::Continue(())
361    }
362
363    /// Invoked for any statements that appear in the AST after visiting children
364    fn post_visit_statement(&mut self, _statement: &mut Statement) -> ControlFlow<Self::Break> {
365        ControlFlow::Continue(())
366    }
367
368    /// Invoked for any value that appear in the AST before visiting children
369    fn pre_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
370        ControlFlow::Continue(())
371    }
372
373    /// Invoked for any statements that appear in the AST after visiting children
374    fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
375        ControlFlow::Continue(())
376    }
377}
378
379struct RelationVisitor<F>(F);
380
381impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F> {
382    type Break = E;
383
384    fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
385        self.0(relation)
386    }
387}
388
389impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
390    type Break = E;
391
392    fn post_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
393        self.0(relation)
394    }
395}
396
397/// Invokes the provided closure on all relations (e.g. table names) present in `v`
398///
399/// # Example
400/// ```
401/// # use sqlparser::parser::Parser;
402/// # use sqlparser::dialect::GenericDialect;
403/// # use sqlparser::ast::{visit_relations};
404/// # use core::ops::ControlFlow;
405/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
406/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
407///    .unwrap();
408///
409/// // visit statements, capturing relations (table names)
410/// let mut visited = vec![];
411/// visit_relations(&statements, |relation| {
412///   visited.push(format!("RELATION: {}", relation));
413///   ControlFlow::<()>::Continue(())
414/// });
415///
416/// let expected : Vec<_> = [
417///   "RELATION: foo",
418///   "RELATION: bar",
419/// ]
420///   .into_iter().map(|s| s.to_string()).collect();
421///
422/// assert_eq!(visited, expected);
423/// ```
424pub fn visit_relations<V, E, F>(v: &V, f: F) -> ControlFlow<E>
425where
426    V: Visit,
427    F: FnMut(&ObjectName) -> ControlFlow<E>,
428{
429    let mut visitor = RelationVisitor(f);
430    v.visit(&mut visitor)?;
431    ControlFlow::Continue(())
432}
433
434/// Invokes the provided closure with a mutable reference to all relations (e.g. table names)
435/// present in `v`.
436///
437/// When the closure mutates its argument, the new mutated relation will not be visited again.
438///
439/// # Example
440/// ```
441/// # use sqlparser::parser::Parser;
442/// # use sqlparser::dialect::GenericDialect;
443/// # use sqlparser::ast::{ObjectName, ObjectNamePart, Ident, visit_relations_mut};
444/// # use core::ops::ControlFlow;
445/// let sql = "SELECT a FROM foo";
446/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
447///    .unwrap();
448///
449/// // visit statements, renaming table foo to bar
450/// visit_relations_mut(&mut statements, |table| {
451///   table.0[0] = ObjectNamePart::Identifier(Ident::new("bar"));
452///   ControlFlow::<()>::Continue(())
453/// });
454///
455/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
456/// ```
457pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
458where
459    V: VisitMut,
460    F: FnMut(&mut ObjectName) -> ControlFlow<E>,
461{
462    let mut visitor = RelationVisitor(f);
463    v.visit(&mut visitor)?;
464    ControlFlow::Continue(())
465}
466
467struct ExprVisitor<F>(F);
468
469impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
470    type Break = E;
471
472    fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
473        self.0(expr)
474    }
475}
476
477impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
478    type Break = E;
479
480    fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
481        self.0(expr)
482    }
483}
484
485/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
486///
487/// # Example
488/// ```
489/// # use sqlparser::parser::Parser;
490/// # use sqlparser::dialect::GenericDialect;
491/// # use sqlparser::ast::{visit_expressions};
492/// # use core::ops::ControlFlow;
493/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
494/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
495///    .unwrap();
496///
497/// // visit all expressions
498/// let mut visited = vec![];
499/// visit_expressions(&statements, |expr| {
500///   visited.push(format!("EXPR: {}", expr));
501///   ControlFlow::<()>::Continue(())
502/// });
503///
504/// let expected : Vec<_> = [
505///   "EXPR: a",
506///   "EXPR: x IN (SELECT y FROM bar)",
507///   "EXPR: x",
508///   "EXPR: y",
509/// ]
510///   .into_iter().map(|s| s.to_string()).collect();
511///
512/// assert_eq!(visited, expected);
513/// ```
514pub fn visit_expressions<V, E, F>(v: &V, f: F) -> ControlFlow<E>
515where
516    V: Visit,
517    F: FnMut(&Expr) -> ControlFlow<E>,
518{
519    let mut visitor = ExprVisitor(f);
520    v.visit(&mut visitor)?;
521    ControlFlow::Continue(())
522}
523
524/// Invokes the provided closure iteratively with a mutable reference to all expressions
525/// present in `v`.
526///
527/// This performs a depth-first search, so if the closure mutates the expression
528///
529/// # Example
530///
531/// ## Remove all select limits in sub-queries
532/// ```
533/// # use sqlparser::parser::Parser;
534/// # use sqlparser::dialect::GenericDialect;
535/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
536/// # use core::ops::ControlFlow;
537/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
538/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
539///
540/// // Remove all select limits in sub-queries
541/// visit_expressions_mut(&mut statements, |expr| {
542///   if let Expr::Subquery(q) = expr {
543///      q.limit_clause = None;
544///   }
545///   ControlFlow::<()>::Continue(())
546/// });
547///
548/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
549/// ```
550///
551/// ## Wrap column name in function call
552///
553/// This demonstrates how to effectively replace an expression with another more complicated one
554/// that references the original. This example avoids unnecessary allocations by using the
555/// [`std::mem`] family of functions.
556///
557/// ```
558/// # use sqlparser::parser::Parser;
559/// # use sqlparser::dialect::GenericDialect;
560/// # use sqlparser::ast::*;
561/// # use core::ops::ControlFlow;
562/// let sql = "SELECT x, y FROM t";
563/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
564///
565/// visit_expressions_mut(&mut statements, |expr| {
566///   if matches!(expr, Expr::Identifier(col_name) if col_name.value == "x") {
567///     let old_expr = std::mem::replace(expr, Expr::value(Value::Null));
568///     *expr = Expr::Function(Function {
569///           name: ObjectName::from(vec![Ident::new("f")]),
570///           uses_odbc_syntax: false,
571///           args: FunctionArguments::List(FunctionArgumentList {
572///               duplicate_treatment: None,
573///               args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
574///               clauses: vec![],
575///           }),
576///           null_treatment: None,
577///           filter: None,
578///           over: None,
579///           parameters: FunctionArguments::None,
580///           within_group: vec![],
581///      });
582///   }
583///   ControlFlow::<()>::Continue(())
584/// });
585///
586/// assert_eq!(statements[0].to_string(), "SELECT f(x), y FROM t");
587/// ```
588pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
589where
590    V: VisitMut,
591    F: FnMut(&mut Expr) -> ControlFlow<E>,
592{
593    v.visit(&mut ExprVisitor(f))?;
594    ControlFlow::Continue(())
595}
596
597struct StatementVisitor<F>(F);
598
599impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
600    type Break = E;
601
602    fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
603        self.0(statement)
604    }
605}
606
607impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
608    type Break = E;
609
610    fn post_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
611        self.0(statement)
612    }
613}
614
615/// Invokes the provided closure iteratively with a mutable reference to all statements
616/// present in `v` (e.g. `SELECT`, `CREATE TABLE`, etc).
617///
618/// # Example
619/// ```
620/// # use sqlparser::parser::Parser;
621/// # use sqlparser::dialect::GenericDialect;
622/// # use sqlparser::ast::{visit_statements};
623/// # use core::ops::ControlFlow;
624/// let sql = "SELECT a FROM foo where x IN (SELECT y FROM bar); CREATE TABLE baz(q int)";
625/// let statements = Parser::parse_sql(&GenericDialect{}, sql)
626///    .unwrap();
627///
628/// // visit all statements
629/// let mut visited = vec![];
630/// visit_statements(&statements, |stmt| {
631///   visited.push(format!("STATEMENT: {}", stmt));
632///   ControlFlow::<()>::Continue(())
633/// });
634///
635/// let expected : Vec<_> = [
636///   "STATEMENT: SELECT a FROM foo WHERE x IN (SELECT y FROM bar)",
637///   "STATEMENT: CREATE TABLE baz (q INT)"
638/// ]
639///   .into_iter().map(|s| s.to_string()).collect();
640///
641/// assert_eq!(visited, expected);
642/// ```
643pub fn visit_statements<V, E, F>(v: &V, f: F) -> ControlFlow<E>
644where
645    V: Visit,
646    F: FnMut(&Statement) -> ControlFlow<E>,
647{
648    let mut visitor = StatementVisitor(f);
649    v.visit(&mut visitor)?;
650    ControlFlow::Continue(())
651}
652
653/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
654///
655/// # Example
656/// ```
657/// # use sqlparser::parser::Parser;
658/// # use sqlparser::dialect::GenericDialect;
659/// # use sqlparser::ast::{Statement, visit_statements_mut};
660/// # use core::ops::ControlFlow;
661/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
662/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
663///
664/// // Remove all select limits in outer statements (not in sub-queries)
665/// visit_statements_mut(&mut statements, |stmt| {
666///   if let Statement::Query(q) = stmt {
667///      q.limit_clause = None;
668///   }
669///   ControlFlow::<()>::Continue(())
670/// });
671///
672/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
673/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
674/// ```
675pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
676where
677    V: VisitMut,
678    F: FnMut(&mut Statement) -> ControlFlow<E>,
679{
680    v.visit(&mut StatementVisitor(f))?;
681    ControlFlow::Continue(())
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687    use crate::ast::Statement;
688    use crate::dialect::GenericDialect;
689    use crate::parser::Parser;
690    use crate::tokenizer::Tokenizer;
691
692    #[derive(Default)]
693    struct TestVisitor {
694        visited: Vec<String>,
695    }
696
697    impl Visitor for TestVisitor {
698        type Break = ();
699
700        /// Invoked for any queries that appear in the AST before visiting children
701        fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
702            self.visited.push(format!("PRE: QUERY: {query}"));
703            ControlFlow::Continue(())
704        }
705
706        /// Invoked for any queries that appear in the AST after visiting children
707        fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
708            self.visited.push(format!("POST: QUERY: {query}"));
709            ControlFlow::Continue(())
710        }
711
712        fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
713            self.visited.push(format!("PRE: RELATION: {relation}"));
714            ControlFlow::Continue(())
715        }
716
717        fn post_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
718            self.visited.push(format!("POST: RELATION: {relation}"));
719            ControlFlow::Continue(())
720        }
721
722        fn pre_visit_table_factor(
723            &mut self,
724            table_factor: &TableFactor,
725        ) -> ControlFlow<Self::Break> {
726            self.visited
727                .push(format!("PRE: TABLE FACTOR: {table_factor}"));
728            ControlFlow::Continue(())
729        }
730
731        fn post_visit_table_factor(
732            &mut self,
733            table_factor: &TableFactor,
734        ) -> ControlFlow<Self::Break> {
735            self.visited
736                .push(format!("POST: TABLE FACTOR: {table_factor}"));
737            ControlFlow::Continue(())
738        }
739
740        fn pre_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
741            self.visited.push(format!("PRE: EXPR: {expr}"));
742            ControlFlow::Continue(())
743        }
744
745        fn post_visit_expr(&mut self, expr: &Expr) -> ControlFlow<Self::Break> {
746            self.visited.push(format!("POST: EXPR: {expr}"));
747            ControlFlow::Continue(())
748        }
749
750        fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
751            self.visited.push(format!("PRE: STATEMENT: {statement}"));
752            ControlFlow::Continue(())
753        }
754
755        fn post_visit_statement(&mut self, statement: &Statement) -> ControlFlow<Self::Break> {
756            self.visited.push(format!("POST: STATEMENT: {statement}"));
757            ControlFlow::Continue(())
758        }
759    }
760
761    fn do_visit<V: Visitor<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
762        let dialect = GenericDialect {};
763        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
764        let s = Parser::new(&dialect)
765            .with_tokens(tokens)
766            .parse_statement()
767            .unwrap();
768
769        let flow = s.visit(visitor);
770        assert_eq!(flow, ControlFlow::Continue(()));
771        s
772    }
773
774    #[test]
775    fn test_sql() {
776        let tests = vec![
777            (
778                "SELECT * from table_name as my_table",
779                vec![
780                    "PRE: STATEMENT: SELECT * FROM table_name AS my_table",
781                    "PRE: QUERY: SELECT * FROM table_name AS my_table",
782                    "PRE: TABLE FACTOR: table_name AS my_table",
783                    "PRE: RELATION: table_name",
784                    "POST: RELATION: table_name",
785                    "POST: TABLE FACTOR: table_name AS my_table",
786                    "POST: QUERY: SELECT * FROM table_name AS my_table",
787                    "POST: STATEMENT: SELECT * FROM table_name AS my_table",
788                ],
789            ),
790            (
791                "SELECT * from t1 join t2 on t1.id = t2.t1_id",
792                vec![
793                    "PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
794                    "PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
795                    "PRE: TABLE FACTOR: t1",
796                    "PRE: RELATION: t1",
797                    "POST: RELATION: t1",
798                    "POST: TABLE FACTOR: t1",
799                    "PRE: TABLE FACTOR: t2",
800                    "PRE: RELATION: t2",
801                    "POST: RELATION: t2",
802                    "POST: TABLE FACTOR: t2",
803                    "PRE: EXPR: t1.id = t2.t1_id",
804                    "PRE: EXPR: t1.id",
805                    "POST: EXPR: t1.id",
806                    "PRE: EXPR: t2.t1_id",
807                    "POST: EXPR: t2.t1_id",
808                    "POST: EXPR: t1.id = t2.t1_id",
809                    "POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
810                    "POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
811                ],
812            ),
813            (
814                "SELECT * from t1 where EXISTS(SELECT column from t2)",
815                vec![
816                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
817                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
818                    "PRE: TABLE FACTOR: t1",
819                    "PRE: RELATION: t1",
820                    "POST: RELATION: t1",
821                    "POST: TABLE FACTOR: t1",
822                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
823                    "PRE: QUERY: SELECT column FROM t2",
824                    "PRE: EXPR: column",
825                    "POST: EXPR: column",
826                    "PRE: TABLE FACTOR: t2",
827                    "PRE: RELATION: t2",
828                    "POST: RELATION: t2",
829                    "POST: TABLE FACTOR: t2",
830                    "POST: QUERY: SELECT column FROM t2",
831                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
832                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
833                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
834                ],
835            ),
836            (
837                "SELECT * from t1 where EXISTS(SELECT column from t2)",
838                vec![
839                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
840                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
841                    "PRE: TABLE FACTOR: t1",
842                    "PRE: RELATION: t1",
843                    "POST: RELATION: t1",
844                    "POST: TABLE FACTOR: t1",
845                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
846                    "PRE: QUERY: SELECT column FROM t2",
847                    "PRE: EXPR: column",
848                    "POST: EXPR: column",
849                    "PRE: TABLE FACTOR: t2",
850                    "PRE: RELATION: t2",
851                    "POST: RELATION: t2",
852                    "POST: TABLE FACTOR: t2",
853                    "POST: QUERY: SELECT column FROM t2",
854                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
855                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
856                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
857                ],
858            ),
859            (
860                "SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
861                vec![
862                    "PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
863                    "PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
864                    "PRE: TABLE FACTOR: t1",
865                    "PRE: RELATION: t1",
866                    "POST: RELATION: t1",
867                    "POST: TABLE FACTOR: t1",
868                    "PRE: EXPR: EXISTS (SELECT column FROM t2)",
869                    "PRE: QUERY: SELECT column FROM t2",
870                    "PRE: EXPR: column",
871                    "POST: EXPR: column",
872                    "PRE: TABLE FACTOR: t2",
873                    "PRE: RELATION: t2",
874                    "POST: RELATION: t2",
875                    "POST: TABLE FACTOR: t2",
876                    "POST: QUERY: SELECT column FROM t2",
877                    "POST: EXPR: EXISTS (SELECT column FROM t2)",
878                    "PRE: TABLE FACTOR: t3",
879                    "PRE: RELATION: t3",
880                    "POST: RELATION: t3",
881                    "POST: TABLE FACTOR: t3",
882                    "POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
883                    "POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
884                ],
885            ),
886            (
887                concat!(
888                    "SELECT * FROM monthly_sales ",
889                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
890                    "ORDER BY EMPID"
891                ),
892                vec![
893                    "PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
894                    "PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
895                    "PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
896                    "PRE: TABLE FACTOR: monthly_sales",
897                    "PRE: RELATION: monthly_sales",
898                    "POST: RELATION: monthly_sales",
899                    "POST: TABLE FACTOR: monthly_sales",
900                    "PRE: EXPR: SUM(a.amount)",
901                    "PRE: EXPR: a.amount",
902                    "POST: EXPR: a.amount",
903                    "POST: EXPR: SUM(a.amount)",
904                    "PRE: EXPR: a.MONTH",
905                    "POST: EXPR: a.MONTH",
906                    "PRE: EXPR: 'JAN'",
907                    "POST: EXPR: 'JAN'",
908                    "PRE: EXPR: 'FEB'",
909                    "POST: EXPR: 'FEB'",
910                    "PRE: EXPR: 'MAR'",
911                    "POST: EXPR: 'MAR'",
912                    "PRE: EXPR: 'APR'",
913                    "POST: EXPR: 'APR'",
914                    "POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
915                    "PRE: EXPR: EMPID",
916                    "POST: EXPR: EMPID",
917                    "POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
918                    "POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
919                ]
920            ),
921            (
922                "SHOW COLUMNS FROM t1",
923                vec![
924                    "PRE: STATEMENT: SHOW COLUMNS FROM t1",
925                    "PRE: RELATION: t1",
926                    "POST: RELATION: t1",
927                    "POST: STATEMENT: SHOW COLUMNS FROM t1",
928                ],
929            ),
930        ];
931        for (sql, expected) in tests {
932            let mut visitor = TestVisitor::default();
933            let _ = do_visit(sql, &mut visitor);
934            let actual: Vec<_> = visitor.visited.iter().map(|x| x.as_str()).collect();
935            assert_eq!(actual, expected)
936        }
937    }
938
939    struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
940
941    impl Visitor for QuickVisitor {
942        type Break = ();
943    }
944
945    #[test]
946    fn overflow() {
947        let cond = (0..1000)
948            .map(|n| format!("X = {n}"))
949            .collect::<Vec<_>>()
950            .join(" OR ");
951        let sql = format!("SELECT x where {cond}");
952
953        let dialect = GenericDialect {};
954        let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
955        let s = Parser::new(&dialect)
956            .with_tokens(tokens)
957            .parse_statement()
958            .unwrap();
959
960        let mut visitor = QuickVisitor {};
961        let flow = s.visit(&mut visitor);
962        assert_eq!(flow, ControlFlow::Continue(()));
963    }
964}
965
966#[cfg(test)]
967mod visit_mut_tests {
968    use crate::ast::{Statement, Value, VisitMut, VisitorMut};
969    use crate::dialect::GenericDialect;
970    use crate::parser::Parser;
971    use crate::tokenizer::Tokenizer;
972    use core::ops::ControlFlow;
973
974    #[derive(Default)]
975    struct MutatorVisitor {
976        index: u64,
977    }
978
979    impl VisitorMut for MutatorVisitor {
980        type Break = ();
981
982        fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
983            self.index += 1;
984            *value = Value::SingleQuotedString(format!("REDACTED_{}", self.index));
985            ControlFlow::Continue(())
986        }
987
988        fn post_visit_value(&mut self, _value: &mut Value) -> ControlFlow<Self::Break> {
989            ControlFlow::Continue(())
990        }
991    }
992
993    fn do_visit_mut<V: VisitorMut<Break = ()>>(sql: &str, visitor: &mut V) -> Statement {
994        let dialect = GenericDialect {};
995        let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
996        let mut s = Parser::new(&dialect)
997            .with_tokens(tokens)
998            .parse_statement()
999            .unwrap();
1000
1001        let flow = s.visit(visitor);
1002        assert_eq!(flow, ControlFlow::Continue(()));
1003        s
1004    }
1005
1006    #[test]
1007    fn test_value_redact() {
1008        let tests = vec![
1009            (
1010                concat!(
1011                    "SELECT * FROM monthly_sales ",
1012                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
1013                    "ORDER BY EMPID"
1014                ),
1015                concat!(
1016                    "SELECT * FROM monthly_sales ",
1017                    "PIVOT(SUM(a.amount) FOR a.MONTH IN ('REDACTED_1', 'REDACTED_2', 'REDACTED_3', 'REDACTED_4')) AS p (c, d) ",
1018                    "ORDER BY EMPID"
1019                ),
1020            ),
1021        ];
1022
1023        for (sql, expected) in tests {
1024            let mut visitor = MutatorVisitor::default();
1025            let mutated = do_visit_mut(sql, &mut visitor);
1026            assert_eq!(mutated.to_string(), expected)
1027        }
1028    }
1029}