feat: Offline Store historical features retrieval based on datetime r… · feast-dev/feast@27ec8ec

@@ -3,12 +3,13 @@

33

import uuid

44

import warnings

55

from dataclasses import asdict, dataclass

6-

from datetime import datetime, timezone

6+

from datetime import datetime, timedelta, timezone

77

from typing import (

88

TYPE_CHECKING,

99

Any,

1010

Callable,

1111

Dict,

12+

KeysView,

1213

List,

1314

Optional,

1415

Tuple,

@@ -151,10 +152,11 @@ def get_historical_features(

151152

config: RepoConfig,

152153

feature_views: List[FeatureView],

153154

feature_refs: List[str],

154-

entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame],

155+

entity_df: Optional[Union[pandas.DataFrame, str, pyspark.sql.DataFrame]],

155156

registry: BaseRegistry,

156157

project: str,

157158

full_feature_names: bool = False,

159+

**kwargs,

158160

) -> RetrievalJob:

159161

assert isinstance(config.offline_store, SparkOfflineStoreConfig)

160162

date_partition_column_formats = []

@@ -175,33 +177,75 @@ def get_historical_features(

175177

)

176178

tmp_entity_df_table_name = offline_utils.get_temp_entity_table_name()

177179178-

entity_schema = _get_entity_schema(

179-

spark_session=spark_session,

180-

entity_df=entity_df,

181-

)

182-

event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(

183-

entity_schema=entity_schema,

184-

)

185-

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(

186-

entity_df,

187-

event_timestamp_col,

188-

spark_session,

189-

)

190-

_upload_entity_df(

191-

spark_session=spark_session,

192-

table_name=tmp_entity_df_table_name,

193-

entity_df=entity_df,

194-

event_timestamp_col=event_timestamp_col,

195-

)

180+

# Non-entity mode: synthesize a left table and timestamp range from start/end dates to avoid requiring entity_df.

181+

# This makes date-range retrievals possible without enumerating entities upfront; sources remain bounded by time.

182+

non_entity_mode = entity_df is None

183+

if non_entity_mode:

184+

# Why: derive bounded time window without requiring entities; uses max TTL fallback to constrain scans.

185+

start_date, end_date = _compute_non_entity_dates(feature_views, kwargs)

186+

entity_df_event_timestamp_range = (start_date, end_date)

187+188+

# Build query contexts so we can reuse entity names and per-view table info consistently.

189+

fv_query_contexts = offline_utils.get_feature_view_query_context(

190+

feature_refs,

191+

feature_views,

192+

registry,

193+

project,

194+

entity_df_event_timestamp_range,

195+

)

196196197-

expected_join_keys = offline_utils.get_expected_join_keys(

198-

project=project, feature_views=feature_views, registry=registry

199-

)

200-

offline_utils.assert_expected_columns_in_entity_df(

201-

entity_schema=entity_schema,

202-

join_keys=expected_join_keys,

203-

entity_df_event_timestamp_col=event_timestamp_col,

204-

)

197+

# Collect the union of entity columns required across all feature views.

198+

all_entities = _gather_all_entities(fv_query_contexts)

199+200+

# Build a UNION DISTINCT of per-feature-view entity projections, time-bounded and partition-pruned.

201+

_create_temp_entity_union_view(

202+

spark_session=spark_session,

203+

tmp_view_name=tmp_entity_df_table_name,

204+

feature_views=feature_views,

205+

fv_query_contexts=fv_query_contexts,

206+

start_date=start_date,

207+

end_date=end_date,

208+

date_partition_column_formats=date_partition_column_formats,

209+

)

210+211+

# Add a stable as-of timestamp column for PIT joins.

212+

left_table_query_string, event_timestamp_col = _make_left_table_query(

213+

end_date=end_date, tmp_view_name=tmp_entity_df_table_name

214+

)

215+

entity_schema_keys = _entity_schema_keys_from(

216+

all_entities=all_entities, event_timestamp_col=event_timestamp_col

217+

)

218+

else:

219+

entity_schema = _get_entity_schema(

220+

spark_session=spark_session,

221+

entity_df=entity_df,

222+

)

223+

event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(

224+

entity_schema=entity_schema,

225+

)

226+

entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(

227+

entity_df,

228+

event_timestamp_col,

229+

spark_session,

230+

)

231+

_upload_entity_df(

232+

spark_session=spark_session,

233+

table_name=tmp_entity_df_table_name,

234+

entity_df=entity_df,

235+

event_timestamp_col=event_timestamp_col,

236+

)

237+

left_table_query_string = tmp_entity_df_table_name

238+

entity_schema_keys = cast(KeysView[str], entity_schema.keys())

239+240+

if not non_entity_mode:

241+

expected_join_keys = offline_utils.get_expected_join_keys(

242+

project=project, feature_views=feature_views, registry=registry

243+

)

244+

offline_utils.assert_expected_columns_in_entity_df(

245+

entity_schema=entity_schema,

246+

join_keys=expected_join_keys,

247+

entity_df_event_timestamp_col=event_timestamp_col,

248+

)

205249206250

query_context = offline_utils.get_feature_view_query_context(

207251

feature_refs,

@@ -232,9 +276,9 @@ def get_historical_features(

232276

feature_view_query_contexts=cast(

233277

List[offline_utils.FeatureViewQueryContext], spark_query_context

234278

),

235-

left_table_query_string=tmp_entity_df_table_name,

279+

left_table_query_string=left_table_query_string,

236280

entity_df_event_timestamp_col=event_timestamp_col,

237-

entity_df_columns=entity_schema.keys(),

281+

entity_df_columns=entity_schema_keys,

238282

query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,

239283

full_feature_names=full_feature_names,

240284

)

@@ -248,7 +292,7 @@ def get_historical_features(

248292

),

249293

metadata=RetrievalMetadata(

250294

features=feature_refs,

251-

keys=list(set(entity_schema.keys()) - {event_timestamp_col}),

295+

keys=list(set(entity_schema_keys) - {event_timestamp_col}),

252296

min_event_timestamp=entity_df_event_timestamp_range[0],

253297

max_event_timestamp=entity_df_event_timestamp_range[1],

254298

),

@@ -540,6 +584,114 @@ def get_spark_session_or_start_new_with_repoconfig(

540584

return spark_session

541585542586587+

def _compute_non_entity_dates(

588+

feature_views: List[FeatureView], kwargs: Dict[str, Any]

589+

) -> Tuple[datetime, datetime]:

590+

# Why: bounds the scan window when no entity_df is provided using explicit dates or max TTL fallback.

591+

start_date_opt = cast(Optional[datetime], kwargs.get("start_date"))

592+

end_date_opt = cast(Optional[datetime], kwargs.get("end_date"))

593+

end_date: datetime = end_date_opt or datetime.now(timezone.utc)

594+595+

if start_date_opt is None:

596+

max_ttl_seconds = 0

597+

for fv in feature_views:

598+

if fv.ttl and isinstance(fv.ttl, timedelta):

599+

max_ttl_seconds = max(max_ttl_seconds, int(fv.ttl.total_seconds()))

600+

start_date: datetime = (

601+

end_date - timedelta(seconds=max_ttl_seconds)

602+

if max_ttl_seconds > 0

603+

else end_date - timedelta(days=30)

604+

)

605+

else:

606+

start_date = start_date_opt

607+

return (start_date, end_date)

608+609+610+

def _gather_all_entities(

611+

fv_query_contexts: List[offline_utils.FeatureViewQueryContext],

612+

) -> List[str]:

613+

# Why: ensure a unified entity set across feature views to align UNION schemas.

614+

all_entities: List[str] = []

615+

for ctx in fv_query_contexts:

616+

for e in ctx.entities:

617+

if e not in all_entities:

618+

all_entities.append(e)

619+

return all_entities

620+621+622+

def _create_temp_entity_union_view(

623+

spark_session: SparkSession,

624+

tmp_view_name: str,

625+

feature_views: List[FeatureView],

626+

fv_query_contexts: List[offline_utils.FeatureViewQueryContext],

627+

start_date: datetime,

628+

end_date: datetime,

629+

date_partition_column_formats: List[Optional[str]],

630+

) -> None:

631+

# Why: derive distinct entity keys observed in the time window without requiring an entity_df upfront.

632+

start_date_str = _format_datetime(start_date)

633+

end_date_str = _format_datetime(end_date)

634+635+

# Compute the unified entity set to align schemas in the UNION.

636+

all_entities = _gather_all_entities(fv_query_contexts)

637+638+

per_view_selects: List[str] = []

639+

for fv, ctx, date_format in zip(

640+

feature_views, fv_query_contexts, date_partition_column_formats

641+

):

642+

assert isinstance(fv.batch_source, SparkSource)

643+

from_expression = fv.batch_source.get_table_query_string()

644+

timestamp_field = fv.batch_source.timestamp_field or "event_timestamp"

645+

date_partition_column = fv.batch_source.date_partition_column

646+

partition_clause = ""

647+

if date_partition_column and date_format:

648+

partition_clause = (

649+

f" AND {date_partition_column} >= '{start_date.strftime(date_format)}'"

650+

f" AND {date_partition_column} <= '{end_date.strftime(date_format)}'"

651+

)

652+653+

# Fill missing entity columns with NULL and cast to STRING to keep UNION schemas aligned.

654+

select_entities: List[str] = []

655+

ctx_entities_set = set(ctx.entities)

656+

for col in all_entities:

657+

if col in ctx_entities_set:

658+

select_entities.append(f"CAST({col} AS STRING) AS {col}")

659+

else:

660+

select_entities.append(f"CAST(NULL AS STRING) AS {col}")

661+662+

per_view_selects.append(

663+

f"""

664+

SELECT DISTINCT {", ".join(select_entities)}

665+

FROM {from_expression}

666+

WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){partition_clause}

667+

"""

668+

)

669+670+

union_query = "\nUNION DISTINCT\n".join([s.strip() for s in per_view_selects])

671+

spark_session.sql(

672+

f"CREATE OR REPLACE TEMPORARY VIEW {tmp_view_name} AS {union_query}"

673+

)

674+675+676+

def _make_left_table_query(end_date: datetime, tmp_view_name: str) -> Tuple[str, str]:

677+

# Why: use a stable as-of timestamp for PIT joins when no entity timestamps are provided.

678+

event_timestamp_col = "entity_ts"

679+

left_table_query_string = (

680+

f"(SELECT *, TIMESTAMP('{_format_datetime(end_date)}') AS {event_timestamp_col} "

681+

f"FROM {tmp_view_name})"

682+

)

683+

return left_table_query_string, event_timestamp_col

684+685+686+

def _entity_schema_keys_from(

687+

all_entities: List[str], event_timestamp_col: str

688+

) -> KeysView[str]:

689+

# Why: pass a KeysView[str] to PIT query builder to match entity_df branch typing.

690+

return cast(

691+

KeysView[str], {k: None for k in (all_entities + [event_timestamp_col])}.keys()

692+

)

693+694+543695

def _get_entity_df_event_timestamp_range(

544696

entity_df: Union[pd.DataFrame, str],

545697

entity_df_event_timestamp_col: str,