Skip to content

Commit 762aeb0

Browse files
author
Tsotne Tabidze
authored
Add support for third party providers (feast-dev#1501)
* Add support for third party providers Signed-off-by: Tsotne Tabidze <[email protected]> * Add unit tests & assume providers without dots in name refers to builtin providers Signed-off-by: Tsotne Tabidze <[email protected]>
1 parent cbb97d3 commit 762aeb0

File tree

5 files changed

+197
-8
lines changed

5 files changed

+197
-8
lines changed

sdk/python/feast/errors.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,20 @@ def __init__(self, name, project=None):
3030

3131
class FeastProviderLoginError(Exception):
3232
"""Error class that indicates a user has not authenticated with their provider."""
33+
34+
35+
class FeastProviderNotImplementedError(Exception):
36+
def __init__(self, provider_name):
37+
super().__init__(f"Provider '{provider_name}' is not implemented")
38+
39+
40+
class FeastProviderModuleImportError(Exception):
41+
def __init__(self, module_name):
42+
super().__init__(f"Could not import provider module '{module_name}'")
43+
44+
45+
class FeastProviderClassImportError(Exception):
46+
def __init__(self, module_name, class_name):
47+
super().__init__(
48+
f"Could not import provider '{class_name}' from module '{module_name}'"
49+
)

sdk/python/feast/infra/provider.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import abc
2+
import importlib
23
from datetime import datetime
34
from pathlib import Path
45
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
56

67
import pandas
78
import pyarrow
89

10+
from feast import errors
911
from feast.entity import Entity
1012
from feast.feature_table import FeatureTable
1113
from feast.feature_view import FeatureView
@@ -135,16 +137,42 @@ def online_read(
135137

136138

137139
def get_provider(config: RepoConfig, repo_path: Path) -> Provider:
138-
if config.provider == "gcp":
139-
from feast.infra.gcp import GcpProvider
140+
if "." not in config.provider:
141+
if config.provider == "gcp":
142+
from feast.infra.gcp import GcpProvider
140143

141-
return GcpProvider(config)
142-
elif config.provider == "local":
143-
from feast.infra.local import LocalProvider
144+
return GcpProvider(config)
145+
elif config.provider == "local":
146+
from feast.infra.local import LocalProvider
144147

145-
return LocalProvider(config, repo_path)
148+
return LocalProvider(config, repo_path)
149+
else:
150+
raise errors.FeastProviderNotImplementedError(config.provider)
146151
else:
147-
raise ValueError(config)
152+
# Split provider into module and class names by finding the right-most dot.
153+
# For example, provider 'foo.bar.MyProvider' will be parsed into 'foo.bar' and 'MyProvider'
154+
module_name, class_name = config.provider.rsplit(".", 1)
155+
156+
# Try importing the module that contains the custom provider
157+
try:
158+
module = importlib.import_module(module_name)
159+
except Exception as e:
160+
# The original exception can be anything - either module not found,
161+
# or any other kind of error happening during the module import time.
162+
# So we should include the original error as well in the stack trace.
163+
raise errors.FeastProviderModuleImportError(module_name) from e
164+
165+
# Try getting the provider class definition
166+
try:
167+
ProviderCls = getattr(module, class_name)
168+
except AttributeError:
169+
# This can only be one type of error, when class_name attribute does not exist in the module
170+
# So we don't have to include the original exception here
171+
raise errors.FeastProviderClassImportError(
172+
module_name, class_name
173+
) from None
174+
175+
return ProviderCls(config, repo_path)
148176

149177

150178
def _get_requested_feature_views_to_features_dict(

sdk/python/tests/cli_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import contextmanager
77
from pathlib import Path
88
from textwrap import dedent
9-
from typing import List
9+
from typing import List, Tuple
1010

1111
from feast import cli
1212
from feast.feature_store import FeatureStore
@@ -26,6 +26,19 @@ class CliRunner:
2626
def run(self, args: List[str], cwd: Path) -> subprocess.CompletedProcess:
2727
return subprocess.run([sys.executable, cli.__file__] + args, cwd=cwd)
2828

29+
def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]:
30+
try:
31+
return (
32+
0,
33+
subprocess.check_output(
34+
[sys.executable, cli.__file__] + args,
35+
cwd=cwd,
36+
stderr=subprocess.STDOUT,
37+
),
38+
)
39+
except subprocess.CalledProcessError as e:
40+
return e.returncode, e.output
41+
2942
@contextmanager
3043
def local_repo(self, example_repo_py: str):
3144
"""

sdk/python/tests/foo_provider.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from datetime import datetime
2+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
3+
4+
import pandas
5+
6+
from feast import Entity, FeatureTable, FeatureView, RepoConfig
7+
from feast.infra.offline_stores.offline_store import RetrievalJob
8+
from feast.infra.provider import Provider
9+
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
10+
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
11+
from feast.registry import Registry
12+
13+
14+
class FooProvider(Provider):
15+
def update_infra(
16+
self,
17+
project: str,
18+
tables_to_delete: Sequence[Union[FeatureTable, FeatureView]],
19+
tables_to_keep: Sequence[Union[FeatureTable, FeatureView]],
20+
entities_to_delete: Sequence[Entity],
21+
entities_to_keep: Sequence[Entity],
22+
partial: bool,
23+
):
24+
pass
25+
26+
def teardown_infra(
27+
self,
28+
project: str,
29+
tables: Sequence[Union[FeatureTable, FeatureView]],
30+
entities: Sequence[Entity],
31+
):
32+
pass
33+
34+
def online_write_batch(
35+
self,
36+
project: str,
37+
table: Union[FeatureTable, FeatureView],
38+
data: List[
39+
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
40+
],
41+
progress: Optional[Callable[[int], Any]],
42+
) -> None:
43+
pass
44+
45+
def materialize_single_feature_view(
46+
self,
47+
feature_view: FeatureView,
48+
start_date: datetime,
49+
end_date: datetime,
50+
registry: Registry,
51+
project: str,
52+
) -> None:
53+
pass
54+
55+
@staticmethod
56+
def get_historical_features(
57+
config: RepoConfig,
58+
feature_views: List[FeatureView],
59+
feature_refs: List[str],
60+
entity_df: Union[pandas.DataFrame, str],
61+
registry: Registry,
62+
project: str,
63+
) -> RetrievalJob:
64+
pass
65+
66+
def online_read(
67+
self,
68+
project: str,
69+
table: Union[FeatureTable, FeatureView],
70+
entity_keys: List[EntityKeyProto],
71+
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
72+
pass
73+
74+
def __init__(self, config, repo_path):
75+
pass

sdk/python/tests/test_cli_local.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import tempfile
2+
from contextlib import contextmanager
23
from pathlib import Path
34
from textwrap import dedent
45

@@ -110,3 +111,58 @@ def test_non_local_feature_repo() -> None:
110111

111112
result = runner.run(["teardown"], cwd=repo_path)
112113
assert result.returncode == 0
114+
115+
116+
@contextmanager
117+
def setup_third_party_provider_repo(provider_name: str):
118+
with tempfile.TemporaryDirectory() as repo_dir_name:
119+
120+
# Construct an example repo in a temporary dir
121+
repo_path = Path(repo_dir_name)
122+
123+
repo_config = repo_path / "feature_store.yaml"
124+
125+
repo_config.write_text(
126+
dedent(
127+
f"""
128+
project: foo
129+
registry: data/registry.db
130+
provider: {provider_name}
131+
online_store:
132+
path: data/online_store.db
133+
type: sqlite
134+
"""
135+
)
136+
)
137+
138+
(repo_path / "foo").mkdir()
139+
repo_example = repo_path / "foo/provider.py"
140+
repo_example.write_text((Path(__file__).parent / "foo_provider.py").read_text())
141+
142+
yield repo_path
143+
144+
145+
def test_3rd_party_providers() -> None:
146+
"""
147+
Test running apply on third party providers
148+
"""
149+
runner = CliRunner()
150+
# Check with incorrect built-in provider name (no dots)
151+
with setup_third_party_provider_repo("feast123") as repo_path:
152+
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
153+
assert return_code == 1
154+
assert b"Provider 'feast123' is not implemented" in output
155+
# Check with incorrect third-party provider name (with dots)
156+
with setup_third_party_provider_repo("feast_foo.provider") as repo_path:
157+
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
158+
assert return_code == 1
159+
assert b"Could not import provider module 'feast_foo'" in output
160+
# Check with incorrect third-party provider name (with dots)
161+
with setup_third_party_provider_repo("foo.provider") as repo_path:
162+
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
163+
assert return_code == 1
164+
assert b"Could not import provider 'provider' from module 'foo'" in output
165+
# Check with correct third-party provider name
166+
with setup_third_party_provider_repo("foo.provider.FooProvider") as repo_path:
167+
return_code, output = runner.run_with_output(["apply"], cwd=repo_path)
168+
assert return_code == 0

0 commit comments

Comments
 (0)