feat: Added RemoteDatasetProxy that executes Ray Data operations remo… · feast-dev/feast@7128024
1-from typing import Dict, List, Optional, Union
1+from typing import Any, Dict, List, Optional, Union
2233import numpy as np
44import pandas as pd
5+import pyarrow as pa
6+import ray
57from ray.data import Dataset
687910+class RemoteDatasetProxy:
11+"""Proxy class that executes Ray Data operations remotely on cluster workers."""
12+13+def __init__(self, dataset_ref: Any):
14+"""Initialize with a reference to the remote dataset."""
15+self._dataset_ref = dataset_ref
16+17+def map_batches(self, func, **kwargs) -> "RemoteDatasetProxy":
18+"""Execute map_batches remotely on cluster workers."""
19+20+@ray.remote
21+def _remote_map_batches(dataset, function, batch_kwargs):
22+return dataset.map_batches(function, **batch_kwargs)
23+24+new_ref = _remote_map_batches.remote(self._dataset_ref, func, kwargs)
25+return RemoteDatasetProxy(new_ref)
26+27+def filter(self, fn) -> "RemoteDatasetProxy":
28+"""Execute filter remotely on cluster workers."""
29+30+@ray.remote
31+def _remote_filter(dataset, filter_fn):
32+return dataset.filter(filter_fn)
33+34+new_ref = _remote_filter.remote(self._dataset_ref, fn)
35+return RemoteDatasetProxy(new_ref)
36+37+def to_pandas(self) -> pd.DataFrame:
38+"""Execute to_pandas remotely and transfer result to client."""
39+40+@ray.remote
41+def _remote_to_pandas(dataset):
42+return dataset.to_pandas()
43+44+result_ref = _remote_to_pandas.remote(self._dataset_ref)
45+return ray.get(result_ref)
46+47+def to_arrow(self) -> pa.Table:
48+"""Execute to_arrow remotely and transfer result to client."""
49+50+@ray.remote
51+def _remote_to_arrow(dataset):
52+return dataset.to_arrow()
53+54+result_ref = _remote_to_arrow.remote(self._dataset_ref)
55+return ray.get(result_ref)
56+57+def schema(self) -> Any:
58+"""Get dataset schema."""
59+60+@ray.remote
61+def _remote_schema(dataset):
62+return dataset.schema()
63+64+schema_ref = _remote_schema.remote(self._dataset_ref)
65+return ray.get(schema_ref)
66+67+def sort(self, key, descending=False) -> "RemoteDatasetProxy":
68+"""Execute sort remotely on cluster workers."""
69+70+@ray.remote
71+def _remote_sort(dataset, sort_key, desc):
72+return dataset.sort(sort_key, descending=desc)
73+74+new_ref = _remote_sort.remote(self._dataset_ref, key, descending)
75+return RemoteDatasetProxy(new_ref)
76+77+def limit(self, count) -> "RemoteDatasetProxy":
78+"""Execute limit remotely on cluster workers."""
79+80+@ray.remote
81+def _remote_limit(dataset, limit_count):
82+return dataset.limit(limit_count)
83+84+new_ref = _remote_limit.remote(self._dataset_ref, count)
85+return RemoteDatasetProxy(new_ref)
86+87+def union(self, other) -> "RemoteDatasetProxy":
88+"""Execute union remotely on cluster workers."""
89+90+@ray.remote
91+def _remote_union(dataset1, dataset2):
92+return dataset1.union(dataset2)
93+94+new_ref = _remote_union.remote(self._dataset_ref, other._dataset_ref)
95+return RemoteDatasetProxy(new_ref)
96+97+def materialize(self) -> "RemoteDatasetProxy":
98+"""Execute materialize remotely on cluster workers."""
99+100+@ray.remote
101+def _remote_materialize(dataset):
102+return dataset.materialize()
103+104+new_ref = _remote_materialize.remote(self._dataset_ref)
105+return RemoteDatasetProxy(new_ref)
106+107+def count(self) -> int:
108+"""Execute count remotely and return result."""
109+110+@ray.remote
111+def _remote_count(dataset):
112+return dataset.count()
113+114+result_ref = _remote_count.remote(self._dataset_ref)
115+return ray.get(result_ref)
116+117+def take(self, n=20) -> list:
118+"""Execute take remotely and return result."""
119+120+@ray.remote
121+def _remote_take(dataset, num):
122+return dataset.take(num)
123+124+result_ref = _remote_take.remote(self._dataset_ref, n)
125+return ray.get(result_ref)
126+127+def __getattr__(self, name):
128+"""Catch any method calls that we haven't explicitly implemented."""
129+raise AttributeError(f"RemoteDatasetProxy has no attribute '{name}'")
130+131+132+def is_ray_data(data: Any) -> bool:
133+"""Check if data is a Ray Dataset or RemoteDatasetProxy."""
134+return isinstance(data, (Dataset, RemoteDatasetProxy))
135+136+8137def normalize_timestamp_columns(
9-data: Union[pd.DataFrame, Dataset],
138+data: Union[pd.DataFrame, Dataset, Any],
10139columns: Union[str, List[str]],
11140inplace: bool = False,
12141exclude_columns: Optional[List[str]] = None,
13-) -> Union[pd.DataFrame, Dataset]:
142+) -> Union[pd.DataFrame, Dataset, Any]:
14143column_list = [columns] if isinstance(columns, str) else columns
15144exclude_columns = exclude_columns or []
16145@@ -21,7 +150,7 @@ def apply_normalization(series: pd.Series) -> pd.Series:
21150 .astype("datetime64[ns, UTC]")
22151 )
2315224-if isinstance(data, Dataset):
153+if is_ray_data(data):
2515426155def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
27156for column in column_list:
@@ -35,6 +164,7 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
3516436165return data.map_batches(normalize_batch, batch_format="pandas")
37166else:
167+assert isinstance(data, pd.DataFrame)
38168if not inplace:
39169data = data.copy()
40170for column in column_list:
@@ -44,13 +174,13 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
441744517546176def ensure_timestamp_compatibility(
47-data: Union[pd.DataFrame, Dataset],
177+data: Union[pd.DataFrame, Dataset, Any],
48178timestamp_fields: List[str],
49179inplace: bool = False,
50-) -> Union[pd.DataFrame, Dataset]:
180+) -> Union[pd.DataFrame, Dataset, Any]:
51181from feast.utils import make_df_tzaware
5218253-if isinstance(data, Dataset):
183+if is_ray_data(data):
5418455185def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
56186batch = make_df_tzaware(batch)
@@ -65,6 +195,7 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
6519566196return data.map_batches(ensure_compatibility, batch_format="pandas")
67197else:
198+assert isinstance(data, pd.DataFrame)
68199if not inplace:
69200data = data.copy()
70201from feast.utils import make_df_tzaware
@@ -77,22 +208,24 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
772087820979210def apply_field_mapping(
80-data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str]
81-) -> Union[pd.DataFrame, Dataset]:
211+data: Union[pd.DataFrame, Dataset, Any],
212+field_mapping: Dict[str, str],
213+) -> Union[pd.DataFrame, Dataset, Any]:
82214def rename_columns(df: pd.DataFrame) -> pd.DataFrame:
83215return df.rename(columns=field_mapping)
8421685-if isinstance(data, Dataset):
217+if is_ray_data(data):
86218return data.map_batches(rename_columns, batch_format="pandas")
87219else:
220+assert isinstance(data, pd.DataFrame)
88221return data.rename(columns=field_mapping)
892229022391224def deduplicate_by_keys_and_timestamp(
92-data: Union[pd.DataFrame, Dataset],
225+data: Union[pd.DataFrame, Dataset, Any],
93226join_keys: List[str],
94227timestamp_columns: List[str],
95-) -> Union[pd.DataFrame, Dataset]:
228+) -> Union[pd.DataFrame, Dataset, Any]:
96229def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
97230if batch.empty:
98231return batch
@@ -110,9 +243,10 @@ def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
110243return deduped_batch
111244return batch
112245113-if isinstance(data, Dataset):
246+if is_ray_data(data):
114247return data.map_batches(deduplicate_batch, batch_format="pandas")
115248else:
249+assert isinstance(data, pd.DataFrame)
116250return deduplicate_batch(data)
117251118252