feat: Add multiple entity support to dbt integration (#5901) · feast-dev/feast@05a4fb5

@@ -6,7 +6,7 @@

66

"""

7788

import logging

9-

from typing import Any, List, Optional, Set

9+

from typing import Any, List, Optional, Set, Union

10101111

from jinja2 import BaseLoader, Environment

1212106106

{% for fv in feature_views %}

107107

{{ fv.var_name }} = FeatureView(

108108

name="{{ fv.name }}",

109-

entities=[{{ fv.entity_var }}],

109+

entities=[{{ fv.entity_vars | join(', ') }}],

110110

ttl=timedelta(days={{ fv.ttl_days }}),

111111

schema=[

112112

{% for field in fv.fields %}

@@ -220,7 +220,7 @@ def __init__(

220220

def generate(

221221

self,

222222

models: List[DbtModel],

223-

entity_column: str,

223+

entity_columns: Union[str, List[str]],

224224

manifest_path: str = "",

225225

project_name: str = "",

226226

exclude_columns: Optional[List[str]] = None,

@@ -231,7 +231,7 @@ def generate(

231231232232

Args:

233233

models: List of DbtModel objects to generate code for

234-

entity_column: The entity/primary key column name

234+

entity_columns: Entity column name(s) - single string or list of strings

235235

manifest_path: Path to the dbt manifest (for documentation)

236236

project_name: dbt project name (for documentation)

237237

exclude_columns: Columns to exclude from features

@@ -240,25 +240,36 @@ def generate(

240240

Returns:

241241

Generated Python code as a string

242242

"""

243-

excluded = {entity_column, self.timestamp_field}

243+

# Normalize entity_columns to list

244+

entity_cols: List[str] = (

245+

[entity_columns] if isinstance(entity_columns, str) else entity_columns

246+

)

247+248+

if not entity_cols:

249+

raise ValueError("At least one entity column must be specified")

250+251+

excluded = set(entity_cols) | {self.timestamp_field}

244252

if exclude_columns:

245253

excluded.update(exclude_columns)

246254247255

# Collect all Feast types used for imports

248256

type_imports: Set[str] = set()

249257250-

# Prepare entity data

258+

# Prepare entity data - create one entity per entity column

251259

entities = []

252-

entity_var = _make_var_name(entity_column)

253-

entities.append(

254-

{

255-

"var_name": entity_var,

256-

"name": entity_column,

257-

"join_key": entity_column,

258-

"description": "Entity key for dbt models",

259-

"tags": {"source": "dbt"},

260-

}

261-

)

260+

entity_vars = [] # Track variable names for feature views

261+

for entity_col in entity_cols:

262+

entity_var = _make_var_name(entity_col)

263+

entity_vars.append(entity_var)

264+

entities.append(

265+

{

266+

"var_name": entity_var,

267+

"name": entity_col,

268+

"join_key": entity_col,

269+

"description": "Entity key for dbt models",

270+

"tags": {"source": "dbt"},

271+

}

272+

)

262273263274

# Prepare data sources and feature views

264275

data_sources = []

@@ -269,7 +280,9 @@ def generate(

269280

column_names = [c.name for c in model.columns]

270281

if self.timestamp_field not in column_names:

271282

continue

272-

if entity_column not in column_names:

283+284+

# Skip if ANY entity column is missing

285+

if not all(e in column_names for e in entity_cols):

273286

continue

274287275288

# Build tags

@@ -339,7 +352,7 @@ def generate(

339352

{

340353

"var_name": fv_var,

341354

"name": model.name,

342-

"entity_var": entity_var,

355+

"entity_vars": entity_vars,

343356

"source_var": source_var,

344357

"ttl_days": self.ttl_days,

345358

"fields": fields,

@@ -366,7 +379,7 @@ def generate(

366379367380

def generate_feast_code(

368381

models: List[DbtModel],

369-

entity_column: str,

382+

entity_columns: Union[str, List[str]],

370383

data_source_type: str = "bigquery",

371384

timestamp_field: str = "event_timestamp",

372385

ttl_days: int = 1,

@@ -380,7 +393,7 @@ def generate_feast_code(

380393381394

Args:

382395

models: List of DbtModel objects

383-

entity_column: Primary key column name

396+

entity_columns: Entity column name(s) - single string or list of strings

384397

data_source_type: Type of data source (bigquery, snowflake, file)

385398

timestamp_field: Timestamp column name

386399

ttl_days: TTL in days for feature views

@@ -400,7 +413,7 @@ def generate_feast_code(

400413401414

return generator.generate(

402415

models=models,

403-

entity_column=entity_column,

416+

entity_columns=entity_columns,

404417

manifest_path=manifest_path,

405418

project_name=project_name,

406419

exclude_columns=exclude_columns,