feat: Add multiple entity support to dbt integration (#5901) · feast-dev/feast@05a4fb5
@@ -6,7 +6,7 @@
66"""
7788import logging
9-from typing import Any, List, Optional, Set
9+from typing import Any, List, Optional, Set, Union
10101111from jinja2 import BaseLoader, Environment
1212106106{% for fv in feature_views %}
107107{{ fv.var_name }} = FeatureView(
108108 name="{{ fv.name }}",
109- entities=[{{ fv.entity_var }}],
109+ entities=[{{ fv.entity_vars | join(', ') }}],
110110 ttl=timedelta(days={{ fv.ttl_days }}),
111111 schema=[
112112{% for field in fv.fields %}
@@ -220,7 +220,7 @@ def __init__(
220220def generate(
221221self,
222222models: List[DbtModel],
223-entity_column: str,
223+entity_columns: Union[str, List[str]],
224224manifest_path: str = "",
225225project_name: str = "",
226226exclude_columns: Optional[List[str]] = None,
@@ -231,7 +231,7 @@ def generate(
231231232232 Args:
233233 models: List of DbtModel objects to generate code for
234- entity_column: The entity/primary key column name
234+ entity_columns: Entity column name(s) - single string or list of strings
235235 manifest_path: Path to the dbt manifest (for documentation)
236236 project_name: dbt project name (for documentation)
237237 exclude_columns: Columns to exclude from features
@@ -240,25 +240,36 @@ def generate(
240240 Returns:
241241 Generated Python code as a string
242242 """
243-excluded = {entity_column, self.timestamp_field}
243+# Normalize entity_columns to list
244+entity_cols: List[str] = (
245+ [entity_columns] if isinstance(entity_columns, str) else entity_columns
246+ )
247+248+if not entity_cols:
249+raise ValueError("At least one entity column must be specified")
250+251+excluded = set(entity_cols) | {self.timestamp_field}
244252if exclude_columns:
245253excluded.update(exclude_columns)
246254247255# Collect all Feast types used for imports
248256type_imports: Set[str] = set()
249257250-# Prepare entity data
258+# Prepare entity data - create one entity per entity column
251259entities = []
252-entity_var = _make_var_name(entity_column)
253-entities.append(
254- {
255-"var_name": entity_var,
256-"name": entity_column,
257-"join_key": entity_column,
258-"description": "Entity key for dbt models",
259-"tags": {"source": "dbt"},
260- }
261- )
260+entity_vars = [] # Track variable names for feature views
261+for entity_col in entity_cols:
262+entity_var = _make_var_name(entity_col)
263+entity_vars.append(entity_var)
264+entities.append(
265+ {
266+"var_name": entity_var,
267+"name": entity_col,
268+"join_key": entity_col,
269+"description": "Entity key for dbt models",
270+"tags": {"source": "dbt"},
271+ }
272+ )
262273263274# Prepare data sources and feature views
264275data_sources = []
@@ -269,7 +280,9 @@ def generate(
269280column_names = [c.name for c in model.columns]
270281if self.timestamp_field not in column_names:
271282continue
272-if entity_column not in column_names:
283+284+# Skip if ANY entity column is missing
285+if not all(e in column_names for e in entity_cols):
273286continue
274287275288# Build tags
@@ -339,7 +352,7 @@ def generate(
339352 {
340353"var_name": fv_var,
341354"name": model.name,
342-"entity_var": entity_var,
355+"entity_vars": entity_vars,
343356"source_var": source_var,
344357"ttl_days": self.ttl_days,
345358"fields": fields,
@@ -366,7 +379,7 @@ def generate(
366379367380def generate_feast_code(
368381models: List[DbtModel],
369-entity_column: str,
382+entity_columns: Union[str, List[str]],
370383data_source_type: str = "bigquery",
371384timestamp_field: str = "event_timestamp",
372385ttl_days: int = 1,
@@ -380,7 +393,7 @@ def generate_feast_code(
380393381394 Args:
382395 models: List of DbtModel objects
383- entity_column: Primary key column name
396+ entity_columns: Entity column name(s) - single string or list of strings
384397 data_source_type: Type of data source (bigquery, snowflake, file)
385398 timestamp_field: Timestamp column name
386399 ttl_days: TTL in days for feature views
@@ -400,7 +413,7 @@ def generate_feast_code(
400413401414return generator.generate(
402415models=models,
403-entity_column=entity_column,
416+entity_columns=entity_columns,
404417manifest_path=manifest_path,
405418project_name=project_name,
406419exclude_columns=exclude_columns,