feat: Offline Store historical features retrieval based on datetime r… · feast-dev/feast@27ec8ec
@@ -3,12 +3,13 @@
33import uuid
44import warnings
55from dataclasses import asdict, dataclass
6-from datetime import datetime, timezone
6+from datetime import datetime, timedelta, timezone
77from typing import (
88TYPE_CHECKING,
99Any,
1010Callable,
1111Dict,
12+KeysView,
1213List,
1314Optional,
1415Tuple,
@@ -151,10 +152,11 @@ def get_historical_features(
151152config: RepoConfig,
152153feature_views: List[FeatureView],
153154feature_refs: List[str],
154-entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],
155+entity_df: Optional[Union[pandas.DataFrame, str, pyspark.sql.DataFrame]],
155156registry: BaseRegistry,
156157project: str,
157158full_feature_names: bool = False,
159+**kwargs,
158160 ) -> RetrievalJob:
159161assert isinstance(config.offline_store, SparkOfflineStoreConfig)
160162date_partition_column_formats = []
@@ -175,33 +177,75 @@ def get_historical_features(
175177 )
176178tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name()
177179178-entity_schema = _get_entity_schema(
179-spark_session=spark_session,
180-entity_df=entity_df,
181- )
182-event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
183-entity_schema=entity_schema,
184- )
185-entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
186-entity_df,
187-event_timestamp_col,
188-spark_session,
189- )
190-_upload_entity_df(
191-spark_session=spark_session,
192-table_name=tmp_entity_df_table_name,
193-entity_df=entity_df,
194-event_timestamp_col=event_timestamp_col,
195- )
180+# Non-entity mode: synthesize a left table and timestamp range from start/end dates to avoid requiring entity_df.
181+# This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time.
182+non_entity_mode = entity_df is None
183+if non_entity_mode:
184+# Why: derive bounded time window without requiring entities; uses max TTL fallback to constrain scans.
185+start_date, end_date = _compute_non_entity_dates(feature_views, kwargs)
186+entity_df_event_timestamp_range = (start_date, end_date)
187+188+# Build query contexts so we can reuse entity names and per-view table info consistently.
189+fv_query_contexts = offline_utils.get_feature_view_query_context(
190+feature_refs,
191+feature_views,
192+registry,
193+project,
194+entity_df_event_timestamp_range,
195+ )
196196197-expected_join_keys = offline_utils.get_expected_join_keys(
198-project=project, feature_views=feature_views, registry=registry
199- )
200-offline_utils.assert_expected_columns_in_entity_df(
201-entity_schema=entity_schema,
202-join_keys=expected_join_keys,
203-entity_df_event_timestamp_col=event_timestamp_col,
204- )
197+# Collect the union of entity columns required across all feature views.
198+all_entities = _gather_all_entities(fv_query_contexts)
199+200+# Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned.
201+_create_temp_entity_union_view(
202+spark_session=spark_session,
203+tmp_view_name=tmp_entity_df_table_name,
204+feature_views=feature_views,
205+fv_query_contexts=fv_query_contexts,
206+start_date=start_date,
207+end_date=end_date,
208+date_partition_column_formats=date_partition_column_formats,
209+ )
210+211+# Add a stable as-of timestamp column for PIT joins.
212+left_table_query_string, event_timestamp_col = _make_left_table_query(
213+end_date=end_date, tmp_view_name=tmp_entity_df_table_name
214+ )
215+entity_schema_keys = _entity_schema_keys_from(
216+all_entities=all_entities, event_timestamp_col=event_timestamp_col
217+ )
218+else:
219+entity_schema = _get_entity_schema(
220+spark_session=spark_session,
221+entity_df=entity_df,
222+ )
223+event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
224+entity_schema=entity_schema,
225+ )
226+entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
227+entity_df,
228+event_timestamp_col,
229+spark_session,
230+ )
231+_upload_entity_df(
232+spark_session=spark_session,
233+table_name=tmp_entity_df_table_name,
234+entity_df=entity_df,
235+event_timestamp_col=event_timestamp_col,
236+ )
237+left_table_query_string = tmp_entity_df_table_name
238+entity_schema_keys = cast(KeysView[str], entity_schema.keys())
239+240+if not non_entity_mode:
241+expected_join_keys = offline_utils.get_expected_join_keys(
242+project=project, feature_views=feature_views, registry=registry
243+ )
244+offline_utils.assert_expected_columns_in_entity_df(
245+entity_schema=entity_schema,
246+join_keys=expected_join_keys,
247+entity_df_event_timestamp_col=event_timestamp_col,
248+ )
205249206250query_context = offline_utils.get_feature_view_query_context(
207251feature_refs,
@@ -232,9 +276,9 @@ def get_historical_features(
232276feature_view_query_contexts=cast(
233277List[offline_utils.FeatureViewQueryContext], spark_query_context
234278 ),
235-left_table_query_string=tmp_entity_df_table_name,
279+left_table_query_string=left_table_query_string,
236280entity_df_event_timestamp_col=event_timestamp_col,
237-entity_df_columns=entity_schema.keys(),
281+entity_df_columns=entity_schema_keys,
238282query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
239283full_feature_names=full_feature_names,
240284 )
@@ -248,7 +292,7 @@ def get_historical_features(
248292 ),
249293metadata=RetrievalMetadata(
250294features=feature_refs,
251-keys=list(set(entity_schema.keys()) - {event_timestamp_col}),
295+keys=list(set(entity_schema_keys) - {event_timestamp_col}),
252296min_event_timestamp=entity_df_event_timestamp_range[0],
253297max_event_timestamp=entity_df_event_timestamp_range[1],
254298 ),
@@ -540,6 +584,114 @@ def get_spark_session_or_start_new_with_repoconfig(
540584return spark_session
541585542586587+def _compute_non_entity_dates(
588+feature_views: List[FeatureView], kwargs: Dict[str, Any]
589+) -> Tuple[datetime, datetime]:
590+# Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback.
591+start_date_opt = cast(Optional[datetime], kwargs.get("start_date"))
592+end_date_opt = cast(Optional[datetime], kwargs.get("end_date"))
593+end_date: datetime = end_date_opt or datetime.now(timezone.utc)
594+595+if start_date_opt is None:
596+max_ttl_seconds = 0
597+for fv in feature_views:
598+if fv.ttl and isinstance(fv.ttl, timedelta):
599+max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds()))
600+start_date: datetime = (
601+end_date - timedelta(seconds=max_ttl_seconds)
602+if max_ttl_seconds > 0
603+else end_date - timedelta(days=30)
604+ )
605+else:
606+start_date = start_date_opt
607+return (start_date, end_date)
608+609+610+def _gather_all_entities(
611+fv_query_contexts: List[offline_utils.FeatureViewQueryContext],
612+) -> List[str]:
613+# Why: ensure a unified entity set across feature views to align UNION schemas.
614+all_entities: List[str] = []
615+for ctx in fv_query_contexts:
616+for e in ctx.entities:
617+if e not in all_entities:
618+all_entities.append(e)
619+return all_entities
620+621+622+def _create_temp_entity_union_view(
623+spark_session: SparkSession,
624+tmp_view_name: str,
625+feature_views: List[FeatureView],
626+fv_query_contexts: List[offline_utils.FeatureViewQueryContext],
627+start_date: datetime,
628+end_date: datetime,
629+date_partition_column_formats: List[Optional[str]],
630+) -> None:
631+# Why: derive distinct entity keys observed in the time window without requiring an entity_df upfront.
632+start_date_str = _format_datetime(start_date)
633+end_date_str = _format_datetime(end_date)
634+635+# Compute the unified entity set to align schemas in the UNION.
636+all_entities = _gather_all_entities(fv_query_contexts)
637+638+per_view_selects: List[str] = []
639+for fv, ctx, date_format in zip(
640+feature_views, fv_query_contexts, date_partition_column_formats
641+ ):
642+assert isinstance(fv.batch_source, SparkSource)
643+from_expression = fv.batch_source.get_table_query_string()
644+timestamp_field = fv.batch_source.timestamp_field or "event_timestamp"
645+date_partition_column = fv.batch_source.date_partition_column
646+partition_clause = ""
647+if date_partition_column and date_format:
648+partition_clause = (
649+f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'"
650+f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'"
651+ )
652+653+# Fill missing entity columns with NULL and cast to STRING to keep UNION schemas aligned.
654+select_entities: List[str] = []
655+ctx_entities_set = set(ctx.entities)
656+for col in all_entities:
657+if col in ctx_entities_set:
658+select_entities.append(f"CAST({col} AS STRING) AS {col}")
659+else:
660+select_entities.append(f"CAST(NULL AS STRING) AS {col}")
661+662+per_view_selects.append(
663+f"""
664+ SELECT DISTINCT {", ".join(select_entities)}
665+ FROM {from_expression}
666+ WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause}
667+ """
668+ )
669+670+union_query = "\nUNION DISTINCT\n".join([s.strip() for s in per_view_selects])
671+spark_session.sql(
672+f"CREATE OR REPLACE TEMPORARY VIEW {tmp_view_name} AS {union_query}"
673+ )
674+675+676+def _make_left_table_query(end_date: datetime, tmp_view_name: str) -> Tuple[str, str]:
677+# Why: use a stable as-of timestamp for PIT joins when no entity timestamps are provided.
678+event_timestamp_col = "entity_ts"
679+left_table_query_string = (
680+f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS {event_timestamp_col} "
681+f"FROM {tmp_view_name})"
682+ )
683+return left_table_query_string, event_timestamp_col
684+685+686+def _entity_schema_keys_from(
687+all_entities: List[str], event_timestamp_col: str
688+) -> KeysView[str]:
689+# Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing.
690+return cast(
691+KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys()
692+ )
693+694+543695def _get_entity_df_event_timestamp_range(
544696entity_df: Union[pd.DataFrame, str],
545697entity_df_event_timestamp_col: str,