feat: Add batch commit mode for MySQL OnlineStore (#5699) · feast-dev/feast@3cfe4eb

@@ -30,6 +30,8 @@ class MySQLOnlineStoreConfig(FeastConfigBaseModel):

3030

password: Optional[StrictStr] = None

3131

database: Optional[StrictStr] = None

3232

port: Optional[int] = None

33+

batch_write: Optional[bool] = False

34+

batch_size: Optional[int] = None

333534363537

class MySQLOnlineStore(OnlineStore):

@@ -51,7 +53,7 @@ def _get_conn(self, config: RepoConfig) -> Connection:

5153

password=online_store_config.password or "test",

5254

database=online_store_config.database or "feast",

5355

port=online_store_config.port or 3306,

54-

autocommit=True,

56+

autocommit=(not online_store_config.batch_write),

5557

)

5658

return self._conn

5759

@@ -69,29 +71,97 @@ def online_write_batch(

69717072

project = config.project

717372-

for entity_key, values, timestamp, created_ts in data:

73-

entity_key_bin = serialize_entity_key(

74-

entity_key,

75-

entity_key_serialization_version=3,

76-

).hex()

77-

timestamp = to_naive_utc(timestamp)

78-

if created_ts is not None:

79-

created_ts = to_naive_utc(created_ts)

80-81-

for feature_name, val in values.items():

82-

self.write_to_table(

83-

created_ts,

84-

cur,

85-

entity_key_bin,

86-

feature_name,

87-

project,

88-

table,

89-

timestamp,

90-

val,

91-

)

92-

conn.commit()

93-

if progress:

94-

progress(1)

74+

batch_write = config.online_store.batch_write

75+

if not batch_write:

76+

for entity_key, values, timestamp, created_ts in data:

77+

entity_key_bin = serialize_entity_key(

78+

entity_key,

79+

entity_key_serialization_version=3,

80+

).hex()

81+

timestamp = to_naive_utc(timestamp)

82+

if created_ts is not None:

83+

created_ts = to_naive_utc(created_ts)

84+85+

for feature_name, val in values.items():

86+

self.write_to_table(

87+

created_ts,

88+

cur,

89+

entity_key_bin,

90+

feature_name,

91+

project,

92+

table,

93+

timestamp,

94+

val,

95+

)

96+

conn.commit()

97+

if progress:

98+

progress(1)

99+

else:

100+

batch_size = config.online_store.bacth_size

101+

if not batch_size or batch_size < 2:

102+

raise ValueError("Batch size must be at least 2")

103+

insert_values = []

104+

for entity_key, values, timestamp, created_ts in data:

105+

entity_key_bin = serialize_entity_key(

106+

entity_key,

107+

entity_key_serialization_version=2,

108+

).hex()

109+

timestamp = to_naive_utc(timestamp)

110+

if created_ts is not None:

111+

created_ts = to_naive_utc(created_ts)

112+113+

for feature_name, val in values.items():

114+

serialized_val = val.SerializeToString()

115+

insert_values.append(

116+

(

117+

entity_key_bin,

118+

feature_name,

119+

serialized_val,

120+

timestamp,

121+

created_ts,

122+

)

123+

)

124+125+

if len(insert_values) >= batch_size:

126+

try:

127+

self._execute_batch(cur, project, table, insert_values)

128+

conn.commit()

129+

if progress:

130+

progress(len(insert_values))

131+

except Exception as e:

132+

conn.rollback()

133+

raise e

134+

insert_values.clear()

135+136+

if insert_values:

137+

try:

138+

self._execute_batch(cur, project, table, insert_values)

139+

conn.commit()

140+

if progress:

141+

progress(len(insert_values))

142+

except Exception as e:

143+

conn.rollback()

144+

raise e

145+146+

def _execute_batch(self, cur, project, table, insert_values):

147+

sql = f"""

148+

INSERT INTO {_table_id(project, table)}

149+

(entity_key, feature_name, value, event_ts, created_ts)

150+

values (%s, %s, %s, %s, %s)

151+

ON DUPLICATE KEY UPDATE

152+

value = VALUES(value),

153+

event_ts = VALUES(event_ts),

154+

created_ts = VALUES(created_ts);

155+

"""

156+

try:

157+

cur.executemany(sql, insert_values)

158+

except Exception as e:

159+

# Log SQL info for debugging without leaking sensitive data

160+

first_sample = insert_values[0] if insert_values else None

161+

raise RuntimeError(

162+

f"Failed to execute batch insert into table '{_table_id(project, table)}' "

163+

f"(rows={len(insert_values)}, sample={first_sample}): {e}"

164+

) from e

9516596166

@staticmethod

97167

def write_to_table(