feat: Added RemoteDatasetProxy that executes Ray Data operations remo… · feast-dev/feast@7128024

1-

from typing import Dict, List, Optional, Union

1+

from typing import Any, Dict, List, Optional, Union

2233

import numpy as np

44

import pandas as pd

5+

import pyarrow as pa

6+

import ray

57

from ray.data import Dataset

687910+

class RemoteDatasetProxy:

11+

"""Proxy class that executes Ray Data operations remotely on cluster workers."""

12+13+

def __init__(self, dataset_ref: Any):

14+

"""Initialize with a reference to the remote dataset."""

15+

self._dataset_ref = dataset_ref

16+17+

def map_batches(self, func, **kwargs) -> "RemoteDatasetProxy":

18+

"""Execute map_batches remotely on cluster workers."""

19+20+

@ray.remote

21+

def _remote_map_batches(dataset, function, batch_kwargs):

22+

return dataset.map_batches(function, **batch_kwargs)

23+24+

new_ref = _remote_map_batches.remote(self._dataset_ref, func, kwargs)

25+

return RemoteDatasetProxy(new_ref)

26+27+

def filter(self, fn) -> "RemoteDatasetProxy":

28+

"""Execute filter remotely on cluster workers."""

29+30+

@ray.remote

31+

def _remote_filter(dataset, filter_fn):

32+

return dataset.filter(filter_fn)

33+34+

new_ref = _remote_filter.remote(self._dataset_ref, fn)

35+

return RemoteDatasetProxy(new_ref)

36+37+

def to_pandas(self) -> pd.DataFrame:

38+

"""Execute to_pandas remotely and transfer result to client."""

39+40+

@ray.remote

41+

def _remote_to_pandas(dataset):

42+

return dataset.to_pandas()

43+44+

result_ref = _remote_to_pandas.remote(self._dataset_ref)

45+

return ray.get(result_ref)

46+47+

def to_arrow(self) -> pa.Table:

48+

"""Execute to_arrow remotely and transfer result to client."""

49+50+

@ray.remote

51+

def _remote_to_arrow(dataset):

52+

return dataset.to_arrow()

53+54+

result_ref = _remote_to_arrow.remote(self._dataset_ref)

55+

return ray.get(result_ref)

56+57+

def schema(self) -> Any:

58+

"""Get dataset schema."""

59+60+

@ray.remote

61+

def _remote_schema(dataset):

62+

return dataset.schema()

63+64+

schema_ref = _remote_schema.remote(self._dataset_ref)

65+

return ray.get(schema_ref)

66+67+

def sort(self, key, descending=False) -> "RemoteDatasetProxy":

68+

"""Execute sort remotely on cluster workers."""

69+70+

@ray.remote

71+

def _remote_sort(dataset, sort_key, desc):

72+

return dataset.sort(sort_key, descending=desc)

73+74+

new_ref = _remote_sort.remote(self._dataset_ref, key, descending)

75+

return RemoteDatasetProxy(new_ref)

76+77+

def limit(self, count) -> "RemoteDatasetProxy":

78+

"""Execute limit remotely on cluster workers."""

79+80+

@ray.remote

81+

def _remote_limit(dataset, limit_count):

82+

return dataset.limit(limit_count)

83+84+

new_ref = _remote_limit.remote(self._dataset_ref, count)

85+

return RemoteDatasetProxy(new_ref)

86+87+

def union(self, other) -> "RemoteDatasetProxy":

88+

"""Execute union remotely on cluster workers."""

89+90+

@ray.remote

91+

def _remote_union(dataset1, dataset2):

92+

return dataset1.union(dataset2)

93+94+

new_ref = _remote_union.remote(self._dataset_ref, other._dataset_ref)

95+

return RemoteDatasetProxy(new_ref)

96+97+

def materialize(self) -> "RemoteDatasetProxy":

98+

"""Execute materialize remotely on cluster workers."""

99+100+

@ray.remote

101+

def _remote_materialize(dataset):

102+

return dataset.materialize()

103+104+

new_ref = _remote_materialize.remote(self._dataset_ref)

105+

return RemoteDatasetProxy(new_ref)

106+107+

def count(self) -> int:

108+

"""Execute count remotely and return result."""

109+110+

@ray.remote

111+

def _remote_count(dataset):

112+

return dataset.count()

113+114+

result_ref = _remote_count.remote(self._dataset_ref)

115+

return ray.get(result_ref)

116+117+

def take(self, n=20) -> list:

118+

"""Execute take remotely and return result."""

119+120+

@ray.remote

121+

def _remote_take(dataset, num):

122+

return dataset.take(num)

123+124+

result_ref = _remote_take.remote(self._dataset_ref, n)

125+

return ray.get(result_ref)

126+127+

def __getattr__(self, name):

128+

"""Catch any method calls that we haven't explicitly implemented."""

129+

raise AttributeError(f"RemoteDatasetProxy has no attribute '{name}'")

130+131+132+

def is_ray_data(data: Any) -> bool:

133+

"""Check if data is a Ray Dataset or RemoteDatasetProxy."""

134+

return isinstance(data, (Dataset, RemoteDatasetProxy))

135+136+8137

def normalize_timestamp_columns(

9-

data: Union[pd.DataFrame, Dataset],

138+

data: Union[pd.DataFrame, Dataset, Any],

10139

columns: Union[str, List[str]],

11140

inplace: bool = False,

12141

exclude_columns: Optional[List[str]] = None,

13-

) -> Union[pd.DataFrame, Dataset]:

142+

) -> Union[pd.DataFrame, Dataset, Any]:

14143

column_list = [columns] if isinstance(columns, str) else columns

15144

exclude_columns = exclude_columns or []

16145

@@ -21,7 +150,7 @@ def apply_normalization(series: pd.Series) -> pd.Series:

21150

.astype("datetime64[ns, UTC]")

22151

)

2315224-

if isinstance(data, Dataset):

153+

if is_ray_data(data):

2515426155

def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:

27156

for column in column_list:

@@ -35,6 +164,7 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:

3516436165

return data.map_batches(normalize_batch, batch_format="pandas")

37166

else:

167+

assert isinstance(data, pd.DataFrame)

38168

if not inplace:

39169

data = data.copy()

40170

for column in column_list:

@@ -44,13 +174,13 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:

441744517546176

def ensure_timestamp_compatibility(

47-

data: Union[pd.DataFrame, Dataset],

177+

data: Union[pd.DataFrame, Dataset, Any],

48178

timestamp_fields: List[str],

49179

inplace: bool = False,

50-

) -> Union[pd.DataFrame, Dataset]:

180+

) -> Union[pd.DataFrame, Dataset, Any]:

51181

from feast.utils import make_df_tzaware

5218253-

if isinstance(data, Dataset):

183+

if is_ray_data(data):

5418455185

def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:

56186

batch = make_df_tzaware(batch)

@@ -65,6 +195,7 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:

6519566196

return data.map_batches(ensure_compatibility, batch_format="pandas")

67197

else:

198+

assert isinstance(data, pd.DataFrame)

68199

if not inplace:

69200

data = data.copy()

70201

from feast.utils import make_df_tzaware

@@ -77,22 +208,24 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:

772087820979210

def apply_field_mapping(

80-

data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str]

81-

) -> Union[pd.DataFrame, Dataset]:

211+

data: Union[pd.DataFrame, Dataset, Any],

212+

field_mapping: Dict[str, str],

213+

) -> Union[pd.DataFrame, Dataset, Any]:

82214

def rename_columns(df: pd.DataFrame) -> pd.DataFrame:

83215

return df.rename(columns=field_mapping)

8421685-

if isinstance(data, Dataset):

217+

if is_ray_data(data):

86218

return data.map_batches(rename_columns, batch_format="pandas")

87219

else:

220+

assert isinstance(data, pd.DataFrame)

88221

return data.rename(columns=field_mapping)

892229022391224

def deduplicate_by_keys_and_timestamp(

92-

data: Union[pd.DataFrame, Dataset],

225+

data: Union[pd.DataFrame, Dataset, Any],

93226

join_keys: List[str],

94227

timestamp_columns: List[str],

95-

) -> Union[pd.DataFrame, Dataset]:

228+

) -> Union[pd.DataFrame, Dataset, Any]:

96229

def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:

97230

if batch.empty:

98231

return batch

@@ -110,9 +243,10 @@ def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:

110243

return deduped_batch

111244

return batch

112245113-

if isinstance(data, Dataset):

246+

if is_ray_data(data):

114247

return data.map_batches(deduplicate_batch, batch_format="pandas")

115248

else:

249+

assert isinstance(data, pd.DataFrame)

116250

return deduplicate_batch(data)

117251118252