Skip to content

Commit 3e313b1

Browse files
authored
fix: Changes template file path to relative path (feast-dev#4624)
* fix: changes following issue 4593 Signed-off-by: Theodor Mihalache <tmihalac@redhat.com> * fix: changes following issue 4593 - Fixed file path in templates to be relative path Signed-off-by: Theodor Mihalache <tmihalac@redhat.com> * fix: Fixes to relative path in FileSource Signed-off-by: Theodor Mihalache <tmihalac@redhat.com> --------- Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
1 parent f05e928 commit 3e313b1

File tree

14 files changed

+104
-32
lines changed

14 files changed

+104
-32
lines changed

sdk/python/feast/feature_store.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def plan(
713713
>>> fs = FeatureStore(repo_path="project/feature_repo")
714714
>>> driver = Entity(name="driver_id", description="driver id")
715715
>>> driver_hourly_stats = FileSource(
716-
... path="project/feature_repo/data/driver_stats.parquet",
716+
... path="data/driver_stats.parquet",
717717
... timestamp_field="event_timestamp",
718718
... created_timestamp_column="created",
719719
... )
@@ -827,7 +827,7 @@ def apply(
827827
>>> fs = FeatureStore(repo_path="project/feature_repo")
828828
>>> driver = Entity(name="driver_id", description="driver id")
829829
>>> driver_hourly_stats = FileSource(
830-
... path="project/feature_repo/data/driver_stats.parquet",
830+
... path="data/driver_stats.parquet",
831831
... timestamp_field="event_timestamp",
832832
... created_timestamp_column="created",
833833
... )

sdk/python/feast/infra/offline_stores/dask.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
self,
5858
evaluation_function: Callable,
5959
full_feature_names: bool,
60+
repo_path: str,
6061
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
6162
metadata: Optional[RetrievalMetadata] = None,
6263
):
@@ -67,6 +68,7 @@ def __init__(
6768
self._full_feature_names = full_feature_names
6869
self._on_demand_feature_views = on_demand_feature_views or []
6970
self._metadata = metadata
71+
self.repo_path = repo_path
7072

7173
@property
7274
def full_feature_names(self) -> bool:
@@ -99,8 +101,13 @@ def persist(
99101
if not allow_overwrite and os.path.exists(storage.file_options.uri):
100102
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)
101103

104+
if not Path(storage.file_options.uri).is_absolute():
105+
absolute_path = Path(self.repo_path) / storage.file_options.uri
106+
else:
107+
absolute_path = Path(storage.file_options.uri)
108+
102109
filesystem, path = FileSource.create_filesystem_and_path(
103-
storage.file_options.uri,
110+
str(absolute_path),
104111
storage.file_options.s3_endpoint_override,
105112
)
106113

@@ -243,7 +250,9 @@ def evaluate_historical_retrieval():
243250

244251
all_join_keys = list(set(all_join_keys + join_keys))
245252

246-
df_to_join = _read_datasource(feature_view.batch_source)
253+
df_to_join = _read_datasource(
254+
feature_view.batch_source, config.repo_path
255+
)
247256

248257
df_to_join, timestamp_field = _field_mapping(
249258
df_to_join,
@@ -297,6 +306,7 @@ def evaluate_historical_retrieval():
297306
min_event_timestamp=entity_df_event_timestamp_range[0],
298307
max_event_timestamp=entity_df_event_timestamp_range[1],
299308
),
309+
repo_path=str(config.repo_path),
300310
)
301311
return job
302312

@@ -316,7 +326,7 @@ def pull_latest_from_table_or_query(
316326

317327
# Create lazy function that is only called from the RetrievalJob object
318328
def evaluate_offline_job():
319-
source_df = _read_datasource(data_source)
329+
source_df = _read_datasource(data_source, config.repo_path)
320330

321331
source_df = _normalize_timestamp(
322332
source_df, timestamp_field, created_timestamp_column
@@ -377,6 +387,7 @@ def evaluate_offline_job():
377387
return DaskRetrievalJob(
378388
evaluation_function=evaluate_offline_job,
379389
full_feature_names=False,
390+
repo_path=str(config.repo_path),
380391
)
381392

382393
@staticmethod
@@ -420,8 +431,13 @@ def write_logged_features(
420431
# Since this code will be mostly used from Go-created thread, it's better to avoid producing new threads
421432
data = pyarrow.parquet.read_table(data, use_threads=False, pre_buffer=False)
422433

434+
if config.repo_path is not None and not Path(destination.path).is_absolute():
435+
absolute_path = config.repo_path / destination.path
436+
else:
437+
absolute_path = Path(destination.path)
438+
423439
filesystem, path = FileSource.create_filesystem_and_path(
424-
destination.path,
440+
str(absolute_path),
425441
destination.s3_endpoint_override,
426442
)
427443

@@ -456,8 +472,14 @@ def offline_write_batch(
456472
)
457473

458474
file_options = feature_view.batch_source.file_options
475+
476+
if config.repo_path is not None and not Path(file_options.uri).is_absolute():
477+
absolute_path = config.repo_path / file_options.uri
478+
else:
479+
absolute_path = Path(file_options.uri)
480+
459481
filesystem, path = FileSource.create_filesystem_and_path(
460-
file_options.uri, file_options.s3_endpoint_override
482+
str(absolute_path), file_options.s3_endpoint_override
461483
)
462484
prev_table = pyarrow.parquet.read_table(
463485
path, filesystem=filesystem, memory_map=True
@@ -493,7 +515,7 @@ def _get_entity_df_event_timestamp_range(
493515
)
494516

495517

496-
def _read_datasource(data_source) -> dd.DataFrame:
518+
def _read_datasource(data_source, repo_path) -> dd.DataFrame:
497519
storage_options = (
498520
{
499521
"client_kwargs": {
@@ -504,8 +526,12 @@ def _read_datasource(data_source) -> dd.DataFrame:
504526
else None
505527
)
506528

529+
if not Path(data_source.path).is_absolute():
530+
path = repo_path / data_source.path
531+
else:
532+
path = data_source.path
507533
return dd.read_parquet(
508-
data_source.path,
534+
path,
509535
storage_options=storage_options,
510536
)
511537

sdk/python/feast/infra/offline_stores/duckdb.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from feast.repo_config import FeastConfigBaseModel, RepoConfig
2828

2929

30-
def _read_data_source(data_source: DataSource) -> Table:
30+
def _read_data_source(data_source: DataSource, repo_path: str) -> Table:
3131
assert isinstance(data_source, FileSource)
3232

3333
if isinstance(data_source.file_format, ParquetFormat):
@@ -43,21 +43,32 @@ def _read_data_source(data_source: DataSource) -> Table:
4343
def _write_data_source(
4444
table: Table,
4545
data_source: DataSource,
46+
repo_path: str,
4647
mode: str = "append",
4748
allow_overwrite: bool = False,
4849
):
4950
assert isinstance(data_source, FileSource)
5051

5152
file_options = data_source.file_options
5253

53-
if mode == "overwrite" and not allow_overwrite and os.path.exists(file_options.uri):
54+
if not Path(file_options.uri).is_absolute():
55+
absolute_path = Path(repo_path) / file_options.uri
56+
else:
57+
absolute_path = Path(file_options.uri)
58+
59+
if (
60+
mode == "overwrite"
61+
and not allow_overwrite
62+
and os.path.exists(str(absolute_path))
63+
):
5464
raise SavedDatasetLocationAlreadyExists(location=file_options.uri)
5565

5666
if isinstance(data_source.file_format, ParquetFormat):
5767
if mode == "overwrite":
5868
table = table.to_pyarrow()
69+
5970
filesystem, path = FileSource.create_filesystem_and_path(
60-
file_options.uri,
71+
str(absolute_path),
6172
file_options.s3_endpoint_override,
6273
)
6374

sdk/python/feast/infra/offline_stores/file_source.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
from typing import Callable, Dict, Iterable, List, Optional, Tuple
23

34
import pyarrow
@@ -154,8 +155,16 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
154155
def get_table_column_names_and_types(
155156
self, config: RepoConfig
156157
) -> Iterable[Tuple[str, str]]:
158+
if (
159+
config.repo_path is not None
160+
and not Path(self.file_options.uri).is_absolute()
161+
):
162+
absolute_path = config.repo_path / self.file_options.uri
163+
else:
164+
absolute_path = Path(self.file_options.uri)
165+
157166
filesystem, path = FileSource.create_filesystem_and_path(
158-
self.path, self.file_options.s3_endpoint_override
167+
str(absolute_path), self.file_options.s3_endpoint_override
159168
)
160169

161170
# TODO why None check necessary

sdk/python/feast/infra/offline_stores/ibis.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def pull_latest_from_table_or_query_ibis(
4646
created_timestamp_column: Optional[str],
4747
start_date: datetime,
4848
end_date: datetime,
49-
data_source_reader: Callable[[DataSource], Table],
50-
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
49+
data_source_reader: Callable[[DataSource, str], Table],
50+
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
5151
staging_location: Optional[str] = None,
5252
staging_location_endpoint_override: Optional[str] = None,
5353
) -> RetrievalJob:
@@ -57,7 +57,7 @@ def pull_latest_from_table_or_query_ibis(
5757
start_date = start_date.astimezone(tz=timezone.utc)
5858
end_date = end_date.astimezone(tz=timezone.utc)
5959

60-
table = data_source_reader(data_source)
60+
table = data_source_reader(data_source, str(config.repo_path))
6161

6262
table = table.select(*fields)
6363

@@ -87,6 +87,7 @@ def pull_latest_from_table_or_query_ibis(
8787
data_source_writer=data_source_writer,
8888
staging_location=staging_location,
8989
staging_location_endpoint_override=staging_location_endpoint_override,
90+
repo_path=str(config.repo_path),
9091
)
9192

9293

@@ -147,8 +148,8 @@ def get_historical_features_ibis(
147148
entity_df: Union[pd.DataFrame, str],
148149
registry: BaseRegistry,
149150
project: str,
150-
data_source_reader: Callable[[DataSource], Table],
151-
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
151+
data_source_reader: Callable[[DataSource, str], Table],
152+
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
152153
full_feature_names: bool = False,
153154
staging_location: Optional[str] = None,
154155
staging_location_endpoint_override: Optional[str] = None,
@@ -174,7 +175,9 @@ def get_historical_features_ibis(
174175
def read_fv(
175176
feature_view: FeatureView, feature_refs: List[str], full_feature_names: bool
176177
) -> Tuple:
177-
fv_table: Table = data_source_reader(feature_view.batch_source)
178+
fv_table: Table = data_source_reader(
179+
feature_view.batch_source, str(config.repo_path)
180+
)
178181

179182
for old_name, new_name in feature_view.batch_source.field_mapping.items():
180183
if old_name in fv_table.columns:
@@ -247,6 +250,7 @@ def read_fv(
247250
data_source_writer=data_source_writer,
248251
staging_location=staging_location,
249252
staging_location_endpoint_override=staging_location_endpoint_override,
253+
repo_path=str(config.repo_path),
250254
)
251255

252256

@@ -258,16 +262,16 @@ def pull_all_from_table_or_query_ibis(
258262
timestamp_field: str,
259263
start_date: datetime,
260264
end_date: datetime,
261-
data_source_reader: Callable[[DataSource], Table],
262-
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
265+
data_source_reader: Callable[[DataSource, str], Table],
266+
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
263267
staging_location: Optional[str] = None,
264268
staging_location_endpoint_override: Optional[str] = None,
265269
) -> RetrievalJob:
266270
fields = join_key_columns + feature_name_columns + [timestamp_field]
267271
start_date = start_date.astimezone(tz=timezone.utc)
268272
end_date = end_date.astimezone(tz=timezone.utc)
269273

270-
table = data_source_reader(data_source)
274+
table = data_source_reader(data_source, str(config.repo_path))
271275

272276
table = table.select(*fields)
273277

@@ -290,6 +294,7 @@ def pull_all_from_table_or_query_ibis(
290294
data_source_writer=data_source_writer,
291295
staging_location=staging_location,
292296
staging_location_endpoint_override=staging_location_endpoint_override,
297+
repo_path=str(config.repo_path),
293298
)
294299

295300

@@ -319,7 +324,7 @@ def offline_write_batch_ibis(
319324
feature_view: FeatureView,
320325
table: pyarrow.Table,
321326
progress: Optional[Callable[[int], Any]],
322-
data_source_writer: Callable[[pyarrow.Table, DataSource], None],
327+
data_source_writer: Callable[[pyarrow.Table, DataSource, str], None],
323328
):
324329
pa_schema, column_names = get_pyarrow_schema_from_batch_source(
325330
config, feature_view.batch_source
@@ -330,7 +335,9 @@ def offline_write_batch_ibis(
330335
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
331336
)
332337

333-
data_source_writer(ibis.memtable(table), feature_view.batch_source)
338+
data_source_writer(
339+
ibis.memtable(table), feature_view.batch_source, str(config.repo_path)
340+
)
334341

335342

336343
def deduplicate(
@@ -469,6 +476,7 @@ def __init__(
469476
data_source_writer,
470477
staging_location,
471478
staging_location_endpoint_override,
479+
repo_path,
472480
) -> None:
473481
super().__init__()
474482
self.table = table
@@ -480,6 +488,7 @@ def __init__(
480488
self.data_source_writer = data_source_writer
481489
self.staging_location = staging_location
482490
self.staging_location_endpoint_override = staging_location_endpoint_override
491+
self.repo_path = repo_path
483492

484493
def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
485494
return self.table.execute()
@@ -502,7 +511,11 @@ def persist(
502511
timeout: Optional[int] = None,
503512
):
504513
self.data_source_writer(
505-
self.table, storage.to_data_source(), "overwrite", allow_overwrite
514+
self.table,
515+
storage.to_data_source(),
516+
self.repo_path,
517+
"overwrite",
518+
allow_overwrite,
506519
)
507520

508521
@property

sdk/python/feast/repo_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class RepoConfig(FeastBaseModel):
193193
""" Flags (deprecated field): Feature flags for experimental features """
194194

195195
repo_path: Optional[Path] = None
196+
"""When using relative path in FileSource path, this parameter is mandatory"""
196197

197198
entity_key_serialization_version: StrictInt = 1
198199
""" Entity key serialization version: This version is used to control what serialization scheme is

sdk/python/feast/templates/cassandra/bootstrap.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def bootstrap():
275275

276276
# example_repo.py
277277
example_py_file = repo_path / "example_repo.py"
278-
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
278+
replace_str_in_file(
279+
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
280+
)
279281

280282
# store config yaml, interact with user and then customize file:
281283
settings = collect_cassandra_store_settings()

sdk/python/feast/templates/hazelcast/bootstrap.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ def bootstrap():
165165

166166
# example_repo.py
167167
example_py_file = repo_path / "example_repo.py"
168-
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
168+
replace_str_in_file(
169+
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
170+
)
169171

170172
# store config yaml, interact with user and then customize file:
171173
settings = collect_hazelcast_online_store_settings()

sdk/python/feast/templates/hbase/bootstrap.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def bootstrap():
2323
driver_df.to_parquet(path=str(driver_stats_path), allow_truncated_timestamps=True)
2424

2525
example_py_file = repo_path / "example_repo.py"
26-
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
26+
replace_str_in_file(
27+
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
28+
)
2729

2830

2931
if __name__ == "__main__":

sdk/python/feast/templates/local/bootstrap.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ def bootstrap():
2525

2626
example_py_file = repo_path / "example_repo.py"
2727
replace_str_in_file(example_py_file, "%PROJECT_NAME%", str(project_name))
28-
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
29-
replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path))
28+
replace_str_in_file(
29+
example_py_file, "%PARQUET_PATH%", str(driver_stats_path.relative_to(repo_path))
30+
)
31+
replace_str_in_file(
32+
example_py_file, "%LOGGING_PATH%", str(data_path.relative_to(repo_path))
33+
)
3034

3135

3236
if __name__ == "__main__":

0 commit comments

Comments
 (0)