oxide_sql_core/ast/
expression.rs1//! Expression AST types.
2
3use core::fmt;
4
5use crate::lexer::Span;
6
7/// A literal value.
8#[derive(Debug, Clone, PartialEq)]
9pub enum Literal {
10 /// Integer literal.
11 Integer(i64),
12 /// Float literal.
13 Float(f64),
14 /// String literal.
15 String(String),
16 /// Blob literal.
17 Blob(Vec<u8>),
18 /// Boolean literal.
19 Boolean(bool),
20 /// NULL literal.
21 Null,
22}
23
24/// Binary operators.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum BinaryOp {
27 // Arithmetic
28 Add,
29 Sub,
30 Mul,
31 Div,
32 Mod,
33
34 // Comparison
35 Eq,
36 NotEq,
37 Lt,
38 LtEq,
39 Gt,
40 GtEq,
41
42 // Logical
43 And,
44 Or,
45
46 // String
47 Concat,
48 Like,
49
50 // Bitwise
51 BitAnd,
52 BitOr,
53 LeftShift,
54 RightShift,
55}
56
57impl BinaryOp {
58 /// Returns the SQL representation of the operator.
59 #[must_use]
60 pub const fn as_str(&self) -> &'static str {
61 match self {
62 Self::Add => "+",
63 Self::Sub => "-",
64 Self::Mul => "*",
65 Self::Div => "/",
66 Self::Mod => "%",
67 Self::Eq => "=",
68 Self::NotEq => "!=",
69 Self::Lt => "<",
70 Self::LtEq => "<=",
71 Self::Gt => ">",
72 Self::GtEq => ">=",
73 Self::And => "AND",
74 Self::Or => "OR",
75 Self::Concat => "||",
76 Self::Like => "LIKE",
77 Self::BitAnd => "&",
78 Self::BitOr => "|",
79 Self::LeftShift => "<<",
80 Self::RightShift => ">>",
81 }
82 }
83
84 /// Returns the precedence of the operator (higher = binds tighter).
85 #[must_use]
86 pub const fn precedence(&self) -> u8 {
87 match self {
88 Self::Or => 1,
89 Self::And => 2,
90 Self::Eq | Self::NotEq | Self::Lt | Self::LtEq | Self::Gt | Self::GtEq => 3,
91 Self::Like => 4,
92 Self::BitOr => 5,
93 Self::BitAnd => 6,
94 Self::LeftShift | Self::RightShift => 7,
95 Self::Add | Self::Sub | Self::Concat => 8,
96 Self::Mul | Self::Div | Self::Mod => 9,
97 }
98 }
99}
100
101impl fmt::Display for Literal {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self {
104 Self::Integer(n) => write!(f, "{n}"),
105 Self::Float(v) => write!(f, "{v}"),
106 Self::String(s) => {
107 let escaped = s.replace('\'', "''");
108 write!(f, "'{escaped}'")
109 }
110 Self::Blob(bytes) => {
111 write!(f, "X'")?;
112 for b in bytes {
113 write!(f, "{b:02X}")?;
114 }
115 write!(f, "'")
116 }
117 Self::Boolean(true) => write!(f, "TRUE"),
118 Self::Boolean(false) => write!(f, "FALSE"),
119 Self::Null => write!(f, "NULL"),
120 }
121 }
122}
123
124impl fmt::Display for BinaryOp {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.write_str(self.as_str())
127 }
128}
129
130/// Unary operators.
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub enum UnaryOp {
133 /// Negation (-)
134 Neg,
135 /// Logical NOT
136 Not,
137 /// Bitwise NOT (~)
138 BitNot,
139}
140
141impl UnaryOp {
142 /// Returns the SQL representation of the operator.
143 #[must_use]
144 pub const fn as_str(&self) -> &'static str {
145 match self {
146 Self::Neg => "-",
147 Self::Not => "NOT",
148 Self::BitNot => "~",
149 }
150 }
151}
152
153impl fmt::Display for UnaryOp {
154 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155 f.write_str(self.as_str())
156 }
157}
158
159/// A function call expression.
160#[derive(Debug, Clone, PartialEq)]
161pub struct FunctionCall {
162 /// The function name.
163 pub name: String,
164 /// The arguments.
165 pub args: Vec<Expr>,
166 /// Whether DISTINCT was specified.
167 pub distinct: bool,
168}
169
170/// An SQL expression.
171#[derive(Debug, Clone, PartialEq)]
172pub enum Expr {
173 /// A literal value.
174 Literal(Literal),
175
176 /// A column reference (optionally qualified with table name).
177 Column {
178 /// Table name or alias (optional).
179 table: Option<String>,
180 /// Column name.
181 name: String,
182 /// Source span.
183 span: Span,
184 },
185
186 /// A binary expression.
187 Binary {
188 /// Left operand.
189 left: Box<Expr>,
190 /// Operator.
191 op: BinaryOp,
192 /// Right operand.
193 right: Box<Expr>,
194 },
195
196 /// A unary expression.
197 Unary {
198 /// Operator.
199 op: UnaryOp,
200 /// Operand.
201 operand: Box<Expr>,
202 },
203
204 /// A function call.
205 Function(FunctionCall),
206
207 /// A subquery.
208 Subquery(Box<super::SelectStatement>),
209
210 /// IS NULL expression.
211 IsNull {
212 /// The expression to check.
213 expr: Box<Expr>,
214 /// Whether this is IS NOT NULL.
215 negated: bool,
216 },
217
218 /// IN expression.
219 In {
220 /// The expression to check.
221 expr: Box<Expr>,
222 /// The list of values or subquery.
223 list: Vec<Expr>,
224 /// Whether this is NOT IN.
225 negated: bool,
226 },
227
228 /// BETWEEN expression.
229 Between {
230 /// The expression to check.
231 expr: Box<Expr>,
232 /// Lower bound.
233 low: Box<Expr>,
234 /// Upper bound.
235 high: Box<Expr>,
236 /// Whether this is NOT BETWEEN.
237 negated: bool,
238 },
239
240 /// CASE expression.
241 Case {
242 /// The operand (if any).
243 operand: Option<Box<Expr>>,
244 /// WHEN/THEN clauses.
245 when_clauses: Vec<(Expr, Expr)>,
246 /// ELSE clause.
247 else_clause: Option<Box<Expr>>,
248 },
249
250 /// CAST expression.
251 Cast {
252 /// Expression to cast.
253 expr: Box<Expr>,
254 /// Target type.
255 data_type: super::DataType,
256 },
257
258 /// Parenthesized expression.
259 Paren(Box<Expr>),
260
261 /// A parameter placeholder (? or :name).
262 Parameter {
263 /// The parameter index or name.
264 name: Option<String>,
265 /// Position in the query (1-based for ? placeholders).
266 position: usize,
267 },
268
269 /// Wildcard (*) in SELECT.
270 Wildcard {
271 /// Table qualifier (optional).
272 table: Option<String>,
273 },
274}
275
276impl fmt::Display for FunctionCall {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 // EXISTS gets special handling: the subquery already
279 // contains its own parentheses in the rendered form,
280 // so we render `EXISTS(SELECT ...)` instead of
281 // `EXISTS((SELECT ...))`.
282 if self.name == "EXISTS" {
283 if let [Expr::Subquery(q)] = self.args.as_slice() {
284 return write!(f, "EXISTS({q})");
285 }
286 }
287 write!(f, "{}(", self.name)?;
288 if self.distinct {
289 write!(f, "DISTINCT ")?;
290 }
291 for (i, arg) in self.args.iter().enumerate() {
292 if i > 0 {
293 write!(f, ", ")?;
294 }
295 write!(f, "{arg}")?;
296 }
297 write!(f, ")")
298 }
299}
300
301impl fmt::Display for Expr {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 match self {
304 Self::Literal(lit) => write!(f, "{lit}"),
305 Self::Column { table, name, .. } => {
306 if let Some(t) = table {
307 write!(f, "{t}.{name}")
308 } else {
309 write!(f, "{name}")
310 }
311 }
312 Self::Binary { left, op, right } => {
313 write!(f, "{left} {op} {right}")
314 }
315 Self::Unary { op, operand } => match op {
316 UnaryOp::Not => write!(f, "NOT {operand}"),
317 UnaryOp::Neg => write!(f, "-{operand}"),
318 UnaryOp::BitNot => write!(f, "~{operand}"),
319 },
320 Self::Function(func) => write!(f, "{func}"),
321 Self::Subquery(q) => write!(f, "({q})"),
322 Self::IsNull { expr, negated } => {
323 if *negated {
324 write!(f, "{expr} IS NOT NULL")
325 } else {
326 write!(f, "{expr} IS NULL")
327 }
328 }
329 Self::In {
330 expr,
331 list,
332 negated,
333 } => {
334 write!(f, "{expr}")?;
335 if *negated {
336 write!(f, " NOT IN (")?;
337 } else {
338 write!(f, " IN (")?;
339 }
340 for (i, item) in list.iter().enumerate() {
341 if i > 0 {
342 write!(f, ", ")?;
343 }
344 write!(f, "{item}")?;
345 }
346 write!(f, ")")
347 }
348 Self::Between {
349 expr,
350 low,
351 high,
352 negated,
353 } => {
354 if *negated {
355 write!(f, "{expr} NOT BETWEEN {low} AND {high}")
356 } else {
357 write!(f, "{expr} BETWEEN {low} AND {high}")
358 }
359 }
360 Self::Case {
361 operand,
362 when_clauses,
363 else_clause,
364 } => {
365 write!(f, "CASE")?;
366 if let Some(op) = operand {
367 write!(f, " {op}")?;
368 }
369 for (when, then) in when_clauses {
370 write!(f, " WHEN {when} THEN {then}")?;
371 }
372 if let Some(el) = else_clause {
373 write!(f, " ELSE {el}")?;
374 }
375 write!(f, " END")
376 }
377 Self::Cast { expr, data_type } => {
378 write!(f, "CAST({expr} AS {data_type})")
379 }
380 Self::Paren(inner) => write!(f, "({inner})"),
381 Self::Parameter { name, .. } => {
382 if let Some(n) = name {
383 write!(f, ":{n}")
384 } else {
385 write!(f, "?")
386 }
387 }
388 Self::Wildcard { table } => {
389 if let Some(t) = table {
390 write!(f, "{t}.*")
391 } else {
392 write!(f, "*")
393 }
394 }
395 }
396 }
397}
398
399impl Expr {
400 /// Creates a new column reference.
401 #[must_use]
402 pub fn column(name: impl Into<String>) -> Self {
403 Self::Column {
404 table: None,
405 name: name.into(),
406 span: Span::default(),
407 }
408 }
409
410 /// Creates a new qualified column reference.
411 #[must_use]
412 pub fn qualified_column(table: impl Into<String>, name: impl Into<String>) -> Self {
413 Self::Column {
414 table: Some(table.into()),
415 name: name.into(),
416 span: Span::default(),
417 }
418 }
419
420 /// Creates a new integer literal.
421 #[must_use]
422 pub const fn integer(value: i64) -> Self {
423 Self::Literal(Literal::Integer(value))
424 }
425
426 /// Creates a new float literal.
427 #[must_use]
428 pub const fn float(value: f64) -> Self {
429 Self::Literal(Literal::Float(value))
430 }
431
432 /// Creates a new string literal.
433 #[must_use]
434 pub fn string(value: impl Into<String>) -> Self {
435 Self::Literal(Literal::String(value.into()))
436 }
437
438 /// Creates a new boolean literal.
439 #[must_use]
440 pub const fn boolean(value: bool) -> Self {
441 Self::Literal(Literal::Boolean(value))
442 }
443
444 /// Creates a NULL literal.
445 #[must_use]
446 pub const fn null() -> Self {
447 Self::Literal(Literal::Null)
448 }
449
450 /// Creates a binary expression.
451 #[must_use]
452 pub fn binary(self, op: BinaryOp, right: Self) -> Self {
453 Self::Binary {
454 left: Box::new(self),
455 op,
456 right: Box::new(right),
457 }
458 }
459
460 /// Creates an equality expression.
461 #[must_use]
462 pub fn eq(self, right: Self) -> Self {
463 self.binary(BinaryOp::Eq, right)
464 }
465
466 /// Creates an inequality expression.
467 #[must_use]
468 pub fn not_eq(self, right: Self) -> Self {
469 self.binary(BinaryOp::NotEq, right)
470 }
471
472 /// Creates a less-than expression.
473 #[must_use]
474 pub fn lt(self, right: Self) -> Self {
475 self.binary(BinaryOp::Lt, right)
476 }
477
478 /// Creates a less-than-or-equal expression.
479 #[must_use]
480 pub fn lt_eq(self, right: Self) -> Self {
481 self.binary(BinaryOp::LtEq, right)
482 }
483
484 /// Creates a greater-than expression.
485 #[must_use]
486 pub fn gt(self, right: Self) -> Self {
487 self.binary(BinaryOp::Gt, right)
488 }
489
490 /// Creates a greater-than-or-equal expression.
491 #[must_use]
492 pub fn gt_eq(self, right: Self) -> Self {
493 self.binary(BinaryOp::GtEq, right)
494 }
495
496 /// Creates an AND expression.
497 #[must_use]
498 pub fn and(self, right: Self) -> Self {
499 self.binary(BinaryOp::And, right)
500 }
501
502 /// Creates an OR expression.
503 #[must_use]
504 pub fn or(self, right: Self) -> Self {
505 self.binary(BinaryOp::Or, right)
506 }
507
508 /// Creates an IS NULL expression.
509 #[must_use]
510 pub fn is_null(self) -> Self {
511 Self::IsNull {
512 expr: Box::new(self),
513 negated: false,
514 }
515 }
516
517 /// Creates an IS NOT NULL expression.
518 #[must_use]
519 pub fn is_not_null(self) -> Self {
520 Self::IsNull {
521 expr: Box::new(self),
522 negated: true,
523 }
524 }
525
526 /// Creates a BETWEEN expression.
527 #[must_use]
528 pub fn between(self, low: Self, high: Self) -> Self {
529 Self::Between {
530 expr: Box::new(self),
531 low: Box::new(low),
532 high: Box::new(high),
533 negated: false,
534 }
535 }
536
537 /// Creates a NOT BETWEEN expression.
538 #[must_use]
539 pub fn not_between(self, low: Self, high: Self) -> Self {
540 Self::Between {
541 expr: Box::new(self),
542 low: Box::new(low),
543 high: Box::new(high),
544 negated: true,
545 }
546 }
547
548 /// Creates an IN expression.
549 #[must_use]
550 pub fn in_list(self, list: Vec<Self>) -> Self {
551 Self::In {
552 expr: Box::new(self),
553 list,
554 negated: false,
555 }
556 }
557
558 /// Creates a NOT IN expression.
559 #[must_use]
560 pub fn not_in_list(self, list: Vec<Self>) -> Self {
561 Self::In {
562 expr: Box::new(self),
563 list,
564 negated: true,
565 }
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_binary_op_precedence() {
575 assert!(BinaryOp::Mul.precedence() > BinaryOp::Add.precedence());
576 assert!(BinaryOp::And.precedence() > BinaryOp::Or.precedence());
577 assert!(BinaryOp::Eq.precedence() > BinaryOp::And.precedence());
578 }
579
580 #[test]
581 fn test_expr_builders() {
582 let col = Expr::column("name");
583 assert!(matches!(col, Expr::Column { name, .. } if name == "name"));
584
585 let lit = Expr::integer(42);
586 assert!(matches!(lit, Expr::Literal(Literal::Integer(42))));
587 }
588
589 #[test]
590 fn test_expr_chaining() {
591 let expr = Expr::column("age")
592 .gt(Expr::integer(18))
593 .and(Expr::column("status").eq(Expr::string("active")));
594
595 assert!(matches!(
596 expr,
597 Expr::Binary {
598 op: BinaryOp::And,
599 ..
600 }
601 ));
602 }
603}