feat: expose` DataFrame.parse_sql_expr` (#1274) · apache/datafusion-python@08901d5

3 files changed

lines changed

Original file line numberDiff line numberDiff line change

@@ -482,6 +482,28 @@ def filter(self, *predicates: Expr) -> DataFrame:

482482

df = df.filter(ensure_expr(p))

483483

return DataFrame(df)

484484
485+

def parse_sql_expr(self, expr: str) -> Expr:

486+

"""Creates logical expression from a SQL query text.

487+
488+

The expression is created and processed against the current schema.

489+
490+

Example::

491+
492+

from datafusion import col, lit

493+

df.parse_sql_expr("a > 1")

494+
495+

should produce:

496+
497+

col("a") > lit(1)

498+
499+

Args:

500+

expr: Expression string to be converted to datafusion expression

501+
502+

Returns:

503+

Logical expression .

504+

"""

505+

return Expr(self.df.parse_sql_expr(expr))

506+
485507

def with_column(self, name: str, expr: Expr) -> DataFrame:

486508

"""Add an additional column to the DataFrame.

487509
Original file line numberDiff line numberDiff line change

@@ -274,6 +274,36 @@ def test_filter(df):

274274

assert result.column(2) == pa.array([5])

275275
276276
277+

def test_parse_sql_expr(df):

278+

plan1 = df.filter(df.parse_sql_expr("a > 2")).logical_plan()

279+

plan2 = df.filter(column("a") > literal(2)).logical_plan()

280+

# object equality not implemented but string representation should match

281+

assert str(plan1) == str(plan2)

282+
283+

df1 = df.filter(df.parse_sql_expr("a > 2")).select(

284+

column("a") + column("b"),

285+

column("a") - column("b"),

286+

)

287+
288+

# execute and collect the first (and only) batch

289+

result = df1.collect()[0]

290+
291+

assert result.column(0) == pa.array([9])

292+

assert result.column(1) == pa.array([-3])

293+
294+

df.show()

295+

# verify that if there is no filter applied, internal dataframe is unchanged

296+

df2 = df.filter()

297+

assert df.df == df2.df

298+
299+

df3 = df.filter(df.parse_sql_expr("a > 1"), df.parse_sql_expr("b != 6"))

300+

result = df3.collect()[0]

301+
302+

assert result.column(0) == pa.array([2])

303+

assert result.column(1) == pa.array([5])

304+

assert result.column(2) == pa.array([5])

305+
306+
277307

def test_show_empty(df, capsys):

278308

df_empty = df.filter(column("a") > literal(3))

279309

df_empty.show()

Original file line numberDiff line numberDiff line change

@@ -454,6 +454,14 @@ impl PyDataFrame {

454454

Ok(Self::new(df))

455455

}

456456
457+

fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult<PyExpr> {

458+

self.df

459+

.as_ref()

460+

.parse_sql_expr(&expr)

461+

.map(|e| PyExpr::from(e))

462+

.map_err(PyDataFusionError::from)

463+

}

464+
457465

fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {

458466

let df = self.df.as_ref().clone().with_column(name, expr.into())?;

459467

Ok(Self::new(df))